polish cuda install

This commit is contained in:
Dun Liang 2021-10-17 15:56:44 +08:00
parent 4a47390964
commit b077d6c185
2 changed files with 9 additions and 5 deletions

View File

@ -979,6 +979,7 @@ if nvcc_path:
# gen cuda key for cache_path
cu = "cu"
v = jit_utils.get_version(nvcc_path)[1:-1]
nvcc_version = list(map(int,v.split('.')))
cu += v
try:
r, s = sp.getstatusoutput(f"{sys.executable} -m jittor_utils.query_cuda_cc")
@ -1184,7 +1185,10 @@ if has_cuda:
return x
return f"-L\"{a}\" -l{b[:-4]}"
nvcc_flags = map_flags(nvcc_flags, func)
nvcc_flags = nvcc_flags.replace("-std=c++17", "-std=c++14 -Xcompiler -std:c++14")
if nvcc_version >= [11,4]:
nvcc_flags = nvcc_flags.replace("-std=c++17", "-std=c++14 -Xcompiler -std:c++14")
else:
nvcc_flags = nvcc_flags.replace("-std=c++17", "")
nvcc_flags = nvcc_flags.replace("-Wall", "")
nvcc_flags = nvcc_flags.replace("-Wno-unknown-pragmas", "")
nvcc_flags = nvcc_flags.replace("-fopenmp", "")
@ -1305,11 +1309,11 @@ flags = core.flags()
if has_cuda:
nvcc_flags = convert_nvcc_flags(cc_flags)
nvcc_version = jit_utils.get_int_version(nvcc_path)
nvcc_version = list(jit_utils.get_int_version(nvcc_path))
max_arch = 1000
if nvcc_version < (11,):
if nvcc_version < [11,]:
max_arch = 75
elif nvcc_version < (11,1):
elif nvcc_version < [11,1]:
max_arch = 80
if len(flags.cuda_archs):
min_arch = 30

View File

@ -60,7 +60,7 @@ def install_cuda():
elif cuda_driver_version >= [11,2]:
cuda_tgz = "cuda11.2_cudnn8_win.zip"
md5 = "b5543822c21bc460c1a414af47754556"
elif cuda_driver_version >= [11,0]:
elif cuda_driver_version >= [11,]:
cuda_tgz = "cuda11.0_cudnn8_win.zip"
md5 = "7a248df76ee5e79623236b0560f8d1fd"
elif cuda_driver_version >= [10,]: