mirror of https://github.com/Jittor/Jittor
polish files and add cuda arch hint
This commit is contained in:
parent
a78c3b4f12
commit
0fa0584fd3
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.0.1'
|
||||
__version__ = '1.3.0.2'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -1259,7 +1259,7 @@ if use_data_gz:
|
|||
.replace("-shared", "")
|
||||
vdp = os.path.join(jittor_path, "src", "utils", "vdp")
|
||||
run_cmd(fix_cl_flags(f"{cc_path} {dflags} -include {vdp} {data_s_path} -c -o {data_o_path}"))
|
||||
# os.remove(data_s_path)
|
||||
os.remove(data_s_path)
|
||||
with open(data_gz_md5_path, 'w') as f:
|
||||
f.write(md5)
|
||||
files.append(data_o_path)
|
||||
|
@ -1280,6 +1280,12 @@ flags = core.flags()
|
|||
if has_cuda:
|
||||
nvcc_flags = convert_nvcc_flags(cc_flags)
|
||||
if len(flags.cuda_archs):
|
||||
archs = []
|
||||
for arch in flags.cuda_archs:
|
||||
if arch<50:
|
||||
LOG.w(f"CUDA arch({arch})<50 is not supported")
|
||||
continue
|
||||
archs.append(arch)
|
||||
nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} "
|
||||
nvcc_flags += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||
|
||||
|
|
Loading…
Reference in New Issue