add cuda arch recommend

This commit is contained in:
Dun Liang 2021-10-01 20:31:57 +08:00
parent 6300b0908f
commit 7bc620f274
3 changed files with 17 additions and 6 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.0.3'
__version__ = '1.3.0.4'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -1279,15 +1279,26 @@ flags = core.flags()
if has_cuda:
nvcc_flags = convert_nvcc_flags(cc_flags)
nvcc_version = jit_utils.get_int_version(nvcc_path)
max_arch = 1000
if nvcc_version < (11,):
max_arch = 75
elif nvcc_version < (11,1):
max_arch = 80
if len(flags.cuda_archs):
min_arch = 30
archs = []
for arch in flags.cuda_archs:
if arch<50:
LOG.w(f"CUDA arch({arch})<30 is not supported")
if arch<min_arch:
LOG.w(f"CUDA arch({arch})<{min_arch} is not supported")
continue
if arch>max_arch:
LOG.w(f"CUDA arch({arch})>{max_arch} will be backward-compatible")
arch = max_arch
archs.append(arch)
nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} "
nvcc_flags += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
flags.cuda_archs = archs
nvcc_flags += f" -arch=compute_{min(archs)} "
nvcc_flags += ''.join(map(lambda x:f' -code=sm_{x} ', archs))
flags.cc_path = cc_path
flags.cc_type = cc_type

View File

@ -50,7 +50,7 @@ def install_cuda():
return None
LOG.i("cuda_driver_version: ", cuda_driver_version)
if "JTCUDA_VERSION" in os.environ:
cuda_driver_version = list(map(int,os.enviroment["JTCUDA_VERSION"].split(".")))
cuda_driver_version = list(map(int,os.environ["JTCUDA_VERSION"].split(".")))
LOG.i("JTCUDA_VERSION: ", cuda_driver_version)
if os.name == 'nt':