mirror of https://github.com/Jittor/Jittor
add nvcc exe suffix
This commit is contained in:
parent
123e915bb3
commit
e77f1ea7cb
|
@ -848,12 +848,7 @@ def check_cuda():
|
|||
cuda_lib_path = glob.glob(cuda_bin+"/cudart64*")[0]
|
||||
cc_flags += f" -lcudart -L\"{cuda_lib}\" "
|
||||
os.add_dll_directory(cuda_dir)
|
||||
# dll = ctypes.CDLL("cudart64_110", dlopen_flags)
|
||||
dll = ctypes.CDLL(cuda_lib_path, dlopen_flags)
|
||||
cuda_driver = ctypes.CDLL(r"nvcuda", dlopen_flags)
|
||||
driver_version = ctypes.c_int()
|
||||
r = cuda_driver.cuDriverGetVersion(ctypes.byref(driver_version))
|
||||
print("version:", driver_version, r)
|
||||
ret = dll.cudaDeviceSynchronize()
|
||||
assert ret == 0
|
||||
else:
|
||||
|
|
|
@ -79,6 +79,7 @@ def install_cuda():
|
|||
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.0")
|
||||
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
|
||||
nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc")
|
||||
if os.name=='nt': nvcc_path += '.exe'
|
||||
nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64")
|
||||
sys.path.append(nvcc_lib_path)
|
||||
new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path
|
||||
|
@ -101,7 +102,7 @@ def install_cuda():
|
|||
with tarfile.open(cuda_tgz_path, "r") as tar:
|
||||
tar.extractall(cuda_tgz_path[:-4])
|
||||
|
||||
assert os.path.isfile(nvcc_path)
|
||||
assert os.path.isfile(nvcc_path), nvcc_path
|
||||
return nvcc_path
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue