mirror of https://github.com/Jittor/Jittor
support cuda win
This commit is contained in:
parent
d85af13024
commit
123e915bb3
|
@ -11,8 +11,17 @@ from jittor_utils import run_cmd, get_version, get_int_version
|
||||||
from jittor_utils.misc import download_url_to_local
|
from jittor_utils.misc import download_url_to_local
|
||||||
|
|
||||||
def search_file(dirs, name, prefer_version=()):
|
def search_file(dirs, name, prefer_version=()):
|
||||||
|
if os.name == 'nt':
|
||||||
|
if name.startswith("lib"):
|
||||||
|
name = name[3:].replace(".so", "64*.dll")
|
||||||
for d in dirs:
|
for d in dirs:
|
||||||
fname = os.path.join(d, name)
|
fname = os.path.join(d, name)
|
||||||
|
if os.name == 'nt':
|
||||||
|
lname = os.path.join(d, name)
|
||||||
|
names = glob.glob(lname)
|
||||||
|
if len(names):
|
||||||
|
return names[0]
|
||||||
|
continue
|
||||||
prefer_version = tuple( str(p) for p in prefer_version )
|
prefer_version = tuple( str(p) for p in prefer_version )
|
||||||
for i in range(len(prefer_version),-1,-1):
|
for i in range(len(prefer_version),-1,-1):
|
||||||
vname = ".".join((fname,)+prefer_version[:i])
|
vname = ".".join((fname,)+prefer_version[:i])
|
||||||
|
@ -122,8 +131,7 @@ def setup_mkl():
|
||||||
mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll')
|
mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll')
|
||||||
mkl_bin_path = os.path.join(mkl_home, 'bin')
|
mkl_bin_path = os.path.join(mkl_home, 'bin')
|
||||||
os.add_dll_directory(mkl_bin_path)
|
os.add_dll_directory(mkl_bin_path)
|
||||||
mkl_lib = os.path.join(mkl_lib_path, "dnnl.lib")
|
extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -ldnnl "
|
||||||
extra_flags = f" -I\"{mkl_include_path}\" \"{mkl_lib}\" "
|
|
||||||
assert os.path.isdir(mkl_include_path)
|
assert os.path.isdir(mkl_include_path)
|
||||||
assert os.path.isdir(mkl_lib_path)
|
assert os.path.isdir(mkl_lib_path)
|
||||||
assert os.path.isfile(mkl_lib_name)
|
assert os.path.isfile(mkl_lib_name)
|
||||||
|
@ -156,17 +164,17 @@ def install_cub(root_folder):
|
||||||
fullname = os.path.join(root_folder, filename)
|
fullname = os.path.join(root_folder, filename)
|
||||||
dirname = os.path.join(root_folder, filename.replace(".tgz",""))
|
dirname = os.path.join(root_folder, filename.replace(".tgz",""))
|
||||||
|
|
||||||
if not os.path.isfile(os.path.join(dirname, "examples", "test")):
|
if not os.path.isfile(os.path.join(dirname, "examples", "device/example_device_radix_sort.cu")):
|
||||||
LOG.i("Downloading cub...")
|
LOG.i("Downloading cub...")
|
||||||
download_url_to_local(url, filename, root_folder, md5)
|
download_url_to_local(url, filename, root_folder, md5)
|
||||||
import tarfile
|
import tarfile
|
||||||
|
|
||||||
with tarfile.open(fullname, "r") as tar:
|
with tarfile.open(fullname, "r") as tar:
|
||||||
tar.extractall(root_folder)
|
tar.extractall(root_folder)
|
||||||
assert 0 == os.system(f"cd {dirname}/examples && "
|
# assert 0 == os.system(f"cd {dirname}/examples && "
|
||||||
f"{nvcc_path} --cudart=shared -ccbin=\"{cc_path}\" device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
|
# f"{nvcc_path} --cudart=shared -ccbin=\"{cc_path}\" device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
|
||||||
if core.get_device_count():
|
# if core.get_device_count():
|
||||||
assert 0 == os.system(f"cd {dirname}/examples && ./test")
|
# assert 0 == os.system(f"cd {dirname}/examples && ./test")
|
||||||
return dirname
|
return dirname
|
||||||
|
|
||||||
def setup_cub():
|
def setup_cub():
|
||||||
|
@ -191,8 +199,9 @@ def setup_cuda_extern():
|
||||||
cuda_extern_src = os.path.join(jittor_path, "extern", "cuda", "src")
|
cuda_extern_src = os.path.join(jittor_path, "extern", "cuda", "src")
|
||||||
cuda_extern_files = [os.path.join(cuda_extern_src, name)
|
cuda_extern_files = [os.path.join(cuda_extern_src, name)
|
||||||
for name in os.listdir(cuda_extern_src)]
|
for name in os.listdir(cuda_extern_src)]
|
||||||
so_name = os.path.join(cache_path_cuda, "cuda_extern.so")
|
so_name = os.path.join(cache_path_cuda, "cuda_extern"+so)
|
||||||
compile(cc_path, cc_flags+f" -I'{cuda_include}' ", cuda_extern_files, so_name)
|
compile(cc_path, cc_flags+f" -I\"{cuda_include}\" ", cuda_extern_files, so_name)
|
||||||
|
link_cuda_extern = f" -L\"{cache_path_cuda}\" -lcuda_extern "
|
||||||
ctypes.CDLL(so_name, dlopen_flags)
|
ctypes.CDLL(so_name, dlopen_flags)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -205,7 +214,7 @@ def setup_cuda_extern():
|
||||||
libs = ["cublas", "cudnn", "curand"]
|
libs = ["cublas", "cudnn", "curand"]
|
||||||
for lib_name in libs:
|
for lib_name in libs:
|
||||||
try:
|
try:
|
||||||
setup_cuda_lib(lib_name)
|
setup_cuda_lib(lib_name, extra_flags=link_cuda_extern)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
line = traceback.format_exc()
|
line = traceback.format_exc()
|
||||||
|
@ -244,12 +253,12 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
||||||
prefer_version = ()
|
prefer_version = ()
|
||||||
if nvcc_version[0] == 11:
|
if nvcc_version[0] == 11:
|
||||||
prefer_version = ("8",)
|
prefer_version = ("8",)
|
||||||
culib_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"lib{lib_name}.so", prefer_version)
|
culib_path = search_file([cuda_bin, cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"lib{lib_name}.so", prefer_version)
|
||||||
|
|
||||||
if lib_name == "cublas" and nvcc_version[0] >= 10:
|
if lib_name == "cublas" and nvcc_version[0] >= 10:
|
||||||
# manual link libcublasLt.so
|
# manual link libcublasLt.so
|
||||||
try:
|
try:
|
||||||
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version)
|
cublas_lt_lib_path = search_file([cuda_bin, cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version)
|
||||||
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
|
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
|
||||||
except:
|
except:
|
||||||
# some aarch64 os, such as uos with FT2000 cpu,
|
# some aarch64 os, such as uos with FT2000 cpu,
|
||||||
|
@ -263,12 +272,12 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
||||||
if nvcc_version >= (11,0,0):
|
if nvcc_version >= (11,0,0):
|
||||||
libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"]
|
libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"]
|
||||||
for l in libs:
|
for l in libs:
|
||||||
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], l, prefer_version)
|
ex_cudnn_path = search_file([cuda_bin, cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], l, prefer_version)
|
||||||
ctypes.CDLL(ex_cudnn_path, dlopen_flags)
|
ctypes.CDLL(ex_cudnn_path, dlopen_flags)
|
||||||
|
|
||||||
# dynamic link cuda library
|
# dynamic link cuda library
|
||||||
ctypes.CDLL(culib_path, dlopen_flags)
|
ctypes.CDLL(culib_path, dlopen_flags)
|
||||||
link_flags = f"-l{lib_name} -L'{cuda_lib}'"
|
link_flags = f"-l{lib_name} -L\"{cuda_lib}\""
|
||||||
|
|
||||||
# find all source files
|
# find all source files
|
||||||
culib_src_dir = os.path.join(jittor_path, "extern", "cuda", lib_name)
|
culib_src_dir = os.path.join(jittor_path, "extern", "cuda", lib_name)
|
||||||
|
@ -281,7 +290,7 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
||||||
|
|
||||||
# compile and get operators
|
# compile and get operators
|
||||||
culib = compile_custom_ops(culib_src_files, return_module=True,
|
culib = compile_custom_ops(culib_src_files, return_module=True,
|
||||||
extra_flags=f" -I'{jt_cuda_include}' -I'{jt_culib_include}' {link_flags} {extra_flags} ")
|
extra_flags=f" -I\"{jt_cuda_include}\" -I\"{jt_culib_include}\" {link_flags} {extra_flags} ")
|
||||||
culib_ops = culib.ops
|
culib_ops = culib.ops
|
||||||
globals()[lib_name+"_ops"] = culib_ops
|
globals()[lib_name+"_ops"] = culib_ops
|
||||||
globals()[lib_name] = culib
|
globals()[lib_name] = culib
|
||||||
|
@ -289,19 +298,20 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
||||||
|
|
||||||
def install_cutt(root_folder):
|
def install_cutt(root_folder):
|
||||||
# Modified from: https://github.com/ap-hynninen/cutt
|
# Modified from: https://github.com/ap-hynninen/cutt
|
||||||
url = "https://codeload.github.com/Jittor/cutt/zip/v1.1"
|
url = "https://codeload.github.com/Jittor/cutt/zip/v1.2"
|
||||||
|
|
||||||
filename = "cutt-1.1.zip"
|
filename = "cutt-1.2.zip"
|
||||||
fullname = os.path.join(root_folder, filename)
|
fullname = os.path.join(root_folder, filename)
|
||||||
dirname = os.path.join(root_folder, filename.replace(".zip",""))
|
dirname = os.path.join(root_folder, filename.replace(".zip",""))
|
||||||
true_md5 = "7bb71cf7c49dbe57772539bf043778f7"
|
true_md5 = "14d0fd1132c8cd657dc3cf29ce4db931"
|
||||||
|
|
||||||
if os.path.exists(fullname):
|
if os.path.exists(fullname):
|
||||||
md5 = run_cmd('md5sum '+fullname).split()[0]
|
from jittor_utils.misc import calculate_md5
|
||||||
|
md5 = calculate_md5(fullname)
|
||||||
if md5 != true_md5:
|
if md5 != true_md5:
|
||||||
os.remove(fullname)
|
os.remove(fullname)
|
||||||
shutil.rmtree(dirname)
|
shutil.rmtree(dirname)
|
||||||
if not os.path.isfile(os.path.join(dirname, "bin", "cutt_test")):
|
if not os.path.isfile(os.path.join(dirname, "lib/libcutt"+so)):
|
||||||
LOG.i("Downloading cutt...")
|
LOG.i("Downloading cutt...")
|
||||||
download_url_to_local(url, filename, root_folder, true_md5)
|
download_url_to_local(url, filename, root_folder, true_md5)
|
||||||
|
|
||||||
|
@ -320,7 +330,17 @@ def install_cutt(root_folder):
|
||||||
if len(flags.cuda_archs):
|
if len(flags.cuda_archs):
|
||||||
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
|
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
|
||||||
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||||
run_cmd(f"make NVCC_GENCODE='{arch_flag} --cudart=shared -ccbin=\"{cc_path}\" ' nvcc_path='{nvcc_path}'", cwd=dirname)
|
cutt_include = f" -I\"{dirname}/include\" -I\"{dirname}/src\" "
|
||||||
|
files = glob.glob(dirname+"/src/*.c*", recursive=True)
|
||||||
|
files2 = []
|
||||||
|
for f in files:
|
||||||
|
if f.endswith("cutt_bench.cpp") or \
|
||||||
|
f.endswith("cutt_test.cpp"):
|
||||||
|
continue
|
||||||
|
files2.append(f)
|
||||||
|
cutt_flags = cc_flags+opt_flags+cutt_include
|
||||||
|
os.makedirs(dirname+"/lib", exist_ok=True)
|
||||||
|
compile(cc_path, cutt_flags, files2, dirname+"/lib/libcutt"+so, cuda_flags=arch_flag)
|
||||||
return dirname
|
return dirname
|
||||||
|
|
||||||
def setup_cutt():
|
def setup_cutt():
|
||||||
|
@ -342,11 +362,11 @@ def setup_cutt():
|
||||||
|
|
||||||
make_cache_dir(cutt_path)
|
make_cache_dir(cutt_path)
|
||||||
install_cutt(cutt_path)
|
install_cutt(cutt_path)
|
||||||
cutt_home = os.path.join(cutt_path, "cutt-1.1")
|
cutt_home = os.path.join(cutt_path, "cutt-1.2")
|
||||||
cutt_include_path = os.path.join(cutt_home, "src")
|
cutt_include_path = os.path.join(cutt_home, "src")
|
||||||
cutt_lib_path = os.path.join(cutt_home, "lib")
|
cutt_lib_path = os.path.join(cutt_home, "lib")
|
||||||
|
|
||||||
cutt_lib_name = os.path.join(cutt_lib_path, "libcutt.so")
|
cutt_lib_name = os.path.join(cutt_lib_path, "libcutt"+so)
|
||||||
assert os.path.isdir(cutt_include_path)
|
assert os.path.isdir(cutt_include_path)
|
||||||
assert os.path.isdir(cutt_lib_path)
|
assert os.path.isdir(cutt_lib_path)
|
||||||
assert os.path.isfile(cutt_lib_name), cutt_lib_name
|
assert os.path.isfile(cutt_lib_name), cutt_lib_name
|
||||||
|
@ -354,12 +374,14 @@ def setup_cutt():
|
||||||
LOG.v(f"cutt_lib_path: {cutt_lib_path}")
|
LOG.v(f"cutt_lib_path: {cutt_lib_path}")
|
||||||
LOG.v(f"cutt_lib_name: {cutt_lib_name}")
|
LOG.v(f"cutt_lib_name: {cutt_lib_name}")
|
||||||
# We do not link manualy, link in custom ops
|
# We do not link manualy, link in custom ops
|
||||||
|
if os.name == "nt":
|
||||||
|
os.add_dll_directory(cutt_lib_path)
|
||||||
ctypes.CDLL(cutt_lib_name, dlopen_flags)
|
ctypes.CDLL(cutt_lib_name, dlopen_flags)
|
||||||
|
|
||||||
cutt_op_dir = os.path.join(jittor_path, "extern", "cuda", "cutt", "ops")
|
cutt_op_dir = os.path.join(jittor_path, "extern", "cuda", "cutt", "ops")
|
||||||
cutt_op_files = [os.path.join(cutt_op_dir, name) for name in os.listdir(cutt_op_dir)]
|
cutt_op_files = [os.path.join(cutt_op_dir, name) for name in os.listdir(cutt_op_dir)]
|
||||||
cutt_ops = compile_custom_ops(cutt_op_files,
|
cutt_ops = compile_custom_ops(cutt_op_files,
|
||||||
extra_flags=f" -I'{cutt_include_path}'")
|
extra_flags=f" -I\"{cutt_include_path}\" -L\"{cutt_lib_path}\" -llibcutt ")
|
||||||
LOG.vv("Get cutt_ops: "+str(dir(cutt_ops)))
|
LOG.vv("Get cutt_ops: "+str(dir(cutt_ops)))
|
||||||
|
|
||||||
|
|
||||||
|
@ -442,7 +464,7 @@ def setup_nccl():
|
||||||
nccl_src_files.append(os.path.join(r, fname))
|
nccl_src_files.append(os.path.join(r, fname))
|
||||||
|
|
||||||
nccl_ops = compile_custom_ops(nccl_src_files,
|
nccl_ops = compile_custom_ops(nccl_src_files,
|
||||||
extra_flags=f" -I'{nccl_include_path}' {mpi_compile_flags} ")
|
extra_flags=f" -I\"{nccl_include_path}\" {mpi_compile_flags} ")
|
||||||
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
|
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
|
||||||
|
|
||||||
def manual_link(flags):
|
def manual_link(flags):
|
||||||
|
@ -498,7 +520,7 @@ def setup_mpi():
|
||||||
mpi_src_files.append(os.path.join(r, fname))
|
mpi_src_files.append(os.path.join(r, fname))
|
||||||
|
|
||||||
# mpi compile flags add for nccl
|
# mpi compile flags add for nccl
|
||||||
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
|
mpi_compile_flags += f" -I\"{os.path.join(mpi_src_dir, 'inc')}\" "
|
||||||
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
|
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
|
||||||
|
|
||||||
mpi_version = get_version(mpicc_path)
|
mpi_version = get_version(mpicc_path)
|
||||||
|
|
|
@ -34,25 +34,62 @@ def make_cache_dir(cache_path):
|
||||||
LOG.i(f"Create cache dir: {cache_path}")
|
LOG.i(f"Create cache dir: {cache_path}")
|
||||||
os.mkdir(cache_path)
|
os.mkdir(cache_path)
|
||||||
|
|
||||||
|
def shsplit(s):
|
||||||
|
s1 = s.split(' ')
|
||||||
|
s2 = []
|
||||||
|
count = 0
|
||||||
|
for s in s1:
|
||||||
|
nc = s.count('"') + s.count('\'')
|
||||||
|
if count&1:
|
||||||
|
count += nc
|
||||||
|
s2[-1] += " "
|
||||||
|
s2[-1] += s
|
||||||
|
else:
|
||||||
|
count = nc
|
||||||
|
s2.append(s)
|
||||||
|
return s2
|
||||||
|
|
||||||
|
|
||||||
def remove_flags(flags, rm_flags):
|
def remove_flags(flags, rm_flags):
|
||||||
flags = flags.split(" ")
|
flags = shsplit(flags)
|
||||||
output = []
|
output = []
|
||||||
for s in flags:
|
for s in flags:
|
||||||
|
ss = s.replace("\"", "")
|
||||||
for rm in rm_flags:
|
for rm in rm_flags:
|
||||||
if s.startswith(rm):
|
if ss.startswith(rm) or ss.endswith(rm):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
output.append(s)
|
output.append(s)
|
||||||
return " ".join(output)
|
return " ".join(output)
|
||||||
|
|
||||||
def compile(compiler, flags, inputs, output, combind_build=False):
|
def moveback_flags(flags, rm_flags):
|
||||||
|
flags = shsplit(flags)
|
||||||
|
output = []
|
||||||
|
output2 = []
|
||||||
|
for s in flags:
|
||||||
|
ss = s.replace("\"", "")
|
||||||
|
for rm in rm_flags:
|
||||||
|
if ss.startswith(rm) or ss.endswith(rm):
|
||||||
|
output2.append(s)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
output.append(s)
|
||||||
|
return " ".join(output+output2)
|
||||||
|
|
||||||
|
def map_flags(flags, func):
|
||||||
|
flags = shsplit(flags)
|
||||||
|
output = []
|
||||||
|
for s in flags:
|
||||||
|
output.append(func(s))
|
||||||
|
return " ".join(output)
|
||||||
|
|
||||||
|
def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags=""):
|
||||||
def do_compile(cmd):
|
def do_compile(cmd):
|
||||||
if jit_utils.cc:
|
if jit_utils.cc:
|
||||||
return jit_utils.cc.cache_compile(cmd, cache_path, jittor_path)
|
return jit_utils.cc.cache_compile(cmd, cache_path, jittor_path)
|
||||||
else:
|
else:
|
||||||
run_cmd(cmd)
|
run_cmd(cmd)
|
||||||
return True
|
return True
|
||||||
link = link_flags
|
|
||||||
base_output = os.path.basename(output).split('.')[0]
|
base_output = os.path.basename(output).split('.')[0]
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
# windows do not combind build, need gen def
|
# windows do not combind build, need gen def
|
||||||
|
@ -64,18 +101,9 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
||||||
# initialize order in windows seems reversed
|
# initialize order in windows seems reversed
|
||||||
inputs = list(inputs[::-1])
|
inputs = list(inputs[::-1])
|
||||||
link = link + f' -Wl,--export-all-symbols,--out-implib,"{afile}" '
|
link = link + f' -Wl,--export-all-symbols,--out-implib,"{afile}" '
|
||||||
if base_output == "jit_utils_core":
|
|
||||||
pass
|
|
||||||
elif base_output == "jittor_core":
|
|
||||||
inputs.append(os.path.join(cache_path, f"jit_utils_core{lib_suffix}"))
|
|
||||||
else:
|
|
||||||
inputs.append(os.path.join(cache_path, f"jit_utils_core{lib_suffix}"))
|
|
||||||
inputs.append(os.path.join(cache_path, f"jittor_core{lib_suffix}"))
|
|
||||||
|
|
||||||
# if output is core, add core_link_flags
|
if not os.path.isabs(output):
|
||||||
if output.startswith("jittor_core"):
|
output = os.path.join(cache_path, output)
|
||||||
link = link + core_link_flags
|
|
||||||
output = os.path.join(cache_path, output)
|
|
||||||
# don't recompile object file in inputs
|
# don't recompile object file in inputs
|
||||||
obj_files = []
|
obj_files = []
|
||||||
ex_obj_files = []
|
ex_obj_files = []
|
||||||
|
@ -94,30 +122,33 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
||||||
return do_compile(fix_cl_flags(cmd))
|
return do_compile(fix_cl_flags(cmd))
|
||||||
# split compile object file and link
|
# split compile object file and link
|
||||||
# remove -l -L flags when compile object files
|
# remove -l -L flags when compile object files
|
||||||
oflags = remove_flags(flags, ['-l', '-L', '-Wl,'])
|
oflags = remove_flags(flags, ['-l', '-L', '-Wl,', '.lib', '-shared'])
|
||||||
cmds = []
|
cmds = []
|
||||||
for input, obj_file in zip(inputs, obj_files):
|
for input, obj_file in zip(inputs, obj_files):
|
||||||
cc = compiler
|
cc = compiler
|
||||||
nflags = oflags
|
nflags = oflags
|
||||||
|
cmd = f"{input} {nflags} {lto_flags} -c -o {obj_file}"
|
||||||
if input.endswith(".cu"):
|
if input.endswith(".cu"):
|
||||||
if has_cuda:
|
if has_cuda:
|
||||||
nflags = convert_nvcc_flags(oflags)
|
cmd = f"\"{nvcc_path}\" {cuda_flags} {cmd}"
|
||||||
cc = nvcc_path
|
cmd = convert_nvcc_flags(fix_cl_flags(cmd))
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
cmd = f"\"{cc}\" {input} {nflags} {lto_flags} -c -o {obj_file}"
|
else:
|
||||||
|
cmd = f"\"{cc}\" {cmd}"
|
||||||
|
cmd = fix_cl_flags(cmd)
|
||||||
if "nan_checker" in input:
|
if "nan_checker" in input:
|
||||||
# nan checker needs to disable fast_math
|
# nan checker needs to disable fast_math
|
||||||
cmd = cmd.replace("--use_fast_math", "")
|
cmd = cmd.replace("--use_fast_math", "")
|
||||||
cmd = cmd.replace("-Ofast", "-O2")
|
cmd = cmd.replace("-Ofast", "-O2")
|
||||||
cmds.append(fix_cl_flags(cmd))
|
cmds.append(cmd)
|
||||||
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
|
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
|
||||||
obj_files += ex_obj_files
|
obj_files += ex_obj_files
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
dumpdef_path = os.path.join(jittor_path, "utils", "dumpdef.py")
|
dumpdef_path = os.path.join(jittor_path, "utils", "dumpdef.py")
|
||||||
cmd = f"\"{sys.executable}\" \"{dumpdef_path}\" {' '.join(obj_files)} -Fo: \"{output}.def\""
|
cmd = f"\"{sys.executable}\" \"{dumpdef_path}\" {' '.join(obj_files)} -Fo: \"{output}.def\""
|
||||||
do_compile(fix_cl_flags(cmd))
|
do_compile(fix_cl_flags(cmd))
|
||||||
cmd = f"\"{compiler}\" {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}"
|
cmd = f"\"{compiler}\" {' '.join(obj_files)} -o {output} {flags} {lto_flags}"
|
||||||
return do_compile(fix_cl_flags(cmd))
|
return do_compile(fix_cl_flags(cmd))
|
||||||
|
|
||||||
def gen_jit_tests():
|
def gen_jit_tests():
|
||||||
|
@ -673,11 +704,15 @@ def compile_custom_ops(
|
||||||
|
|
||||||
op_extra_flags = includes + extra_flags
|
op_extra_flags = includes + extra_flags
|
||||||
|
|
||||||
|
lib_path = os.path.join(cache_path, "custom_ops")
|
||||||
|
make_cache_dir(lib_path)
|
||||||
|
gen_src_fname = os.path.join(lib_path, gen_name+".cc")
|
||||||
|
gen_head_fname = os.path.join(lib_path, gen_name+".h")
|
||||||
|
gen_lib = os.path.join(lib_path, gen_name+extension_suffix)
|
||||||
|
libname = gen_name + lib_suffix
|
||||||
|
op_extra_flags += f" -L\"{lib_path}\" -l\"{libname}\" "
|
||||||
|
|
||||||
gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags)
|
gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags)
|
||||||
make_cache_dir(os.path.join(cache_path, "custom_ops"))
|
|
||||||
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
|
|
||||||
gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
|
|
||||||
gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
|
|
||||||
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src)
|
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src)
|
||||||
# gen src initialize first
|
# gen src initialize first
|
||||||
builds.insert(0, gen_src_fname)
|
builds.insert(0, gen_src_fname)
|
||||||
|
@ -794,8 +829,9 @@ def compile_extern():
|
||||||
def check_cuda():
|
def check_cuda():
|
||||||
if not nvcc_path:
|
if not nvcc_path:
|
||||||
return
|
return
|
||||||
global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include, cuda_home
|
global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include, cuda_home, cuda_bin
|
||||||
cuda_dir = os.path.dirname(get_full_path_of_executable(nvcc_path))
|
cuda_dir = os.path.dirname(get_full_path_of_executable(nvcc_path))
|
||||||
|
cuda_bin = cuda_dir
|
||||||
cuda_home = os.path.abspath(os.path.join(cuda_dir, ".."))
|
cuda_home = os.path.abspath(os.path.join(cuda_dir, ".."))
|
||||||
# try default nvidia-cuda-toolkit in Ubuntu 20.04
|
# try default nvidia-cuda-toolkit in Ubuntu 20.04
|
||||||
# assert cuda_dir.endswith("bin") and "cuda" in cuda_dir.lower(), f"Wrong cuda_dir: {cuda_dir}"
|
# assert cuda_dir.endswith("bin") and "cuda" in cuda_dir.lower(), f"Wrong cuda_dir: {cuda_dir}"
|
||||||
|
@ -805,10 +841,25 @@ def check_cuda():
|
||||||
# this nvcc is install by package manager
|
# this nvcc is install by package manager
|
||||||
cuda_lib = "/usr/lib/x86_64-linux-gnu"
|
cuda_lib = "/usr/lib/x86_64-linux-gnu"
|
||||||
cuda_include2 = os.path.join(jittor_path, "extern","cuda","inc")
|
cuda_include2 = os.path.join(jittor_path, "extern","cuda","inc")
|
||||||
cc_flags += f" -DHAS_CUDA -I'{cuda_include}' -I'{cuda_include2}' "
|
cc_flags += f" -DHAS_CUDA -I\"{cuda_include}\" -I\"{cuda_include2}\" "
|
||||||
core_link_flags += f" -lcudart -L'{cuda_lib}' "
|
if os.name == 'nt':
|
||||||
# ctypes.CDLL(cuda_lib+"/libcudart.so", import_flags)
|
cuda_lib = os.path.abspath(os.path.join(cuda_dir, "..", "lib", "x64"))
|
||||||
ctypes.CDLL(cuda_lib+"/libcudart.so", dlopen_flags)
|
# cc_flags += f" \"{cuda_lib}\\cudart.lib\" "
|
||||||
|
cuda_lib_path = glob.glob(cuda_bin+"/cudart64*")[0]
|
||||||
|
cc_flags += f" -lcudart -L\"{cuda_lib}\" "
|
||||||
|
os.add_dll_directory(cuda_dir)
|
||||||
|
# dll = ctypes.CDLL("cudart64_110", dlopen_flags)
|
||||||
|
dll = ctypes.CDLL(cuda_lib_path, dlopen_flags)
|
||||||
|
cuda_driver = ctypes.CDLL(r"nvcuda", dlopen_flags)
|
||||||
|
driver_version = ctypes.c_int()
|
||||||
|
r = cuda_driver.cuDriverGetVersion(ctypes.byref(driver_version))
|
||||||
|
print("version:", driver_version, r)
|
||||||
|
ret = dll.cudaDeviceSynchronize()
|
||||||
|
assert ret == 0
|
||||||
|
else:
|
||||||
|
cc_flags += f" -lcudart -L\"{cuda_lib}\" "
|
||||||
|
# ctypes.CDLL(cuda_lib+"/libcudart.so", import_flags)
|
||||||
|
ctypes.CDLL(cuda_lib+"/libcudart.so", dlopen_flags)
|
||||||
has_cuda = 1
|
has_cuda = 1
|
||||||
|
|
||||||
def check_cache_compile():
|
def check_cache_compile():
|
||||||
|
@ -950,89 +1001,90 @@ if platform.system() == 'Darwin' and platform.machine() == 'arm64':
|
||||||
if "cc_flags" in os.environ:
|
if "cc_flags" in os.environ:
|
||||||
cc_flags += os.environ["cc_flags"] + ' '
|
cc_flags += os.environ["cc_flags"] + ' '
|
||||||
|
|
||||||
link_flags = " -lstdc++ -ldl -shared "
|
cc_flags += " -lstdc++ -ldl -shared "
|
||||||
if platform.system() == 'Darwin':
|
if platform.system() == 'Darwin':
|
||||||
# TODO: if not using apple clang, there is no need to add -lomp
|
# TODO: if not using apple clang, there is no need to add -lomp
|
||||||
link_flags += "-undefined dynamic_lookup -lomp "
|
cc_flags += "-undefined dynamic_lookup -lomp "
|
||||||
if platform.machine() == "arm64":
|
if platform.machine() == "arm64":
|
||||||
link_flags += " -L/opt/homebrew/lib "
|
cc_flags += " -L/opt/homebrew/lib "
|
||||||
|
|
||||||
core_link_flags = ""
|
|
||||||
opt_flags = ""
|
opt_flags = ""
|
||||||
|
|
||||||
py_include = jit_utils.get_py3_include_path()
|
py_include = jit_utils.get_py3_include_path()
|
||||||
LOG.i(f"py_include: {py_include}")
|
LOG.i(f"py_include: {py_include}")
|
||||||
extension_suffix = jit_utils.get_py3_extension_suffix()
|
extension_suffix = jit_utils.get_py3_extension_suffix()
|
||||||
lib_suffix = extension_suffix.replace(".pyd", ".lib")
|
lib_suffix = extension_suffix.rsplit(".", 1)[0]
|
||||||
LOG.i(f"extension_suffix: {extension_suffix}")
|
LOG.i(f"extension_suffix: {extension_suffix}")
|
||||||
|
so = ".so" if os.name != 'nt' else ".dll"
|
||||||
|
|
||||||
|
|
||||||
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags
|
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags
|
||||||
if platform.system() == 'Darwin':
|
if platform.system() == 'Darwin':
|
||||||
# TODO: if not using apple clang, cannot add -Xpreprocessor
|
# TODO: if not using apple clang, cannot add -Xpreprocessor
|
||||||
kernel_opt_flags = kernel_opt_flags + " -Xpreprocessor -fopenmp "
|
kernel_opt_flags += " -Xpreprocessor -fopenmp "
|
||||||
elif cc_type != 'cl':
|
elif cc_type != 'cl':
|
||||||
kernel_opt_flags = kernel_opt_flags + " -fopenmp "
|
kernel_opt_flags += " -fopenmp "
|
||||||
fix_cl_flags = lambda x:x
|
fix_cl_flags = lambda x:x
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
if cc_type == 'g++':
|
if cc_type == 'g++':
|
||||||
link_flags = link_flags.replace('-ldl', '')
|
pass
|
||||||
py3_link_path = '-L"' + os.path.join(
|
|
||||||
os.path.dirname(sys.executable),
|
|
||||||
"libs"
|
|
||||||
) + f'" -lpython3{sys.version_info.minor} '
|
|
||||||
core_link_flags = py3_link_path
|
|
||||||
link_flags += core_link_flags
|
|
||||||
# link_flags += " -Wl,--unresolved-symbols=ignore-all "
|
|
||||||
# cc_flags += " -Xlinker --allow-shlib-undefined "
|
|
||||||
cc_flags = cc_flags.replace('-std=c++14', '-std=c++17')
|
|
||||||
link_flags += " -fopenmp "
|
|
||||||
kernel_opt_flags += f" {cache_path}\\jit_utils_core{lib_suffix} "
|
|
||||||
kernel_opt_flags += f" {cache_path}\\jittor_core{lib_suffix} "
|
|
||||||
elif cc_type == 'cl':
|
elif cc_type == 'cl':
|
||||||
py3_link_path = os.path.join(
|
py3_link_path = os.path.join(
|
||||||
os.path.dirname(sys.executable),
|
os.path.dirname(sys.executable),
|
||||||
"libs",
|
"libs",
|
||||||
f'python3{sys.version_info.minor}.lib'
|
|
||||||
)
|
)
|
||||||
# core_link_flags = py3_link_path
|
cc_flags = cc_flags.replace("-std=c++14", "-std=c++17")
|
||||||
link_flags += core_link_flags
|
cc_flags = cc_flags.replace("-lstdc++", "")
|
||||||
# link_flags += " -Wl,--unresolved-symbols=ignore-all "
|
cc_flags = cc_flags.replace("-ldl", "")
|
||||||
# cc_flags += " -Xlinker --allow-shlib-undefined "
|
cc_flags += f" -L\"{py3_link_path}\" -lpython3{sys.version_info.minor} "
|
||||||
kernel_opt_flags += f" {cache_path}\\jit_utils_core{lib_suffix} "
|
cc_flags += " -EHsc "
|
||||||
kernel_opt_flags += f" {cache_path}\\jittor_core{lib_suffix} "
|
|
||||||
# cc_flags = " -std:c++17 -O2 -fp:fast -EHsc "
|
|
||||||
cc_flags = " -std:c++17 -O2 -fp:fast -EHsc "
|
|
||||||
# cc_flags += py3_link_path + " "
|
|
||||||
import jittor_utils
|
import jittor_utils
|
||||||
if jittor_utils.msvc_path:
|
if jittor_utils.msvc_path:
|
||||||
mp = jittor_utils.msvc_path
|
mp = jittor_utils.msvc_path
|
||||||
cc_flags += f' -nologo -I"{mp}\\cl_x64\\include" -I"{mp}\\win10_kits\\include\\ucrt" -I"{mp}\\win10_kits\\include\\shared" -I"{mp}\\win10_kits\\include\\um" -DNOMINMAX '
|
cc_flags += f' -nologo -I"{mp}\\VC\\include" -I"{mp}\\win10_kits\\include\\ucrt" -I"{mp}\\win10_kits\\include\\shared" -I"{mp}\\win10_kits\\include\\um" -DNOMINMAX '
|
||||||
win_link_flags = f' -link -LIBPATH:"{mp}\\cl_x64\\lib" -LIBPATH:"{mp}\\win10_kits\\lib\\um\\x64" -LIBPATH:"{mp}\\win10_kits\\lib\\ucrt\\x64" '
|
cc_flags += f' -L"{mp}\\VC\\lib" -L"{mp}\\win10_kits\\lib\\um\\x64" -L"{mp}\\win10_kits\\lib\\ucrt\\x64" '
|
||||||
link_flags = ' -LD '
|
|
||||||
kernel_opt_flags += win_link_flags# + " -EXPORT:\"?jit_run@FusedOp@jittor@@QEAAXXZ\""
|
|
||||||
def fix_cl_flags(cmd):
|
def fix_cl_flags(cmd):
|
||||||
cmd = cmd.replace(".o ", ".obj ")
|
cmd = cmd.replace(".o ", ".obj ")
|
||||||
cmd = cmd.replace(".o\" ", ".obj\" ")
|
cmd = cmd.replace(".o\" ", ".obj\" ")
|
||||||
if cmd.endswith(".o"): cmd += "bj"
|
if cmd.endswith(".o"): cmd += "bj"
|
||||||
from shlex import split
|
if " -o " in cmd:
|
||||||
if " -LD " in cmd:
|
if " -shared " in cmd:
|
||||||
cmd = cmd.replace(" -o ", " -Fe: ")
|
cmd = cmd.replace(" -o ", " -Fe: ")
|
||||||
output = split(cmd.split("-Fe:")[1].strip(), posix=False)[0]
|
output = shsplit(cmd.split("-Fe:")[1].strip())[0]
|
||||||
base_output = os.path.basename(output).split('.')[0]
|
base_output = os.path.basename(output).split('.')[0]
|
||||||
cmd += win_link_flags
|
cmd += f" -DEF:\"{output}.def\" -IGNORE:4102 -IGNORE:4197 -IGNORE:4217 "
|
||||||
cmd += f" -DEF:\"{output}.def\" -IGNORE:4102 -IGNORE:4197 -IGNORE:4217 {py3_link_path}"
|
|
||||||
if base_output == "jit_utils_core":
|
|
||||||
pass
|
|
||||||
elif base_output == "jittor_core":
|
|
||||||
cmd += " " + os.path.join(cache_path, f"jit_utils_core{lib_suffix}")
|
|
||||||
else:
|
|
||||||
cmd += " " + os.path.join(cache_path, f"jit_utils_core{lib_suffix} ")
|
|
||||||
cmd += " " + os.path.join(cache_path, f"jittor_core{lib_suffix} ")
|
|
||||||
|
|
||||||
elif " -c -o " in cmd:
|
elif " -c -o " in cmd:
|
||||||
cmd = cmd.replace(" -c -o ", " -c -Fo: ")
|
cmd = cmd.replace(" -c -o ", " -c -Fo: ")
|
||||||
|
flags = shsplit(cmd)
|
||||||
|
output = []
|
||||||
|
output2 = []
|
||||||
|
for f in flags:
|
||||||
|
if f.startswith("-link"):
|
||||||
|
pass
|
||||||
|
elif f.startswith("-l"):
|
||||||
|
output2.append(f[2:]+".lib")
|
||||||
|
elif f.startswith("-LIB"):
|
||||||
|
output2.append(f)
|
||||||
|
elif f.startswith("-LD"):
|
||||||
|
output.append(f)
|
||||||
|
elif f.startswith("-L"):
|
||||||
|
output2.append("-LIBPATH:"+f[2:])
|
||||||
|
elif ".lib" in f:
|
||||||
|
output2.append(f)
|
||||||
|
elif f.startswith("-DEF:"):
|
||||||
|
output2.append(f)
|
||||||
|
elif f.startswith("-W") or f.startswith("-f"):
|
||||||
|
pass
|
||||||
|
elif f.startswith("-std="):
|
||||||
|
output.append(f.replace("=", ":"))
|
||||||
|
else:
|
||||||
|
output.append(f)
|
||||||
|
cmd = " ".join(output)
|
||||||
|
if len(output2):
|
||||||
|
cmd += " -link " + " ".join(output2)
|
||||||
cmd = cmd.replace("-include", "-FI")
|
cmd = cmd.replace("-include", "-FI")
|
||||||
|
cmd = cmd.replace("-shared", "-LD")
|
||||||
return cmd
|
return cmd
|
||||||
|
|
||||||
if ' -O' not in cc_flags:
|
if ' -O' not in cc_flags:
|
||||||
|
@ -1055,7 +1107,7 @@ ck_path = os.path.join(cache_path, "checkpoints")
|
||||||
make_cache_dir(ck_path)
|
make_cache_dir(ck_path)
|
||||||
|
|
||||||
# build cache_compile
|
# build cache_compile
|
||||||
cc_flags += f" -I{os.path.join(jittor_path, 'src')} "
|
cc_flags += f" -I\"{os.path.join(jittor_path, 'src')}\" "
|
||||||
cc_flags += py_include
|
cc_flags += py_include
|
||||||
check_cache_compile()
|
check_cache_compile()
|
||||||
LOG.v(f"Get cache_compile: {jit_utils.cc}")
|
LOG.v(f"Get cache_compile: {jit_utils.cc}")
|
||||||
|
@ -1065,9 +1117,28 @@ has_cuda = 0
|
||||||
check_cuda()
|
check_cuda()
|
||||||
nvcc_flags = os.environ.get("nvcc_flags", "")
|
nvcc_flags = os.environ.get("nvcc_flags", "")
|
||||||
if has_cuda:
|
if has_cuda:
|
||||||
nvcc_flags += cc_flags + link_flags
|
nvcc_flags += cc_flags
|
||||||
def convert_nvcc_flags(nvcc_flags):
|
def convert_nvcc_flags(nvcc_flags):
|
||||||
# nvcc don't support -Wall option
|
# nvcc don't support -Wall option
|
||||||
|
if os.name == 'nt':
|
||||||
|
nvcc_flags = nvcc_flags.replace("-fp:", "-Xcompiler -fp:")
|
||||||
|
nvcc_flags = nvcc_flags.replace("-EHsc", "-Xcompiler -EHsc")
|
||||||
|
nvcc_flags = nvcc_flags.replace("-nologo", "")
|
||||||
|
nvcc_flags = nvcc_flags.replace("-std:", "-std=")
|
||||||
|
nvcc_flags = nvcc_flags.replace("-Fo:", "-o")
|
||||||
|
nvcc_flags = nvcc_flags.replace("-LD", "-shared")
|
||||||
|
nvcc_flags = nvcc_flags.replace("-LIBPATH:", "-L")
|
||||||
|
nvcc_flags = nvcc_flags.replace("-link", "")
|
||||||
|
def func(x):
|
||||||
|
if ".lib" not in x: return x
|
||||||
|
x = x.replace("\"", "")
|
||||||
|
a = os.path.dirname(x)
|
||||||
|
b = os.path.basename(x)
|
||||||
|
if not b.endswith(".lib"):
|
||||||
|
return x
|
||||||
|
return f"-L\"{a}\" -l{b[:-4]}"
|
||||||
|
nvcc_flags = map_flags(nvcc_flags, func)
|
||||||
|
nvcc_flags = nvcc_flags.replace("-std=c++17", "-std=c++14 -Xcompiler -std:c++14")
|
||||||
nvcc_flags = nvcc_flags.replace("-Wall", "")
|
nvcc_flags = nvcc_flags.replace("-Wall", "")
|
||||||
nvcc_flags = nvcc_flags.replace("-Wno-unknown-pragmas", "")
|
nvcc_flags = nvcc_flags.replace("-Wno-unknown-pragmas", "")
|
||||||
nvcc_flags = nvcc_flags.replace("-fopenmp", "")
|
nvcc_flags = nvcc_flags.replace("-fopenmp", "")
|
||||||
|
@ -1075,10 +1146,10 @@ if has_cuda:
|
||||||
nvcc_flags = nvcc_flags.replace("-Werror", "")
|
nvcc_flags = nvcc_flags.replace("-Werror", "")
|
||||||
nvcc_flags = nvcc_flags.replace("-fPIC", "-Xcompiler -fPIC")
|
nvcc_flags = nvcc_flags.replace("-fPIC", "-Xcompiler -fPIC")
|
||||||
nvcc_flags = nvcc_flags.replace("-fdiagnostics", "-Xcompiler -fdiagnostics")
|
nvcc_flags = nvcc_flags.replace("-fdiagnostics", "-Xcompiler -fdiagnostics")
|
||||||
nvcc_flags += f" -x cu --cudart=shared -ccbin='{cc_path}' --use_fast_math "
|
nvcc_flags += f" -x cu --cudart=shared -ccbin=\"{cc_path}\" --use_fast_math "
|
||||||
# nvcc warning is noise
|
# nvcc warning is noise
|
||||||
nvcc_flags += " -w "
|
nvcc_flags += " -w "
|
||||||
nvcc_flags += f" -I'{os.path.join(jittor_path, 'extern/cuda/inc')}' "
|
nvcc_flags += f" -I\"{os.path.join(jittor_path, 'extern/cuda/inc')}\" "
|
||||||
if os.environ.get("cuda_debug", "0") == "1":
|
if os.environ.get("cuda_debug", "0") == "1":
|
||||||
nvcc_flags += " -G "
|
nvcc_flags += " -G "
|
||||||
return nvcc_flags
|
return nvcc_flags
|
||||||
|
@ -1092,7 +1163,7 @@ jit_src = gen_jit_op_maker(op_headers)
|
||||||
LOG.vvvv(jit_src)
|
LOG.vvvv(jit_src)
|
||||||
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
|
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
|
||||||
f.write(jit_src)
|
f.write(jit_src)
|
||||||
cc_flags += f' -I{cache_path} '
|
cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" '
|
||||||
# gen pyjt
|
# gen pyjt
|
||||||
pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
|
pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
|
||||||
|
|
||||||
|
@ -1164,16 +1235,19 @@ if use_data_gz:
|
||||||
f.write(data.decode("utf8"))
|
f.write(data.decode("utf8"))
|
||||||
dflags = (cc_flags+opt_flags)\
|
dflags = (cc_flags+opt_flags)\
|
||||||
.replace("-Wall", "") \
|
.replace("-Wall", "") \
|
||||||
.replace("-Werror", "")
|
.replace("-Werror", "") \
|
||||||
|
.replace("-shared", "")
|
||||||
vdp = os.path.join(jittor_path, "src", "utils", "vdp")
|
vdp = os.path.join(jittor_path, "src", "utils", "vdp")
|
||||||
run_cmd(fix_cl_flags(f"{cc_path} {dflags} -include {vdp} {data_s_path} -c -o {data_o_path}"))
|
run_cmd(fix_cl_flags(f"{cc_path} {dflags} -include {vdp} {data_s_path} -c -o {data_o_path}"))
|
||||||
os.remove(data_s_path)
|
# os.remove(data_s_path)
|
||||||
with open(data_gz_md5_path, 'w') as f:
|
with open(data_gz_md5_path, 'w') as f:
|
||||||
f.write(md5)
|
f.write(md5)
|
||||||
files.append(data_o_path)
|
files.append(data_o_path)
|
||||||
files = [f for f in files if "__data__" not in f]
|
files = [f for f in files if "__data__" not in f]
|
||||||
|
|
||||||
|
cc_flags += f" -l\"jit_utils_core{lib_suffix}\" "
|
||||||
compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix)
|
compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix)
|
||||||
|
cc_flags += f" -l\"jittor_core{lib_suffix}\" "
|
||||||
|
|
||||||
# TODO: move to compile_extern.py
|
# TODO: move to compile_extern.py
|
||||||
# compile_extern()
|
# compile_extern()
|
||||||
|
@ -1182,6 +1256,8 @@ with jit_utils.import_scope(import_flags):
|
||||||
import jittor_core as core
|
import jittor_core as core
|
||||||
|
|
||||||
flags = core.flags()
|
flags = core.flags()
|
||||||
|
nvcc_flags = convert_nvcc_flags(cc_flags)
|
||||||
|
|
||||||
if has_cuda:
|
if has_cuda:
|
||||||
if len(flags.cuda_archs):
|
if len(flags.cuda_archs):
|
||||||
nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} "
|
nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} "
|
||||||
|
@ -1189,7 +1265,7 @@ if has_cuda:
|
||||||
|
|
||||||
flags.cc_path = cc_path
|
flags.cc_path = cc_path
|
||||||
flags.cc_type = cc_type
|
flags.cc_type = cc_type
|
||||||
flags.cc_flags = cc_flags + link_flags + kernel_opt_flags
|
flags.cc_flags = cc_flags + kernel_opt_flags
|
||||||
flags.nvcc_path = nvcc_path
|
flags.nvcc_path = nvcc_path
|
||||||
flags.nvcc_flags = nvcc_flags
|
flags.nvcc_flags = nvcc_flags
|
||||||
flags.python_path = python_path
|
flags.python_path = python_path
|
||||||
|
|
|
@ -190,7 +190,7 @@ void CudnnConv3dBackwardWOp::jit_run() {
|
||||||
};
|
};
|
||||||
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
||||||
int perf_count;
|
int perf_count;
|
||||||
cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_algos];
|
STACK_ALLOC(cudnnConvolutionBwdFilterAlgoPerf_t,perf_results,num_algos);
|
||||||
cudnnConvolutionBwdFilterAlgo_t algo;
|
cudnnConvolutionBwdFilterAlgo_t algo;
|
||||||
bool benchmark=true;
|
bool benchmark=true;
|
||||||
|
|
||||||
|
|
|
@ -181,7 +181,7 @@ void CudnnConv3dBackwardXOp::jit_run() {
|
||||||
};
|
};
|
||||||
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
||||||
int perf_count;
|
int perf_count;
|
||||||
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos];
|
STACK_ALLOC(cudnnConvolutionBwdDataAlgoPerf_t,perf_results,num_algos);
|
||||||
cudnnConvolutionBwdDataAlgo_t algo;
|
cudnnConvolutionBwdDataAlgo_t algo;
|
||||||
bool benchmark=true;
|
bool benchmark=true;
|
||||||
|
|
||||||
|
|
|
@ -184,7 +184,7 @@ void CudnnConv3dOp::jit_run() {
|
||||||
};
|
};
|
||||||
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||||
int perf_count;
|
int perf_count;
|
||||||
cudnnConvolutionFwdAlgoPerf_t perf_results[num_algos];
|
STACK_ALLOC(cudnnConvolutionFwdAlgoPerf_t,perf_results,num_algos);
|
||||||
cudnnConvolutionFwdAlgo_t algo;
|
cudnnConvolutionFwdAlgo_t algo;
|
||||||
bool benchmark=true;
|
bool benchmark=true;
|
||||||
|
|
||||||
|
|
|
@ -180,7 +180,7 @@ void CudnnConvBackwardWOp::jit_run() {
|
||||||
};
|
};
|
||||||
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
||||||
int perf_count;
|
int perf_count;
|
||||||
cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_algos];
|
STACK_ALLOC(cudnnConvolutionBwdFilterAlgoPerf_t,perf_results,num_algos);
|
||||||
cudnnConvolutionBwdFilterAlgo_t algo;
|
cudnnConvolutionBwdFilterAlgo_t algo;
|
||||||
bool benchmark=true;
|
bool benchmark=true;
|
||||||
|
|
||||||
|
|
|
@ -181,7 +181,7 @@ void CudnnConvBackwardXOp::jit_run() {
|
||||||
};
|
};
|
||||||
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
||||||
int perf_count;
|
int perf_count;
|
||||||
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos];
|
STACK_ALLOC(cudnnConvolutionBwdDataAlgoPerf_t,perf_results,num_algos);
|
||||||
cudnnConvolutionBwdDataAlgo_t algo;
|
cudnnConvolutionBwdDataAlgo_t algo;
|
||||||
bool benchmark=true;
|
bool benchmark=true;
|
||||||
|
|
||||||
|
|
|
@ -183,7 +183,7 @@ void CudnnConvOp::jit_run() {
|
||||||
};
|
};
|
||||||
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||||
int perf_count;
|
int perf_count;
|
||||||
cudnnConvolutionFwdAlgoPerf_t perf_results[num_algos];
|
STACK_ALLOC(cudnnConvolutionFwdAlgoPerf_t,perf_results,num_algos);
|
||||||
cudnnConvolutionFwdAlgo_t algo;
|
cudnnConvolutionFwdAlgo_t algo;
|
||||||
bool benchmark=true;
|
bool benchmark=true;
|
||||||
|
|
||||||
|
|
|
@ -84,9 +84,6 @@ static double second (void)
|
||||||
gettimeofday(&tv, NULL);
|
gettimeofday(&tv, NULL);
|
||||||
return (double)tv.tv_sec + (double)tv.tv_usec / 1000000.0;
|
return (double)tv.tv_sec + (double)tv.tv_usec / 1000000.0;
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
#error unsupported platform
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||||
|
@ -996,3 +993,9 @@ int cudnn_test_entry( int argc, char** argv )
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
int cudnn_test_entry( int argc, char** argv ) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -472,9 +472,9 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
last_is_cuda = is_cuda;
|
last_is_cuda = is_cuda;
|
||||||
_JT_SEH_START2;
|
// _JT_SEH_START2;
|
||||||
op->do_run_after_prepare(jkl);
|
op->do_run_after_prepare(jkl);
|
||||||
_JT_SEH_END2;
|
// _JT_SEH_END2;
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
// migrate to gpu
|
// migrate to gpu
|
||||||
if (PREDICT_BRANCH_NOT_TAKEN((!is_cuda && use_cuda && !use_cuda_managed_allocator))) {
|
if (PREDICT_BRANCH_NOT_TAKEN((!is_cuda && use_cuda && !use_cuda_managed_allocator))) {
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include "utils/cache_compile.h"
|
#include "utils/cache_compile.h"
|
||||||
#include "utils/flags.h"
|
#include "utils/flags.h"
|
||||||
#include "fused_op.h"
|
#include "fused_op.h"
|
||||||
|
#include "utils/str_utils.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
|
@ -32,6 +33,71 @@ DEFINE_FLAG(string, python_path, "", "Path of python interpreter");
|
||||||
DEFINE_FLAG(string, cache_path, "", "Cache path of jittor");
|
DEFINE_FLAG(string, cache_path, "", "Cache path of jittor");
|
||||||
DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not");
|
DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not");
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
vector<string> shsplit(const string& s) {
|
||||||
|
auto s1 = split(s, " ");
|
||||||
|
vector<string> s2;
|
||||||
|
int count = 0;
|
||||||
|
for (auto& s : s1) {
|
||||||
|
int nc = 0;
|
||||||
|
for (auto& c : s)
|
||||||
|
nc += c=='"' || c=='\'';
|
||||||
|
if (count&1) {
|
||||||
|
count += nc;
|
||||||
|
s2.back() += " ";
|
||||||
|
s2.back() += s;
|
||||||
|
} else {
|
||||||
|
count = nc;
|
||||||
|
s2.push_back(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s2;
|
||||||
|
}
|
||||||
|
|
||||||
|
string fix_cl_flags(const string& cmd) {
|
||||||
|
auto flags = shsplit(cmd);
|
||||||
|
vector<string> output, output2;
|
||||||
|
|
||||||
|
for (auto& f : flags) {
|
||||||
|
if (startswith(f, "-link"))
|
||||||
|
continue;
|
||||||
|
else if (startswith(f, "-l"))
|
||||||
|
output2.push_back(f.substr(2)+".lib");
|
||||||
|
else if (startswith(f, "-LIB"))
|
||||||
|
output2.push_back(f);
|
||||||
|
else if (startswith(f, "-LD"))
|
||||||
|
output.push_back(f);
|
||||||
|
else if (startswith(f, "-L"))
|
||||||
|
output2.push_back("-LIBPATH:"+f.substr(2));
|
||||||
|
else if (f.find(".lib") != string::npos)
|
||||||
|
output2.push_back(f);
|
||||||
|
else if (startswith(f, "-DEF:"))
|
||||||
|
output2.push_back(f);
|
||||||
|
else if (startswith(f, "-W") || startswith(f,"-f"))
|
||||||
|
continue;
|
||||||
|
else if (startswith(f,"-std="))
|
||||||
|
output.push_back("-std:"+f.substr(5));
|
||||||
|
else if (startswith(f,"-include"))
|
||||||
|
output.push_back("-FI");
|
||||||
|
else if (startswith(f,"-shared"))
|
||||||
|
output.push_back("-LD");
|
||||||
|
else
|
||||||
|
output.push_back(f);
|
||||||
|
}
|
||||||
|
string cmdx = "";
|
||||||
|
for (auto& s : output) {
|
||||||
|
cmdx += s;
|
||||||
|
cmdx += " ";
|
||||||
|
}
|
||||||
|
cmdx += "-link ";
|
||||||
|
for (auto& s : output2) {
|
||||||
|
cmdx += s;
|
||||||
|
cmdx += " ";
|
||||||
|
}
|
||||||
|
return cmdx;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace jit_compiler {
|
namespace jit_compiler {
|
||||||
|
|
||||||
std::mutex dl_open_mutex;
|
std::mutex dl_open_mutex;
|
||||||
|
@ -103,6 +169,7 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
||||||
write(jit_src_path, src);
|
write(jit_src_path, src);
|
||||||
string cmd;
|
string cmd;
|
||||||
|
|
||||||
|
auto symbol_name = get_symbol_name(jit_key);
|
||||||
#ifndef _MSC_VER
|
#ifndef _MSC_VER
|
||||||
if (is_cuda_op) {
|
if (is_cuda_op) {
|
||||||
cmd = "\"" + nvcc_path + "\""
|
cmd = "\"" + nvcc_path + "\""
|
||||||
|
@ -124,21 +191,18 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
||||||
cmd = "\"" + nvcc_path + "\""
|
cmd = "\"" + nvcc_path + "\""
|
||||||
+ " \"" + jit_src_path + "\"" + other_src
|
+ " \"" + jit_src_path + "\"" + other_src
|
||||||
+ nvcc_flags + extra_flags
|
+ nvcc_flags + extra_flags
|
||||||
+ " -o \"" + jit_lib_path + "\"";
|
+ " -o \"" + jit_lib_path + "\""
|
||||||
|
+ " -Xlinker -EXPORT:\""
|
||||||
|
+ symbol_name + "\"";;
|
||||||
} else {
|
} else {
|
||||||
auto symbol_name = get_symbol_name(jit_key);
|
|
||||||
auto pos = cc_flags.find("-link");
|
|
||||||
auto cc_flags1 = cc_flags.substr(0, pos);
|
|
||||||
auto cc_flags2 = cc_flags.substr(pos);
|
|
||||||
cmd = "\"" + cc_path + "\""
|
cmd = "\"" + cc_path + "\""
|
||||||
+ " \"" + jit_src_path + "\"" + other_src
|
+ " \"" + jit_src_path + "\"" + other_src
|
||||||
+ cc_flags1 + extra_flags
|
+ " -Fe: \"" + jit_lib_path + "\" "
|
||||||
+ " -Fe: \"" + jit_lib_path + "\" " + cc_flags2 + " -EXPORT:\""
|
+ fix_cl_flags(cc_flags + extra_flags) + " -EXPORT:\""
|
||||||
+ symbol_name + "\"";
|
+ symbol_name + "\"";
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
cache_compile(cmd, cache_path, jittor_path);
|
cache_compile(cmd, cache_path, jittor_path);
|
||||||
auto symbol_name = get_symbol_name(jit_key);
|
|
||||||
auto jit_entry = load_jit_lib(jit_lib_path, symbol_name);
|
auto jit_entry = load_jit_lib(jit_lib_path, symbol_name);
|
||||||
return jit_entry;
|
return jit_entry;
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,8 @@ namespace jittor {
|
||||||
using std::string_view;
|
using std::string_view;
|
||||||
#elif defined(__GNUC__)
|
#elif defined(__GNUC__)
|
||||||
using std::experimental::string_view;
|
using std::experimental::string_view;
|
||||||
|
#elif __cplusplus < 201400L
|
||||||
|
using string_view = string;
|
||||||
#else
|
#else
|
||||||
using std::string_view;
|
using std::string_view;
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -446,7 +446,7 @@ void GetitemOp::jit_prepare(JK& jk) {
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
if (use_cuda) {
|
if (use_cuda) {
|
||||||
int no = o_shape.size();
|
int no = o_shape.size();
|
||||||
int masks[no];
|
STACK_ALLOC(int, masks, no);
|
||||||
int tdims[6];
|
int tdims[6];
|
||||||
cuda_loop_schedule(o_shape, masks, tdims);
|
cuda_loop_schedule(o_shape, masks, tdims);
|
||||||
for (int i=0; i<no; i++) {
|
for (int i=0; i<no; i++) {
|
||||||
|
|
|
@ -237,7 +237,7 @@ void SetitemOp::jit_prepare(JK& jk) {
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
if (use_cuda) {
|
if (use_cuda) {
|
||||||
int no = o_shape.size();
|
int no = o_shape.size();
|
||||||
int masks[no];
|
STACK_ALLOC(int, masks, no);
|
||||||
int tdims[6];
|
int tdims[6];
|
||||||
cuda_loop_schedule(o_shape, masks, tdims);
|
cuda_loop_schedule(o_shape, masks, tdims);
|
||||||
for (int i=0; i<no; i++) {
|
for (int i=0; i<no; i++) {
|
||||||
|
|
|
@ -254,10 +254,12 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
|
||||||
continue;
|
continue;
|
||||||
processed.insert(input_names[i]);
|
processed.insert(input_names[i]);
|
||||||
auto src = read_all(input_names[i]);
|
auto src = read_all(input_names[i]);
|
||||||
|
auto back = input_names[i].back();
|
||||||
|
// *.lib
|
||||||
|
if (back == 'b') continue;
|
||||||
ASSERT(src.size()) << "Source read failed:" << input_names[i] << "cmd:" << cmd;
|
ASSERT(src.size()) << "Source read failed:" << input_names[i] << "cmd:" << cmd;
|
||||||
auto hash = S(hash64(src));
|
auto hash = S(hash64(src));
|
||||||
vector<string> new_names;
|
vector<string> new_names;
|
||||||
auto back = input_names[i].back();
|
|
||||||
// *.obj, *.o, *.pyd
|
// *.obj, *.o, *.pyd
|
||||||
if (back != 'j' && back != 'o' && back != 'd')
|
if (back != 'j' && back != 'o' && back != 'd')
|
||||||
process(src, new_names, cmd);
|
process(src, new_names, cmd);
|
||||||
|
|
|
@ -6,6 +6,7 @@ def_path = sys.argv[-1]
|
||||||
|
|
||||||
# print(sys.argv)
|
# print(sys.argv)
|
||||||
dumpbin_path = os.environ.get("dumpbin_path", "dumpbin")
|
dumpbin_path = os.environ.get("dumpbin_path", "dumpbin")
|
||||||
|
export_all = os.environ.get("EXPORT_ALL", "0")=="1"
|
||||||
|
|
||||||
syms = {}
|
syms = {}
|
||||||
|
|
||||||
|
@ -22,6 +23,11 @@ for obj in sys.argv[1:-2]:
|
||||||
if sym.startswith("??$get_from_env"): syms[sym] = 1
|
if sym.startswith("??$get_from_env"): syms[sym] = 1
|
||||||
# if sym.startswith("??"): continue
|
# if sym.startswith("??"): continue
|
||||||
if sym.startswith("my"): syms[sym] = 1
|
if sym.startswith("my"): syms[sym] = 1
|
||||||
|
# for cutt
|
||||||
|
if "custom_cuda" in sym: syms[sym] = 1
|
||||||
|
if "cutt" in sym: syms[sym] = 1
|
||||||
|
if "_cudaGetErrorEnum" in sym: syms[sym] = 1
|
||||||
|
if export_all: syms[sym] = 1
|
||||||
if "jittor" not in sym: continue
|
if "jittor" not in sym: continue
|
||||||
syms[sym] = 1
|
syms[sym] = 1
|
||||||
# print(ret)
|
# print(ret)
|
||||||
|
|
|
@ -219,7 +219,7 @@ if os.name=='nt' and getattr(mp.current_process(), '_inheriting', False):
|
||||||
os.environ["log_silent"] = '1'
|
os.environ["log_silent"] = '1'
|
||||||
|
|
||||||
if os.environ.get("DISABLE_MULTIPROCESSING", '0') == '1':
|
if os.environ.get("DISABLE_MULTIPROCESSING", '0') == '1':
|
||||||
os.environ["use_parallel_op_compiler"] = '1'
|
os.environ["use_parallel_op_compiler"] = '0'
|
||||||
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
|
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
|
||||||
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
|
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
|
||||||
n = len(cmds)
|
n = len(cmds)
|
||||||
|
@ -278,11 +278,12 @@ def find_cache_path():
|
||||||
|
|
||||||
def get_version(output):
|
def get_version(output):
|
||||||
if output.endswith("mpicc"):
|
if output.endswith("mpicc"):
|
||||||
version = run_cmd(output+" --showme:version")
|
version = run_cmd(f"\"{output}\" --showme:version")
|
||||||
elif os.name == 'nt' and output.endswith("cl"):
|
elif os.name == 'nt' and (
|
||||||
|
output.endswith("cl") or output.endswith("cl.exe")):
|
||||||
version = run_cmd(output)
|
version = run_cmd(output)
|
||||||
else:
|
else:
|
||||||
version = run_cmd(output+" --version")
|
version = run_cmd(f"\"{output}\" --version")
|
||||||
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
|
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
|
||||||
if len(v) == 0:
|
if len(v) == 0:
|
||||||
v = re.findall("[0-9]+\\.[0-9]+", version)
|
v = re.findall("[0-9]+\\.[0-9]+", version)
|
||||||
|
@ -427,7 +428,7 @@ msvc_path = ""
|
||||||
if os.name == 'nt' and os.environ.get("cc_path", "")=="":
|
if os.name == 'nt' and os.environ.get("cc_path", "")=="":
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
msvc_path = os.path.join(str(Path.home()), ".cache", "jittor", "msvc")
|
msvc_path = os.path.join(str(Path.home()), ".cache", "jittor", "msvc")
|
||||||
cc_path = os.path.join(msvc_path, "cl_x64", "bin", "cl")
|
cc_path = os.path.join(msvc_path, "VC", r"_\_\_\_\_\bin", "cl.exe")
|
||||||
check_msvc_install = True
|
check_msvc_install = True
|
||||||
else:
|
else:
|
||||||
cc_path = env_or_find('cc_path', 'g++', silent=True)
|
cc_path = env_or_find('cc_path', 'g++', silent=True)
|
||||||
|
|
|
@ -12,7 +12,21 @@ from jittor_utils import LOG
|
||||||
from jittor_utils.misc import download_url_to_local
|
from jittor_utils.misc import download_url_to_local
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
|
def get_cuda_driver_win():
|
||||||
|
try:
|
||||||
|
import ctypes
|
||||||
|
cuda_driver = ctypes.CDLL(r"nvcuda")
|
||||||
|
driver_version = ctypes.c_int()
|
||||||
|
r = cuda_driver.cuDriverGetVersion(ctypes.byref(driver_version))
|
||||||
|
if r != 0: return None
|
||||||
|
v = driver_version.value
|
||||||
|
return [v//1000, v%1000//10, v%10]
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_cuda_driver():
|
def get_cuda_driver():
|
||||||
|
if os.name == 'nt':
|
||||||
|
return get_cuda_driver_win()
|
||||||
ret, out = sp.getstatusoutput("nvidia-smi -q -u")
|
ret, out = sp.getstatusoutput("nvidia-smi -q -u")
|
||||||
if ret != 0: return None
|
if ret != 0: return None
|
||||||
try:
|
try:
|
||||||
|
@ -35,21 +49,34 @@ def install_cuda():
|
||||||
if not cuda_driver_version:
|
if not cuda_driver_version:
|
||||||
return None
|
return None
|
||||||
LOG.i("cuda_driver_version: ", cuda_driver_version)
|
LOG.i("cuda_driver_version: ", cuda_driver_version)
|
||||||
|
if "JTCUDA_VERSION" in os.environ:
|
||||||
|
cuda_driver_version = list(map(int,os.enviroment["JTCUDA_VERSION"].split(".")))
|
||||||
|
LOG.i("JTCUDA_VERSION: ", cuda_driver_version)
|
||||||
|
|
||||||
if cuda_driver_version >= [11,2]:
|
if os.name == 'nt':
|
||||||
cuda_tgz = "cuda11.2_cudnn8_linux.tgz"
|
if cuda_driver_version >= [11,4]:
|
||||||
md5 = "b93a1a5d19098e93450ee080509e9836"
|
cuda_tgz = "cuda11.4_cudnn8_win.zip"
|
||||||
elif cuda_driver_version >= [11,]:
|
md5 = "06eed370d0d44bb2cc57809343911187"
|
||||||
cuda_tgz = "cuda11.0_cudnn8_linux.tgz"
|
elif cuda_driver_version >= [10,]:
|
||||||
md5 = "5dbdb43e35b4db8249027997720bf1ca"
|
cuda_tgz = "cuda10.2_cudnn7_win.zip"
|
||||||
elif cuda_driver_version >= [10,2]:
|
md5 = "7dd9963833a91371299a2ba58779dd71"
|
||||||
cuda_tgz = "cuda10.2_cudnn7_linux.tgz"
|
else:
|
||||||
md5 = "40f0563e8eb176f53e55943f6d212ad7"
|
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.2")
|
||||||
elif cuda_driver_version >= [10,]:
|
|
||||||
cuda_tgz = "cuda10.0_cudnn7_linux.tgz"
|
|
||||||
md5 = "f16d3ff63f081031d21faec3ec8b7dac"
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}")
|
if cuda_driver_version >= [11,2]:
|
||||||
|
cuda_tgz = "cuda11.2_cudnn8_linux.tgz"
|
||||||
|
md5 = "b93a1a5d19098e93450ee080509e9836"
|
||||||
|
elif cuda_driver_version >= [11,]:
|
||||||
|
cuda_tgz = "cuda11.0_cudnn8_linux.tgz"
|
||||||
|
md5 = "5dbdb43e35b4db8249027997720bf1ca"
|
||||||
|
elif cuda_driver_version >= [10,2]:
|
||||||
|
cuda_tgz = "cuda10.2_cudnn7_linux.tgz"
|
||||||
|
md5 = "40f0563e8eb176f53e55943f6d212ad7"
|
||||||
|
elif cuda_driver_version >= [10,]:
|
||||||
|
cuda_tgz = "cuda10.0_cudnn7_linux.tgz"
|
||||||
|
md5 = "f16d3ff63f081031d21faec3ec8b7dac"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.0")
|
||||||
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
|
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
|
||||||
nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc")
|
nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc")
|
||||||
nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64")
|
nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64")
|
||||||
|
@ -65,9 +92,14 @@ def install_cuda():
|
||||||
download_url_to_local("https://cg.cs.tsinghua.edu.cn/jittor/assets/"+cuda_tgz, cuda_tgz, jtcuda_path, md5)
|
download_url_to_local("https://cg.cs.tsinghua.edu.cn/jittor/assets/"+cuda_tgz, cuda_tgz, jtcuda_path, md5)
|
||||||
|
|
||||||
|
|
||||||
import tarfile
|
if cuda_tgz.endswith(".zip"):
|
||||||
with tarfile.open(cuda_tgz_path, "r") as tar:
|
import zipfile
|
||||||
tar.extractall(cuda_tgz_path[:-4])
|
zf = zipfile.ZipFile(cuda_tgz_path)
|
||||||
|
zf.extractall(path=cuda_tgz_path[:-4])
|
||||||
|
else:
|
||||||
|
import tarfile
|
||||||
|
with tarfile.open(cuda_tgz_path, "r") as tar:
|
||||||
|
tar.extractall(cuda_tgz_path[:-4])
|
||||||
|
|
||||||
assert os.path.isfile(nvcc_path)
|
assert os.path.isfile(nvcc_path)
|
||||||
return nvcc_path
|
return nvcc_path
|
||||||
|
|
|
@ -8,7 +8,7 @@ def install(path):
|
||||||
LOG.i("Installing MSVC...")
|
LOG.i("Installing MSVC...")
|
||||||
filename = "msvc.zip"
|
filename = "msvc.zip"
|
||||||
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename
|
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename
|
||||||
md5sum = "0fd71436c034808649b24baf28998ccc"
|
md5sum = "929c4b86ea68160121f73c3edf6b5463"
|
||||||
download_url_to_local(url, filename, path, md5sum)
|
download_url_to_local(url, filename, path, md5sum)
|
||||||
fullname = os.path.join(path, filename)
|
fullname = os.path.join(path, filename)
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
|
@ -21,7 +21,10 @@ def ensure_dir(dir_path):
|
||||||
os.makedirs(dir_path)
|
os.makedirs(dir_path)
|
||||||
|
|
||||||
def _progress():
|
def _progress():
|
||||||
pbar = tqdm(total=None)
|
pbar = tqdm(total=None,
|
||||||
|
unit="B",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024)
|
||||||
|
|
||||||
def bar_update(block_num, block_size, total_size):
|
def bar_update(block_num, block_size, total_size):
|
||||||
""" reporthook
|
""" reporthook
|
||||||
|
|
Loading…
Reference in New Issue