mirror of https://github.com/Jittor/Jittor
fix multiple different gpu archs compile error
This commit is contained in:
parent
35b22799fd
commit
f3e99b96bc
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.1.7.14'
|
||||
__version__ = '1.1.7.15'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
|
|
@ -268,7 +268,10 @@ def install_nccl(root_folder):
|
|||
tar.extractall(root_folder)
|
||||
|
||||
LOG.i("installing nccl...")
|
||||
arch_flag = f" -arch={','.join(map(lambda x:'sm_'+str(x),flags.cuda_archs))} "
|
||||
arch_flag = ""
|
||||
if len(flags.cuda_archs):
|
||||
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
|
||||
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||
run_cmd(f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag}' ", cwd=dirname)
|
||||
return dirname
|
||||
|
||||
|
|
|
@ -982,7 +982,9 @@ with jit_utils.import_scope(import_flags):
|
|||
|
||||
flags = core.flags()
|
||||
if has_cuda:
|
||||
nvcc_flags += f" -arch={','.join(map(lambda x:'sm_'+str(x),flags.cuda_archs))} "
|
||||
if len(flags.cuda_archs):
|
||||
nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} "
|
||||
nvcc_flags += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||
|
||||
flags.cc_path = cc_path
|
||||
flags.cc_type = cc_type
|
||||
|
|
Loading…
Reference in New Issue