add nvcc exe suffix

This commit is contained in:
Dun Liang 2021-09-26 15:57:08 +08:00
parent 123e915bb3
commit e77f1ea7cb
2 changed files with 2 additions and 6 deletions

View File

@ -848,12 +848,7 @@ def check_cuda():
cuda_lib_path = glob.glob(cuda_bin+"/cudart64*")[0]
cc_flags += f" -lcudart -L\"{cuda_lib}\" "
os.add_dll_directory(cuda_dir)
# dll = ctypes.CDLL("cudart64_110", dlopen_flags)
dll = ctypes.CDLL(cuda_lib_path, dlopen_flags)
cuda_driver = ctypes.CDLL(r"nvcuda", dlopen_flags)
driver_version = ctypes.c_int()
r = cuda_driver.cuDriverGetVersion(ctypes.byref(driver_version))
print("version:", driver_version, r)
ret = dll.cudaDeviceSynchronize()
assert ret == 0
else:

View File

@ -79,6 +79,7 @@ def install_cuda():
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.0")
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc")
if os.name=='nt': nvcc_path += '.exe'
nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64")
sys.path.append(nvcc_lib_path)
new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path
@ -101,7 +102,7 @@ def install_cuda():
with tarfile.open(cuda_tgz_path, "r") as tar:
tar.extractall(cuda_tgz_path[:-4])
assert os.path.isfile(nvcc_path)
assert os.path.isfile(nvcc_path), nvcc_path
return nvcc_path