mirror of https://github.com/Jittor/Jittor
add cuda arch recommend
This commit is contained in:
parent
6300b0908f
commit
7bc620f274
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.0.3'
|
||||
__version__ = '1.3.0.4'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -1279,15 +1279,26 @@ flags = core.flags()
|
|||
|
||||
if has_cuda:
|
||||
nvcc_flags = convert_nvcc_flags(cc_flags)
|
||||
nvcc_version = jit_utils.get_int_version(nvcc_path)
|
||||
max_arch = 1000
|
||||
if nvcc_version < (11,):
|
||||
max_arch = 75
|
||||
elif nvcc_version < (11,1):
|
||||
max_arch = 80
|
||||
if len(flags.cuda_archs):
|
||||
min_arch = 30
|
||||
archs = []
|
||||
for arch in flags.cuda_archs:
|
||||
if arch<50:
|
||||
LOG.w(f"CUDA arch({arch})<30 is not supported")
|
||||
if arch<min_arch:
|
||||
LOG.w(f"CUDA arch({arch})<{min_arch} is not supported")
|
||||
continue
|
||||
if arch>max_arch:
|
||||
LOG.w(f"CUDA arch({arch})>{max_arch} will be backward-compatible")
|
||||
arch = max_arch
|
||||
archs.append(arch)
|
||||
nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} "
|
||||
nvcc_flags += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||
flags.cuda_archs = archs
|
||||
nvcc_flags += f" -arch=compute_{min(archs)} "
|
||||
nvcc_flags += ''.join(map(lambda x:f' -code=sm_{x} ', archs))
|
||||
|
||||
flags.cc_path = cc_path
|
||||
flags.cc_type = cc_type
|
||||
|
|
|
@ -50,7 +50,7 @@ def install_cuda():
|
|||
return None
|
||||
LOG.i("cuda_driver_version: ", cuda_driver_version)
|
||||
if "JTCUDA_VERSION" in os.environ:
|
||||
cuda_driver_version = list(map(int,os.enviroment["JTCUDA_VERSION"].split(".")))
|
||||
cuda_driver_version = list(map(int,os.environ["JTCUDA_VERSION"].split(".")))
|
||||
LOG.i("JTCUDA_VERSION: ", cuda_driver_version)
|
||||
|
||||
if os.name == 'nt':
|
||||
|
|
Loading…
Reference in New Issue