mirror of https://github.com/mamba-org/mamba.git
fix: Improve CUDA version detection (#3700)
Signed-off-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Klaim <Klaim@users.noreply.github.com>
This commit is contained in:
parent
9dd3999b8d
commit
4f97739aed
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include "mamba/core/context.hpp"
|
||||
#include "mamba/core/output.hpp"
|
||||
#include "mamba/core/util.hpp"
|
||||
#include "mamba/core/util_os.hpp"
|
||||
#include "mamba/core/virtual_packages.hpp"
|
||||
#include "mamba/util/build.hpp"
|
||||
|
@ -65,9 +66,57 @@ namespace mamba
|
|||
auto override_version = util::get_env("CONDA_OVERRIDE_CUDA");
|
||||
if (override_version)
|
||||
{
|
||||
LOG_DEBUG << "CUDA version set by `CONDA_OVERRIDE_CUDA`: "
|
||||
<< override_version.value();
|
||||
return override_version.value();
|
||||
}
|
||||
|
||||
std::string cuda_version;
|
||||
std::string cuda_version_file = "/usr/local/cuda/version.json";
|
||||
|
||||
if (fs::exists(cuda_version_file))
|
||||
{
|
||||
LOG_DEBUG << "CUDA version file found: " << cuda_version_file;
|
||||
std::ifstream f = open_ifstream(cuda_version_file);
|
||||
nlohmann::json j;
|
||||
f >> j;
|
||||
if (auto it_cuda = j.find("cuda"); it_cuda != j.end())
|
||||
{
|
||||
auto cuda_val = *it_cuda;
|
||||
if (auto it_cuda_version = cuda_val.find("version");
|
||||
it_cuda_version != cuda_val.end())
|
||||
{
|
||||
cuda_version = it_cuda_version->get<std::string>();
|
||||
LOG_DEBUG << "CUDA version found: " << cuda_version;
|
||||
|
||||
// Extract major, minor and patch version number from the version string
|
||||
// and return only major.minor to match the cuda package version return
|
||||
// by `nvidia-smi --query -u -x`
|
||||
std::regex re("([0-9]+)\\.([0-9]+)\\.([0-9]+)");
|
||||
std::smatch m;
|
||||
if (std::regex_search(cuda_version, m, re) && m.size() >= 3)
|
||||
{
|
||||
std::ssub_match major = m[1];
|
||||
std::ssub_match minor = m[2];
|
||||
cuda_version = major.str() + "." + minor.str();
|
||||
LOG_DEBUG << "CUDA version returned: " << cuda_version;
|
||||
return cuda_version;
|
||||
}
|
||||
}
|
||||
LOG_WARNING << "Could not extract CUDA version from: " << cuda_version;
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG_WARNING << "CUDA version not found in the JSON file (`.cuda.version` is missing)";
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG_DEBUG << "CUDA version file not found: " << cuda_version_file;
|
||||
}
|
||||
|
||||
LOG_DEBUG << "Trying to find CUDA version by running `nvidia-smi --query -u -x`";
|
||||
|
||||
std::string out, err;
|
||||
std::vector<std::string> args = { "nvidia-smi", "--query", "-u", "-x" };
|
||||
auto [status, ec] = reproc::run(
|
||||
|
@ -125,26 +174,29 @@ namespace mamba
|
|||
}
|
||||
}
|
||||
|
||||
if (out.empty())
|
||||
if (!out.empty())
|
||||
{
|
||||
LOG_DEBUG << "Could not find CUDA version by calling 'nvidia-smi' (skipped)\n";
|
||||
return "";
|
||||
}
|
||||
std::regex re("<cuda_version>(.*)<\\/cuda_version>");
|
||||
std::smatch m;
|
||||
|
||||
std::regex re("<cuda_version>(.*)<\\/cuda_version>");
|
||||
std::smatch m;
|
||||
|
||||
if (std::regex_search(out, m, re))
|
||||
{
|
||||
if (m.size() == 2)
|
||||
if (std::regex_search(out, m, re))
|
||||
{
|
||||
std::ssub_match cuda_version = m[1];
|
||||
LOG_DEBUG << "CUDA driver version found: " << cuda_version;
|
||||
return cuda_version.str();
|
||||
if (m.size() == 2)
|
||||
{
|
||||
std::ssub_match cuda_version_match = m[1];
|
||||
LOG_DEBUG << "CUDA driver version found: " << cuda_version_match;
|
||||
return cuda_version_match.str();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DEBUG << "CUDA not found";
|
||||
LOG_WARNING << "Could not find CUDA version by, in this order:\n";
|
||||
LOG_WARNING << " - inspecting the `CONDA_OVERRIDE_CUDA` environment variable\n";
|
||||
LOG_WARNING << " - parsing the : " << cuda_version_file << "\n";
|
||||
LOG_WARNING << " - parsing the output of `nvidia-smi --query -u -x`\n";
|
||||
LOG_WARNING << "\n";
|
||||
LOG_WARNING << "We recommend setting the `CONDA_OVERRIDE_CUDA` environment variable\n";
|
||||
LOG_WARNING << "to the desired CUDA version.";
|
||||
return "";
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue