fix: Improve CUDA version detection (#3700)

Signed-off-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Klaim <Klaim@users.noreply.github.com>
This commit is contained in:
Julien Jerphanion 2025-01-06 15:14:06 +01:00 committed by GitHub
parent 9dd3999b8d
commit 4f97739aed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 66 additions and 14 deletions

View File

@ -17,6 +17,7 @@
#include "mamba/core/context.hpp"
#include "mamba/core/output.hpp"
#include "mamba/core/util.hpp"
#include "mamba/core/util_os.hpp"
#include "mamba/core/virtual_packages.hpp"
#include "mamba/util/build.hpp"
@ -65,9 +66,57 @@ namespace mamba
auto override_version = util::get_env("CONDA_OVERRIDE_CUDA");
if (override_version)
{
LOG_DEBUG << "CUDA version set by `CONDA_OVERRIDE_CUDA`: "
<< override_version.value();
return override_version.value();
}
std::string cuda_version;
std::string cuda_version_file = "/usr/local/cuda/version.json";
if (fs::exists(cuda_version_file))
{
LOG_DEBUG << "CUDA version file found: " << cuda_version_file;
std::ifstream f = open_ifstream(cuda_version_file);
nlohmann::json j;
f >> j;
if (auto it_cuda = j.find("cuda"); it_cuda != j.end())
{
auto cuda_val = *it_cuda;
if (auto it_cuda_version = cuda_val.find("version");
it_cuda_version != cuda_val.end())
{
cuda_version = it_cuda_version->get<std::string>();
LOG_DEBUG << "CUDA version found: " << cuda_version;
// Extract major, minor and patch version number from the version string
// and return only major.minor to match the cuda package version return
// by `nvidia-smi --query -u -x`
std::regex re("([0-9]+)\\.([0-9]+)\\.([0-9]+)");
std::smatch m;
if (std::regex_search(cuda_version, m, re) && m.size() >= 3)
{
std::ssub_match major = m[1];
std::ssub_match minor = m[2];
cuda_version = major.str() + "." + minor.str();
LOG_DEBUG << "CUDA version returned: " << cuda_version;
return cuda_version;
}
}
LOG_WARNING << "Could not extract CUDA version from: " << cuda_version;
}
else
{
LOG_WARNING << "CUDA version not found in the JSON file (`.cuda.version` is missing)";
}
}
else
{
LOG_DEBUG << "CUDA version file not found: " << cuda_version_file;
}
LOG_DEBUG << "Trying to find CUDA version by running `nvidia-smi --query -u -x`";
std::string out, err;
std::vector<std::string> args = { "nvidia-smi", "--query", "-u", "-x" };
auto [status, ec] = reproc::run(
@ -125,26 +174,29 @@ namespace mamba
}
}
if (out.empty())
if (!out.empty())
{
LOG_DEBUG << "Could not find CUDA version by calling 'nvidia-smi' (skipped)\n";
return "";
}
std::regex re("<cuda_version>(.*)<\\/cuda_version>");
std::smatch m;
std::regex re("<cuda_version>(.*)<\\/cuda_version>");
std::smatch m;
if (std::regex_search(out, m, re))
{
if (m.size() == 2)
if (std::regex_search(out, m, re))
{
std::ssub_match cuda_version = m[1];
LOG_DEBUG << "CUDA driver version found: " << cuda_version;
return cuda_version.str();
if (m.size() == 2)
{
std::ssub_match cuda_version_match = m[1];
LOG_DEBUG << "CUDA driver version found: " << cuda_version_match;
return cuda_version_match.str();
}
}
}
LOG_DEBUG << "CUDA not found";
LOG_WARNING << "Could not find CUDA version by, in this order:\n";
LOG_WARNING << " - inspecting the `CONDA_OVERRIDE_CUDA` environment variable\n";
LOG_WARNING << " - parsing the : " << cuda_version_file << "\n";
LOG_WARNING << " - parsing the output of `nvidia-smi --query -u -x`\n";
LOG_WARNING << "\n";
LOG_WARNING << "We recommend setting the `CONDA_OVERRIDE_CUDA` environment variable\n";
LOG_WARNING << "to the desired CUDA version.";
return "";
}