support cuda win

This commit is contained in:
Dun Liang 2021-09-26 15:35:18 +08:00
parent d85af13024
commit 123e915bb3
20 changed files with 371 additions and 160 deletions

View File

@ -11,8 +11,17 @@ from jittor_utils import run_cmd, get_version, get_int_version
from jittor_utils.misc import download_url_to_local from jittor_utils.misc import download_url_to_local
def search_file(dirs, name, prefer_version=()): def search_file(dirs, name, prefer_version=()):
if os.name == 'nt':
if name.startswith("lib"):
name = name[3:].replace(".so", "64*.dll")
for d in dirs: for d in dirs:
fname = os.path.join(d, name) fname = os.path.join(d, name)
if os.name == 'nt':
lname = os.path.join(d, name)
names = glob.glob(lname)
if len(names):
return names[0]
continue
prefer_version = tuple( str(p) for p in prefer_version ) prefer_version = tuple( str(p) for p in prefer_version )
for i in range(len(prefer_version),-1,-1): for i in range(len(prefer_version),-1,-1):
vname = ".".join((fname,)+prefer_version[:i]) vname = ".".join((fname,)+prefer_version[:i])
@ -122,8 +131,7 @@ def setup_mkl():
mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll') mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll')
mkl_bin_path = os.path.join(mkl_home, 'bin') mkl_bin_path = os.path.join(mkl_home, 'bin')
os.add_dll_directory(mkl_bin_path) os.add_dll_directory(mkl_bin_path)
mkl_lib = os.path.join(mkl_lib_path, "dnnl.lib") extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -ldnnl "
extra_flags = f" -I\"{mkl_include_path}\" \"{mkl_lib}\" "
assert os.path.isdir(mkl_include_path) assert os.path.isdir(mkl_include_path)
assert os.path.isdir(mkl_lib_path) assert os.path.isdir(mkl_lib_path)
assert os.path.isfile(mkl_lib_name) assert os.path.isfile(mkl_lib_name)
@ -156,17 +164,17 @@ def install_cub(root_folder):
fullname = os.path.join(root_folder, filename) fullname = os.path.join(root_folder, filename)
dirname = os.path.join(root_folder, filename.replace(".tgz","")) dirname = os.path.join(root_folder, filename.replace(".tgz",""))
if not os.path.isfile(os.path.join(dirname, "examples", "test")): if not os.path.isfile(os.path.join(dirname, "examples", "device/example_device_radix_sort.cu")):
LOG.i("Downloading cub...") LOG.i("Downloading cub...")
download_url_to_local(url, filename, root_folder, md5) download_url_to_local(url, filename, root_folder, md5)
import tarfile import tarfile
with tarfile.open(fullname, "r") as tar: with tarfile.open(fullname, "r") as tar:
tar.extractall(root_folder) tar.extractall(root_folder)
assert 0 == os.system(f"cd {dirname}/examples && " # assert 0 == os.system(f"cd {dirname}/examples && "
f"{nvcc_path} --cudart=shared -ccbin=\"{cc_path}\" device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test") # f"{nvcc_path} --cudart=shared -ccbin=\"{cc_path}\" device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
if core.get_device_count(): # if core.get_device_count():
assert 0 == os.system(f"cd {dirname}/examples && ./test") # assert 0 == os.system(f"cd {dirname}/examples && ./test")
return dirname return dirname
def setup_cub(): def setup_cub():
@ -191,8 +199,9 @@ def setup_cuda_extern():
cuda_extern_src = os.path.join(jittor_path, "extern", "cuda", "src") cuda_extern_src = os.path.join(jittor_path, "extern", "cuda", "src")
cuda_extern_files = [os.path.join(cuda_extern_src, name) cuda_extern_files = [os.path.join(cuda_extern_src, name)
for name in os.listdir(cuda_extern_src)] for name in os.listdir(cuda_extern_src)]
so_name = os.path.join(cache_path_cuda, "cuda_extern.so") so_name = os.path.join(cache_path_cuda, "cuda_extern"+so)
compile(cc_path, cc_flags+f" -I'{cuda_include}' ", cuda_extern_files, so_name) compile(cc_path, cc_flags+f" -I\"{cuda_include}\" ", cuda_extern_files, so_name)
link_cuda_extern = f" -L\"{cache_path_cuda}\" -lcuda_extern "
ctypes.CDLL(so_name, dlopen_flags) ctypes.CDLL(so_name, dlopen_flags)
try: try:
@ -205,7 +214,7 @@ def setup_cuda_extern():
libs = ["cublas", "cudnn", "curand"] libs = ["cublas", "cudnn", "curand"]
for lib_name in libs: for lib_name in libs:
try: try:
setup_cuda_lib(lib_name) setup_cuda_lib(lib_name, extra_flags=link_cuda_extern)
except Exception as e: except Exception as e:
import traceback import traceback
line = traceback.format_exc() line = traceback.format_exc()
@ -244,12 +253,12 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
prefer_version = () prefer_version = ()
if nvcc_version[0] == 11: if nvcc_version[0] == 11:
prefer_version = ("8",) prefer_version = ("8",)
culib_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"lib{lib_name}.so", prefer_version) culib_path = search_file([cuda_bin, cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"lib{lib_name}.so", prefer_version)
if lib_name == "cublas" and nvcc_version[0] >= 10: if lib_name == "cublas" and nvcc_version[0] >= 10:
# manual link libcublasLt.so # manual link libcublasLt.so
try: try:
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version) cublas_lt_lib_path = search_file([cuda_bin, cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version)
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags) ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
except: except:
# some aarch64 os, such as uos with FT2000 cpu, # some aarch64 os, such as uos with FT2000 cpu,
@ -263,12 +272,12 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
if nvcc_version >= (11,0,0): if nvcc_version >= (11,0,0):
libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"] libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"]
for l in libs: for l in libs:
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], l, prefer_version) ex_cudnn_path = search_file([cuda_bin, cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], l, prefer_version)
ctypes.CDLL(ex_cudnn_path, dlopen_flags) ctypes.CDLL(ex_cudnn_path, dlopen_flags)
# dynamic link cuda library # dynamic link cuda library
ctypes.CDLL(culib_path, dlopen_flags) ctypes.CDLL(culib_path, dlopen_flags)
link_flags = f"-l{lib_name} -L'{cuda_lib}'" link_flags = f"-l{lib_name} -L\"{cuda_lib}\""
# find all source files # find all source files
culib_src_dir = os.path.join(jittor_path, "extern", "cuda", lib_name) culib_src_dir = os.path.join(jittor_path, "extern", "cuda", lib_name)
@ -281,7 +290,7 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
# compile and get operators # compile and get operators
culib = compile_custom_ops(culib_src_files, return_module=True, culib = compile_custom_ops(culib_src_files, return_module=True,
extra_flags=f" -I'{jt_cuda_include}' -I'{jt_culib_include}' {link_flags} {extra_flags} ") extra_flags=f" -I\"{jt_cuda_include}\" -I\"{jt_culib_include}\" {link_flags} {extra_flags} ")
culib_ops = culib.ops culib_ops = culib.ops
globals()[lib_name+"_ops"] = culib_ops globals()[lib_name+"_ops"] = culib_ops
globals()[lib_name] = culib globals()[lib_name] = culib
@ -289,19 +298,20 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
def install_cutt(root_folder): def install_cutt(root_folder):
# Modified from: https://github.com/ap-hynninen/cutt # Modified from: https://github.com/ap-hynninen/cutt
url = "https://codeload.github.com/Jittor/cutt/zip/v1.1" url = "https://codeload.github.com/Jittor/cutt/zip/v1.2"
filename = "cutt-1.1.zip" filename = "cutt-1.2.zip"
fullname = os.path.join(root_folder, filename) fullname = os.path.join(root_folder, filename)
dirname = os.path.join(root_folder, filename.replace(".zip","")) dirname = os.path.join(root_folder, filename.replace(".zip",""))
true_md5 = "7bb71cf7c49dbe57772539bf043778f7" true_md5 = "14d0fd1132c8cd657dc3cf29ce4db931"
if os.path.exists(fullname): if os.path.exists(fullname):
md5 = run_cmd('md5sum '+fullname).split()[0] from jittor_utils.misc import calculate_md5
md5 = calculate_md5(fullname)
if md5 != true_md5: if md5 != true_md5:
os.remove(fullname) os.remove(fullname)
shutil.rmtree(dirname) shutil.rmtree(dirname)
if not os.path.isfile(os.path.join(dirname, "bin", "cutt_test")): if not os.path.isfile(os.path.join(dirname, "lib/libcutt"+so)):
LOG.i("Downloading cutt...") LOG.i("Downloading cutt...")
download_url_to_local(url, filename, root_folder, true_md5) download_url_to_local(url, filename, root_folder, true_md5)
@ -320,7 +330,17 @@ def install_cutt(root_folder):
if len(flags.cuda_archs): if len(flags.cuda_archs):
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} " arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs)) arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
run_cmd(f"make NVCC_GENCODE='{arch_flag} --cudart=shared -ccbin=\"{cc_path}\" ' nvcc_path='{nvcc_path}'", cwd=dirname) cutt_include = f" -I\"{dirname}/include\" -I\"{dirname}/src\" "
files = glob.glob(dirname+"/src/*.c*", recursive=True)
files2 = []
for f in files:
if f.endswith("cutt_bench.cpp") or \
f.endswith("cutt_test.cpp"):
continue
files2.append(f)
cutt_flags = cc_flags+opt_flags+cutt_include
os.makedirs(dirname+"/lib", exist_ok=True)
compile(cc_path, cutt_flags, files2, dirname+"/lib/libcutt"+so, cuda_flags=arch_flag)
return dirname return dirname
def setup_cutt(): def setup_cutt():
@ -342,11 +362,11 @@ def setup_cutt():
make_cache_dir(cutt_path) make_cache_dir(cutt_path)
install_cutt(cutt_path) install_cutt(cutt_path)
cutt_home = os.path.join(cutt_path, "cutt-1.1") cutt_home = os.path.join(cutt_path, "cutt-1.2")
cutt_include_path = os.path.join(cutt_home, "src") cutt_include_path = os.path.join(cutt_home, "src")
cutt_lib_path = os.path.join(cutt_home, "lib") cutt_lib_path = os.path.join(cutt_home, "lib")
cutt_lib_name = os.path.join(cutt_lib_path, "libcutt.so") cutt_lib_name = os.path.join(cutt_lib_path, "libcutt"+so)
assert os.path.isdir(cutt_include_path) assert os.path.isdir(cutt_include_path)
assert os.path.isdir(cutt_lib_path) assert os.path.isdir(cutt_lib_path)
assert os.path.isfile(cutt_lib_name), cutt_lib_name assert os.path.isfile(cutt_lib_name), cutt_lib_name
@ -354,12 +374,14 @@ def setup_cutt():
LOG.v(f"cutt_lib_path: {cutt_lib_path}") LOG.v(f"cutt_lib_path: {cutt_lib_path}")
LOG.v(f"cutt_lib_name: {cutt_lib_name}") LOG.v(f"cutt_lib_name: {cutt_lib_name}")
# We do not link manualy, link in custom ops # We do not link manualy, link in custom ops
if os.name == "nt":
os.add_dll_directory(cutt_lib_path)
ctypes.CDLL(cutt_lib_name, dlopen_flags) ctypes.CDLL(cutt_lib_name, dlopen_flags)
cutt_op_dir = os.path.join(jittor_path, "extern", "cuda", "cutt", "ops") cutt_op_dir = os.path.join(jittor_path, "extern", "cuda", "cutt", "ops")
cutt_op_files = [os.path.join(cutt_op_dir, name) for name in os.listdir(cutt_op_dir)] cutt_op_files = [os.path.join(cutt_op_dir, name) for name in os.listdir(cutt_op_dir)]
cutt_ops = compile_custom_ops(cutt_op_files, cutt_ops = compile_custom_ops(cutt_op_files,
extra_flags=f" -I'{cutt_include_path}'") extra_flags=f" -I\"{cutt_include_path}\" -L\"{cutt_lib_path}\" -llibcutt ")
LOG.vv("Get cutt_ops: "+str(dir(cutt_ops))) LOG.vv("Get cutt_ops: "+str(dir(cutt_ops)))
@ -442,7 +464,7 @@ def setup_nccl():
nccl_src_files.append(os.path.join(r, fname)) nccl_src_files.append(os.path.join(r, fname))
nccl_ops = compile_custom_ops(nccl_src_files, nccl_ops = compile_custom_ops(nccl_src_files,
extra_flags=f" -I'{nccl_include_path}' {mpi_compile_flags} ") extra_flags=f" -I\"{nccl_include_path}\" {mpi_compile_flags} ")
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops))) LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
def manual_link(flags): def manual_link(flags):
@ -498,7 +520,7 @@ def setup_mpi():
mpi_src_files.append(os.path.join(r, fname)) mpi_src_files.append(os.path.join(r, fname))
# mpi compile flags add for nccl # mpi compile flags add for nccl
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' " mpi_compile_flags += f" -I\"{os.path.join(mpi_src_dir, 'inc')}\" "
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "") mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
mpi_version = get_version(mpicc_path) mpi_version = get_version(mpicc_path)

View File

@ -34,25 +34,62 @@ def make_cache_dir(cache_path):
LOG.i(f"Create cache dir: {cache_path}") LOG.i(f"Create cache dir: {cache_path}")
os.mkdir(cache_path) os.mkdir(cache_path)
def shsplit(s):
s1 = s.split(' ')
s2 = []
count = 0
for s in s1:
nc = s.count('"') + s.count('\'')
if count&1:
count += nc
s2[-1] += " "
s2[-1] += s
else:
count = nc
s2.append(s)
return s2
def remove_flags(flags, rm_flags): def remove_flags(flags, rm_flags):
flags = flags.split(" ") flags = shsplit(flags)
output = [] output = []
for s in flags: for s in flags:
ss = s.replace("\"", "")
for rm in rm_flags: for rm in rm_flags:
if s.startswith(rm): if ss.startswith(rm) or ss.endswith(rm):
break break
else: else:
output.append(s) output.append(s)
return " ".join(output) return " ".join(output)
def compile(compiler, flags, inputs, output, combind_build=False): def moveback_flags(flags, rm_flags):
flags = shsplit(flags)
output = []
output2 = []
for s in flags:
ss = s.replace("\"", "")
for rm in rm_flags:
if ss.startswith(rm) or ss.endswith(rm):
output2.append(s)
break
else:
output.append(s)
return " ".join(output+output2)
def map_flags(flags, func):
flags = shsplit(flags)
output = []
for s in flags:
output.append(func(s))
return " ".join(output)
def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags=""):
def do_compile(cmd): def do_compile(cmd):
if jit_utils.cc: if jit_utils.cc:
return jit_utils.cc.cache_compile(cmd, cache_path, jittor_path) return jit_utils.cc.cache_compile(cmd, cache_path, jittor_path)
else: else:
run_cmd(cmd) run_cmd(cmd)
return True return True
link = link_flags
base_output = os.path.basename(output).split('.')[0] base_output = os.path.basename(output).split('.')[0]
if os.name == 'nt': if os.name == 'nt':
# windows do not combind build, need gen def # windows do not combind build, need gen def
@ -64,18 +101,9 @@ def compile(compiler, flags, inputs, output, combind_build=False):
# initialize order in windows seems reversed # initialize order in windows seems reversed
inputs = list(inputs[::-1]) inputs = list(inputs[::-1])
link = link + f' -Wl,--export-all-symbols,--out-implib,"{afile}" ' link = link + f' -Wl,--export-all-symbols,--out-implib,"{afile}" '
if base_output == "jit_utils_core":
pass
elif base_output == "jittor_core":
inputs.append(os.path.join(cache_path, f"jit_utils_core{lib_suffix}"))
else:
inputs.append(os.path.join(cache_path, f"jit_utils_core{lib_suffix}"))
inputs.append(os.path.join(cache_path, f"jittor_core{lib_suffix}"))
# if output is core, add core_link_flags if not os.path.isabs(output):
if output.startswith("jittor_core"): output = os.path.join(cache_path, output)
link = link + core_link_flags
output = os.path.join(cache_path, output)
# don't recompile object file in inputs # don't recompile object file in inputs
obj_files = [] obj_files = []
ex_obj_files = [] ex_obj_files = []
@ -94,30 +122,33 @@ def compile(compiler, flags, inputs, output, combind_build=False):
return do_compile(fix_cl_flags(cmd)) return do_compile(fix_cl_flags(cmd))
# split compile object file and link # split compile object file and link
# remove -l -L flags when compile object files # remove -l -L flags when compile object files
oflags = remove_flags(flags, ['-l', '-L', '-Wl,']) oflags = remove_flags(flags, ['-l', '-L', '-Wl,', '.lib', '-shared'])
cmds = [] cmds = []
for input, obj_file in zip(inputs, obj_files): for input, obj_file in zip(inputs, obj_files):
cc = compiler cc = compiler
nflags = oflags nflags = oflags
cmd = f"{input} {nflags} {lto_flags} -c -o {obj_file}"
if input.endswith(".cu"): if input.endswith(".cu"):
if has_cuda: if has_cuda:
nflags = convert_nvcc_flags(oflags) cmd = f"\"{nvcc_path}\" {cuda_flags} {cmd}"
cc = nvcc_path cmd = convert_nvcc_flags(fix_cl_flags(cmd))
else: else:
continue continue
cmd = f"\"{cc}\" {input} {nflags} {lto_flags} -c -o {obj_file}" else:
cmd = f"\"{cc}\" {cmd}"
cmd = fix_cl_flags(cmd)
if "nan_checker" in input: if "nan_checker" in input:
# nan checker needs to disable fast_math # nan checker needs to disable fast_math
cmd = cmd.replace("--use_fast_math", "") cmd = cmd.replace("--use_fast_math", "")
cmd = cmd.replace("-Ofast", "-O2") cmd = cmd.replace("-Ofast", "-O2")
cmds.append(fix_cl_flags(cmd)) cmds.append(cmd)
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output) jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
obj_files += ex_obj_files obj_files += ex_obj_files
if os.name == 'nt': if os.name == 'nt':
dumpdef_path = os.path.join(jittor_path, "utils", "dumpdef.py") dumpdef_path = os.path.join(jittor_path, "utils", "dumpdef.py")
cmd = f"\"{sys.executable}\" \"{dumpdef_path}\" {' '.join(obj_files)} -Fo: \"{output}.def\"" cmd = f"\"{sys.executable}\" \"{dumpdef_path}\" {' '.join(obj_files)} -Fo: \"{output}.def\""
do_compile(fix_cl_flags(cmd)) do_compile(fix_cl_flags(cmd))
cmd = f"\"{compiler}\" {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}" cmd = f"\"{compiler}\" {' '.join(obj_files)} -o {output} {flags} {lto_flags}"
return do_compile(fix_cl_flags(cmd)) return do_compile(fix_cl_flags(cmd))
def gen_jit_tests(): def gen_jit_tests():
@ -673,11 +704,15 @@ def compile_custom_ops(
op_extra_flags = includes + extra_flags op_extra_flags = includes + extra_flags
lib_path = os.path.join(cache_path, "custom_ops")
make_cache_dir(lib_path)
gen_src_fname = os.path.join(lib_path, gen_name+".cc")
gen_head_fname = os.path.join(lib_path, gen_name+".h")
gen_lib = os.path.join(lib_path, gen_name+extension_suffix)
libname = gen_name + lib_suffix
op_extra_flags += f" -L\"{lib_path}\" -l\"{libname}\" "
gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags) gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags)
make_cache_dir(os.path.join(cache_path, "custom_ops"))
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src) pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src)
# gen src initialize first # gen src initialize first
builds.insert(0, gen_src_fname) builds.insert(0, gen_src_fname)
@ -794,8 +829,9 @@ def compile_extern():
def check_cuda(): def check_cuda():
if not nvcc_path: if not nvcc_path:
return return
global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include, cuda_home global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include, cuda_home, cuda_bin
cuda_dir = os.path.dirname(get_full_path_of_executable(nvcc_path)) cuda_dir = os.path.dirname(get_full_path_of_executable(nvcc_path))
cuda_bin = cuda_dir
cuda_home = os.path.abspath(os.path.join(cuda_dir, "..")) cuda_home = os.path.abspath(os.path.join(cuda_dir, ".."))
# try default nvidia-cuda-toolkit in Ubuntu 20.04 # try default nvidia-cuda-toolkit in Ubuntu 20.04
# assert cuda_dir.endswith("bin") and "cuda" in cuda_dir.lower(), f"Wrong cuda_dir: {cuda_dir}" # assert cuda_dir.endswith("bin") and "cuda" in cuda_dir.lower(), f"Wrong cuda_dir: {cuda_dir}"
@ -805,10 +841,25 @@ def check_cuda():
# this nvcc is install by package manager # this nvcc is install by package manager
cuda_lib = "/usr/lib/x86_64-linux-gnu" cuda_lib = "/usr/lib/x86_64-linux-gnu"
cuda_include2 = os.path.join(jittor_path, "extern","cuda","inc") cuda_include2 = os.path.join(jittor_path, "extern","cuda","inc")
cc_flags += f" -DHAS_CUDA -I'{cuda_include}' -I'{cuda_include2}' " cc_flags += f" -DHAS_CUDA -I\"{cuda_include}\" -I\"{cuda_include2}\" "
core_link_flags += f" -lcudart -L'{cuda_lib}' " if os.name == 'nt':
# ctypes.CDLL(cuda_lib+"/libcudart.so", import_flags) cuda_lib = os.path.abspath(os.path.join(cuda_dir, "..", "lib", "x64"))
ctypes.CDLL(cuda_lib+"/libcudart.so", dlopen_flags) # cc_flags += f" \"{cuda_lib}\\cudart.lib\" "
cuda_lib_path = glob.glob(cuda_bin+"/cudart64*")[0]
cc_flags += f" -lcudart -L\"{cuda_lib}\" "
os.add_dll_directory(cuda_dir)
# dll = ctypes.CDLL("cudart64_110", dlopen_flags)
dll = ctypes.CDLL(cuda_lib_path, dlopen_flags)
cuda_driver = ctypes.CDLL(r"nvcuda", dlopen_flags)
driver_version = ctypes.c_int()
r = cuda_driver.cuDriverGetVersion(ctypes.byref(driver_version))
print("version:", driver_version, r)
ret = dll.cudaDeviceSynchronize()
assert ret == 0
else:
cc_flags += f" -lcudart -L\"{cuda_lib}\" "
# ctypes.CDLL(cuda_lib+"/libcudart.so", import_flags)
ctypes.CDLL(cuda_lib+"/libcudart.so", dlopen_flags)
has_cuda = 1 has_cuda = 1
def check_cache_compile(): def check_cache_compile():
@ -950,89 +1001,90 @@ if platform.system() == 'Darwin' and platform.machine() == 'arm64':
if "cc_flags" in os.environ: if "cc_flags" in os.environ:
cc_flags += os.environ["cc_flags"] + ' ' cc_flags += os.environ["cc_flags"] + ' '
link_flags = " -lstdc++ -ldl -shared " cc_flags += " -lstdc++ -ldl -shared "
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
# TODO: if not using apple clang, there is no need to add -lomp # TODO: if not using apple clang, there is no need to add -lomp
link_flags += "-undefined dynamic_lookup -lomp " cc_flags += "-undefined dynamic_lookup -lomp "
if platform.machine() == "arm64": if platform.machine() == "arm64":
link_flags += " -L/opt/homebrew/lib " cc_flags += " -L/opt/homebrew/lib "
core_link_flags = ""
opt_flags = "" opt_flags = ""
py_include = jit_utils.get_py3_include_path() py_include = jit_utils.get_py3_include_path()
LOG.i(f"py_include: {py_include}") LOG.i(f"py_include: {py_include}")
extension_suffix = jit_utils.get_py3_extension_suffix() extension_suffix = jit_utils.get_py3_extension_suffix()
lib_suffix = extension_suffix.replace(".pyd", ".lib") lib_suffix = extension_suffix.rsplit(".", 1)[0]
LOG.i(f"extension_suffix: {extension_suffix}") LOG.i(f"extension_suffix: {extension_suffix}")
so = ".so" if os.name != 'nt' else ".dll"
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
# TODO: if not using apple clang, cannot add -Xpreprocessor # TODO: if not using apple clang, cannot add -Xpreprocessor
kernel_opt_flags = kernel_opt_flags + " -Xpreprocessor -fopenmp " kernel_opt_flags += " -Xpreprocessor -fopenmp "
elif cc_type != 'cl': elif cc_type != 'cl':
kernel_opt_flags = kernel_opt_flags + " -fopenmp " kernel_opt_flags += " -fopenmp "
fix_cl_flags = lambda x:x fix_cl_flags = lambda x:x
if os.name == 'nt': if os.name == 'nt':
if cc_type == 'g++': if cc_type == 'g++':
link_flags = link_flags.replace('-ldl', '') pass
py3_link_path = '-L"' + os.path.join(
os.path.dirname(sys.executable),
"libs"
) + f'" -lpython3{sys.version_info.minor} '
core_link_flags = py3_link_path
link_flags += core_link_flags
# link_flags += " -Wl,--unresolved-symbols=ignore-all "
# cc_flags += " -Xlinker --allow-shlib-undefined "
cc_flags = cc_flags.replace('-std=c++14', '-std=c++17')
link_flags += " -fopenmp "
kernel_opt_flags += f" {cache_path}\\jit_utils_core{lib_suffix} "
kernel_opt_flags += f" {cache_path}\\jittor_core{lib_suffix} "
elif cc_type == 'cl': elif cc_type == 'cl':
py3_link_path = os.path.join( py3_link_path = os.path.join(
os.path.dirname(sys.executable), os.path.dirname(sys.executable),
"libs", "libs",
f'python3{sys.version_info.minor}.lib'
) )
# core_link_flags = py3_link_path cc_flags = cc_flags.replace("-std=c++14", "-std=c++17")
link_flags += core_link_flags cc_flags = cc_flags.replace("-lstdc++", "")
# link_flags += " -Wl,--unresolved-symbols=ignore-all " cc_flags = cc_flags.replace("-ldl", "")
# cc_flags += " -Xlinker --allow-shlib-undefined " cc_flags += f" -L\"{py3_link_path}\" -lpython3{sys.version_info.minor} "
kernel_opt_flags += f" {cache_path}\\jit_utils_core{lib_suffix} " cc_flags += " -EHsc "
kernel_opt_flags += f" {cache_path}\\jittor_core{lib_suffix} "
# cc_flags = " -std:c++17 -O2 -fp:fast -EHsc "
cc_flags = " -std:c++17 -O2 -fp:fast -EHsc "
# cc_flags += py3_link_path + " "
import jittor_utils import jittor_utils
if jittor_utils.msvc_path: if jittor_utils.msvc_path:
mp = jittor_utils.msvc_path mp = jittor_utils.msvc_path
cc_flags += f' -nologo -I"{mp}\\cl_x64\\include" -I"{mp}\\win10_kits\\include\\ucrt" -I"{mp}\\win10_kits\\include\\shared" -I"{mp}\\win10_kits\\include\\um" -DNOMINMAX ' cc_flags += f' -nologo -I"{mp}\\VC\\include" -I"{mp}\\win10_kits\\include\\ucrt" -I"{mp}\\win10_kits\\include\\shared" -I"{mp}\\win10_kits\\include\\um" -DNOMINMAX '
win_link_flags = f' -link -LIBPATH:"{mp}\\cl_x64\\lib" -LIBPATH:"{mp}\\win10_kits\\lib\\um\\x64" -LIBPATH:"{mp}\\win10_kits\\lib\\ucrt\\x64" ' cc_flags += f' -L"{mp}\\VC\\lib" -L"{mp}\\win10_kits\\lib\\um\\x64" -L"{mp}\\win10_kits\\lib\\ucrt\\x64" '
link_flags = ' -LD '
kernel_opt_flags += win_link_flags# + " -EXPORT:\"?jit_run@FusedOp@jittor@@QEAAXXZ\""
def fix_cl_flags(cmd): def fix_cl_flags(cmd):
cmd = cmd.replace(".o ", ".obj ") cmd = cmd.replace(".o ", ".obj ")
cmd = cmd.replace(".o\" ", ".obj\" ") cmd = cmd.replace(".o\" ", ".obj\" ")
if cmd.endswith(".o"): cmd += "bj" if cmd.endswith(".o"): cmd += "bj"
from shlex import split if " -o " in cmd:
if " -LD " in cmd: if " -shared " in cmd:
cmd = cmd.replace(" -o ", " -Fe: ") cmd = cmd.replace(" -o ", " -Fe: ")
output = split(cmd.split("-Fe:")[1].strip(), posix=False)[0] output = shsplit(cmd.split("-Fe:")[1].strip())[0]
base_output = os.path.basename(output).split('.')[0] base_output = os.path.basename(output).split('.')[0]
cmd += win_link_flags cmd += f" -DEF:\"{output}.def\" -IGNORE:4102 -IGNORE:4197 -IGNORE:4217 "
cmd += f" -DEF:\"{output}.def\" -IGNORE:4102 -IGNORE:4197 -IGNORE:4217 {py3_link_path}"
if base_output == "jit_utils_core":
pass
elif base_output == "jittor_core":
cmd += " " + os.path.join(cache_path, f"jit_utils_core{lib_suffix}")
else:
cmd += " " + os.path.join(cache_path, f"jit_utils_core{lib_suffix} ")
cmd += " " + os.path.join(cache_path, f"jittor_core{lib_suffix} ")
elif " -c -o " in cmd: elif " -c -o " in cmd:
cmd = cmd.replace(" -c -o ", " -c -Fo: ") cmd = cmd.replace(" -c -o ", " -c -Fo: ")
flags = shsplit(cmd)
output = []
output2 = []
for f in flags:
if f.startswith("-link"):
pass
elif f.startswith("-l"):
output2.append(f[2:]+".lib")
elif f.startswith("-LIB"):
output2.append(f)
elif f.startswith("-LD"):
output.append(f)
elif f.startswith("-L"):
output2.append("-LIBPATH:"+f[2:])
elif ".lib" in f:
output2.append(f)
elif f.startswith("-DEF:"):
output2.append(f)
elif f.startswith("-W") or f.startswith("-f"):
pass
elif f.startswith("-std="):
output.append(f.replace("=", ":"))
else:
output.append(f)
cmd = " ".join(output)
if len(output2):
cmd += " -link " + " ".join(output2)
cmd = cmd.replace("-include", "-FI") cmd = cmd.replace("-include", "-FI")
cmd = cmd.replace("-shared", "-LD")
return cmd return cmd
if ' -O' not in cc_flags: if ' -O' not in cc_flags:
@ -1055,7 +1107,7 @@ ck_path = os.path.join(cache_path, "checkpoints")
make_cache_dir(ck_path) make_cache_dir(ck_path)
# build cache_compile # build cache_compile
cc_flags += f" -I{os.path.join(jittor_path, 'src')} " cc_flags += f" -I\"{os.path.join(jittor_path, 'src')}\" "
cc_flags += py_include cc_flags += py_include
check_cache_compile() check_cache_compile()
LOG.v(f"Get cache_compile: {jit_utils.cc}") LOG.v(f"Get cache_compile: {jit_utils.cc}")
@ -1065,9 +1117,28 @@ has_cuda = 0
check_cuda() check_cuda()
nvcc_flags = os.environ.get("nvcc_flags", "") nvcc_flags = os.environ.get("nvcc_flags", "")
if has_cuda: if has_cuda:
nvcc_flags += cc_flags + link_flags nvcc_flags += cc_flags
def convert_nvcc_flags(nvcc_flags): def convert_nvcc_flags(nvcc_flags):
# nvcc don't support -Wall option # nvcc don't support -Wall option
if os.name == 'nt':
nvcc_flags = nvcc_flags.replace("-fp:", "-Xcompiler -fp:")
nvcc_flags = nvcc_flags.replace("-EHsc", "-Xcompiler -EHsc")
nvcc_flags = nvcc_flags.replace("-nologo", "")
nvcc_flags = nvcc_flags.replace("-std:", "-std=")
nvcc_flags = nvcc_flags.replace("-Fo:", "-o")
nvcc_flags = nvcc_flags.replace("-LD", "-shared")
nvcc_flags = nvcc_flags.replace("-LIBPATH:", "-L")
nvcc_flags = nvcc_flags.replace("-link", "")
def func(x):
if ".lib" not in x: return x
x = x.replace("\"", "")
a = os.path.dirname(x)
b = os.path.basename(x)
if not b.endswith(".lib"):
return x
return f"-L\"{a}\" -l{b[:-4]}"
nvcc_flags = map_flags(nvcc_flags, func)
nvcc_flags = nvcc_flags.replace("-std=c++17", "-std=c++14 -Xcompiler -std:c++14")
nvcc_flags = nvcc_flags.replace("-Wall", "") nvcc_flags = nvcc_flags.replace("-Wall", "")
nvcc_flags = nvcc_flags.replace("-Wno-unknown-pragmas", "") nvcc_flags = nvcc_flags.replace("-Wno-unknown-pragmas", "")
nvcc_flags = nvcc_flags.replace("-fopenmp", "") nvcc_flags = nvcc_flags.replace("-fopenmp", "")
@ -1075,10 +1146,10 @@ if has_cuda:
nvcc_flags = nvcc_flags.replace("-Werror", "") nvcc_flags = nvcc_flags.replace("-Werror", "")
nvcc_flags = nvcc_flags.replace("-fPIC", "-Xcompiler -fPIC") nvcc_flags = nvcc_flags.replace("-fPIC", "-Xcompiler -fPIC")
nvcc_flags = nvcc_flags.replace("-fdiagnostics", "-Xcompiler -fdiagnostics") nvcc_flags = nvcc_flags.replace("-fdiagnostics", "-Xcompiler -fdiagnostics")
nvcc_flags += f" -x cu --cudart=shared -ccbin='{cc_path}' --use_fast_math " nvcc_flags += f" -x cu --cudart=shared -ccbin=\"{cc_path}\" --use_fast_math "
# nvcc warning is noise # nvcc warning is noise
nvcc_flags += " -w " nvcc_flags += " -w "
nvcc_flags += f" -I'{os.path.join(jittor_path, 'extern/cuda/inc')}' " nvcc_flags += f" -I\"{os.path.join(jittor_path, 'extern/cuda/inc')}\" "
if os.environ.get("cuda_debug", "0") == "1": if os.environ.get("cuda_debug", "0") == "1":
nvcc_flags += " -G " nvcc_flags += " -G "
return nvcc_flags return nvcc_flags
@ -1092,7 +1163,7 @@ jit_src = gen_jit_op_maker(op_headers)
LOG.vvvv(jit_src) LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f: with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
f.write(jit_src) f.write(jit_src)
cc_flags += f' -I{cache_path} ' cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" '
# gen pyjt # gen pyjt
pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path) pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
@ -1164,16 +1235,19 @@ if use_data_gz:
f.write(data.decode("utf8")) f.write(data.decode("utf8"))
dflags = (cc_flags+opt_flags)\ dflags = (cc_flags+opt_flags)\
.replace("-Wall", "") \ .replace("-Wall", "") \
.replace("-Werror", "") .replace("-Werror", "") \
.replace("-shared", "")
vdp = os.path.join(jittor_path, "src", "utils", "vdp") vdp = os.path.join(jittor_path, "src", "utils", "vdp")
run_cmd(fix_cl_flags(f"{cc_path} {dflags} -include {vdp} {data_s_path} -c -o {data_o_path}")) run_cmd(fix_cl_flags(f"{cc_path} {dflags} -include {vdp} {data_s_path} -c -o {data_o_path}"))
os.remove(data_s_path) # os.remove(data_s_path)
with open(data_gz_md5_path, 'w') as f: with open(data_gz_md5_path, 'w') as f:
f.write(md5) f.write(md5)
files.append(data_o_path) files.append(data_o_path)
files = [f for f in files if "__data__" not in f] files = [f for f in files if "__data__" not in f]
cc_flags += f" -l\"jit_utils_core{lib_suffix}\" "
compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix) compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix)
cc_flags += f" -l\"jittor_core{lib_suffix}\" "
# TODO: move to compile_extern.py # TODO: move to compile_extern.py
# compile_extern() # compile_extern()
@ -1182,6 +1256,8 @@ with jit_utils.import_scope(import_flags):
import jittor_core as core import jittor_core as core
flags = core.flags() flags = core.flags()
nvcc_flags = convert_nvcc_flags(cc_flags)
if has_cuda: if has_cuda:
if len(flags.cuda_archs): if len(flags.cuda_archs):
nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} " nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} "
@ -1189,7 +1265,7 @@ if has_cuda:
flags.cc_path = cc_path flags.cc_path = cc_path
flags.cc_type = cc_type flags.cc_type = cc_type
flags.cc_flags = cc_flags + link_flags + kernel_opt_flags flags.cc_flags = cc_flags + kernel_opt_flags
flags.nvcc_path = nvcc_path flags.nvcc_path = nvcc_path
flags.nvcc_flags = nvcc_flags flags.nvcc_flags = nvcc_flags
flags.python_path = python_path flags.python_path = python_path

View File

@ -190,7 +190,7 @@ void CudnnConv3dBackwardWOp::jit_run() {
}; };
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
int perf_count; int perf_count;
cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_algos]; STACK_ALLOC(cudnnConvolutionBwdFilterAlgoPerf_t,perf_results,num_algos);
cudnnConvolutionBwdFilterAlgo_t algo; cudnnConvolutionBwdFilterAlgo_t algo;
bool benchmark=true; bool benchmark=true;

View File

@ -181,7 +181,7 @@ void CudnnConv3dBackwardXOp::jit_run() {
}; };
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
int perf_count; int perf_count;
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos]; STACK_ALLOC(cudnnConvolutionBwdDataAlgoPerf_t,perf_results,num_algos);
cudnnConvolutionBwdDataAlgo_t algo; cudnnConvolutionBwdDataAlgo_t algo;
bool benchmark=true; bool benchmark=true;

View File

@ -184,7 +184,7 @@ void CudnnConv3dOp::jit_run() {
}; };
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
int perf_count; int perf_count;
cudnnConvolutionFwdAlgoPerf_t perf_results[num_algos]; STACK_ALLOC(cudnnConvolutionFwdAlgoPerf_t,perf_results,num_algos);
cudnnConvolutionFwdAlgo_t algo; cudnnConvolutionFwdAlgo_t algo;
bool benchmark=true; bool benchmark=true;

View File

@ -180,7 +180,7 @@ void CudnnConvBackwardWOp::jit_run() {
}; };
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
int perf_count; int perf_count;
cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_algos]; STACK_ALLOC(cudnnConvolutionBwdFilterAlgoPerf_t,perf_results,num_algos);
cudnnConvolutionBwdFilterAlgo_t algo; cudnnConvolutionBwdFilterAlgo_t algo;
bool benchmark=true; bool benchmark=true;

View File

@ -181,7 +181,7 @@ void CudnnConvBackwardXOp::jit_run() {
}; };
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
int perf_count; int perf_count;
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos]; STACK_ALLOC(cudnnConvolutionBwdDataAlgoPerf_t,perf_results,num_algos);
cudnnConvolutionBwdDataAlgo_t algo; cudnnConvolutionBwdDataAlgo_t algo;
bool benchmark=true; bool benchmark=true;

View File

@ -183,7 +183,7 @@ void CudnnConvOp::jit_run() {
}; };
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
int perf_count; int perf_count;
cudnnConvolutionFwdAlgoPerf_t perf_results[num_algos]; STACK_ALLOC(cudnnConvolutionFwdAlgoPerf_t,perf_results,num_algos);
cudnnConvolutionFwdAlgo_t algo; cudnnConvolutionFwdAlgo_t algo;
bool benchmark=true; bool benchmark=true;

View File

@ -84,9 +84,6 @@ static double second (void)
gettimeofday(&tv, NULL); gettimeofday(&tv, NULL);
return (double)tv.tv_sec + (double)tv.tv_usec / 1000000.0; return (double)tv.tv_sec + (double)tv.tv_usec / 1000000.0;
} }
#else
#error unsupported platform
#endif
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType(); template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; } template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
@ -995,4 +992,10 @@ int cudnn_test_entry( int argc, char** argv )
doTest<half1>(algo, dimA, padA, convstrideA, filterdimA, filterFormat, mathType, benchmark); doTest<half1>(algo, dimA, padA, convstrideA, filterdimA, filterFormat, mathType, benchmark);
return 0; return 0;
} }
#else
int cudnn_test_entry( int argc, char** argv ) {
return 0;
}
#endif

View File

@ -472,9 +472,9 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
} }
#endif #endif
last_is_cuda = is_cuda; last_is_cuda = is_cuda;
_JT_SEH_START2; // _JT_SEH_START2;
op->do_run_after_prepare(jkl); op->do_run_after_prepare(jkl);
_JT_SEH_END2; // _JT_SEH_END2;
#ifdef HAS_CUDA #ifdef HAS_CUDA
// migrate to gpu // migrate to gpu
if (PREDICT_BRANCH_NOT_TAKEN((!is_cuda && use_cuda && !use_cuda_managed_allocator))) { if (PREDICT_BRANCH_NOT_TAKEN((!is_cuda && use_cuda && !use_cuda_managed_allocator))) {

View File

@ -19,6 +19,7 @@
#include "utils/cache_compile.h" #include "utils/cache_compile.h"
#include "utils/flags.h" #include "utils/flags.h"
#include "fused_op.h" #include "fused_op.h"
#include "utils/str_utils.h"
namespace jittor { namespace jittor {
@ -32,6 +33,71 @@ DEFINE_FLAG(string, python_path, "", "Path of python interpreter");
DEFINE_FLAG(string, cache_path, "", "Cache path of jittor"); DEFINE_FLAG(string, cache_path, "", "Cache path of jittor");
DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not"); DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not");
#ifdef _MSC_VER
vector<string> shsplit(const string& s) {
auto s1 = split(s, " ");
vector<string> s2;
int count = 0;
for (auto& s : s1) {
int nc = 0;
for (auto& c : s)
nc += c=='"' || c=='\'';
if (count&1) {
count += nc;
s2.back() += " ";
s2.back() += s;
} else {
count = nc;
s2.push_back(s);
}
}
return s2;
}
string fix_cl_flags(const string& cmd) {
auto flags = shsplit(cmd);
vector<string> output, output2;
for (auto& f : flags) {
if (startswith(f, "-link"))
continue;
else if (startswith(f, "-l"))
output2.push_back(f.substr(2)+".lib");
else if (startswith(f, "-LIB"))
output2.push_back(f);
else if (startswith(f, "-LD"))
output.push_back(f);
else if (startswith(f, "-L"))
output2.push_back("-LIBPATH:"+f.substr(2));
else if (f.find(".lib") != string::npos)
output2.push_back(f);
else if (startswith(f, "-DEF:"))
output2.push_back(f);
else if (startswith(f, "-W") || startswith(f,"-f"))
continue;
else if (startswith(f,"-std="))
output.push_back("-std:"+f.substr(5));
else if (startswith(f,"-include"))
output.push_back("-FI");
else if (startswith(f,"-shared"))
output.push_back("-LD");
else
output.push_back(f);
}
string cmdx = "";
for (auto& s : output) {
cmdx += s;
cmdx += " ";
}
cmdx += "-link ";
for (auto& s : output2) {
cmdx += s;
cmdx += " ";
}
return cmdx;
}
#endif
namespace jit_compiler { namespace jit_compiler {
std::mutex dl_open_mutex; std::mutex dl_open_mutex;
@ -103,6 +169,7 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
write(jit_src_path, src); write(jit_src_path, src);
string cmd; string cmd;
auto symbol_name = get_symbol_name(jit_key);
#ifndef _MSC_VER #ifndef _MSC_VER
if (is_cuda_op) { if (is_cuda_op) {
cmd = "\"" + nvcc_path + "\"" cmd = "\"" + nvcc_path + "\""
@ -124,21 +191,18 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
cmd = "\"" + nvcc_path + "\"" cmd = "\"" + nvcc_path + "\""
+ " \"" + jit_src_path + "\"" + other_src + " \"" + jit_src_path + "\"" + other_src
+ nvcc_flags + extra_flags + nvcc_flags + extra_flags
+ " -o \"" + jit_lib_path + "\""; + " -o \"" + jit_lib_path + "\""
+ " -Xlinker -EXPORT:\""
+ symbol_name + "\"";;
} else { } else {
auto symbol_name = get_symbol_name(jit_key);
auto pos = cc_flags.find("-link");
auto cc_flags1 = cc_flags.substr(0, pos);
auto cc_flags2 = cc_flags.substr(pos);
cmd = "\"" + cc_path + "\"" cmd = "\"" + cc_path + "\""
+ " \"" + jit_src_path + "\"" + other_src + " \"" + jit_src_path + "\"" + other_src
+ cc_flags1 + extra_flags + " -Fe: \"" + jit_lib_path + "\" "
+ " -Fe: \"" + jit_lib_path + "\" " + cc_flags2 + " -EXPORT:\"" + fix_cl_flags(cc_flags + extra_flags) + " -EXPORT:\""
+ symbol_name + "\""; + symbol_name + "\"";
} }
#endif #endif
cache_compile(cmd, cache_path, jittor_path); cache_compile(cmd, cache_path, jittor_path);
auto symbol_name = get_symbol_name(jit_key);
auto jit_entry = load_jit_lib(jit_lib_path, symbol_name); auto jit_entry = load_jit_lib(jit_lib_path, symbol_name);
return jit_entry; return jit_entry;
} }

View File

@ -20,6 +20,8 @@ namespace jittor {
using std::string_view; using std::string_view;
#elif defined(__GNUC__) #elif defined(__GNUC__)
using std::experimental::string_view; using std::experimental::string_view;
#elif __cplusplus < 201400L
using string_view = string;
#else #else
using std::string_view; using std::string_view;
#endif #endif

View File

@ -446,7 +446,7 @@ void GetitemOp::jit_prepare(JK& jk) {
#ifdef HAS_CUDA #ifdef HAS_CUDA
if (use_cuda) { if (use_cuda) {
int no = o_shape.size(); int no = o_shape.size();
int masks[no]; STACK_ALLOC(int, masks, no);
int tdims[6]; int tdims[6];
cuda_loop_schedule(o_shape, masks, tdims); cuda_loop_schedule(o_shape, masks, tdims);
for (int i=0; i<no; i++) { for (int i=0; i<no; i++) {

View File

@ -237,7 +237,7 @@ void SetitemOp::jit_prepare(JK& jk) {
#ifdef HAS_CUDA #ifdef HAS_CUDA
if (use_cuda) { if (use_cuda) {
int no = o_shape.size(); int no = o_shape.size();
int masks[no]; STACK_ALLOC(int, masks, no);
int tdims[6]; int tdims[6];
cuda_loop_schedule(o_shape, masks, tdims); cuda_loop_schedule(o_shape, masks, tdims);
for (int i=0; i<no; i++) { for (int i=0; i<no; i++) {

View File

@ -254,10 +254,12 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
continue; continue;
processed.insert(input_names[i]); processed.insert(input_names[i]);
auto src = read_all(input_names[i]); auto src = read_all(input_names[i]);
auto back = input_names[i].back();
// *.lib
if (back == 'b') continue;
ASSERT(src.size()) << "Source read failed:" << input_names[i] << "cmd:" << cmd; ASSERT(src.size()) << "Source read failed:" << input_names[i] << "cmd:" << cmd;
auto hash = S(hash64(src)); auto hash = S(hash64(src));
vector<string> new_names; vector<string> new_names;
auto back = input_names[i].back();
// *.obj, *.o, *.pyd // *.obj, *.o, *.pyd
if (back != 'j' && back != 'o' && back != 'd') if (back != 'j' && back != 'o' && back != 'd')
process(src, new_names, cmd); process(src, new_names, cmd);

View File

@ -6,6 +6,7 @@ def_path = sys.argv[-1]
# print(sys.argv) # print(sys.argv)
dumpbin_path = os.environ.get("dumpbin_path", "dumpbin") dumpbin_path = os.environ.get("dumpbin_path", "dumpbin")
export_all = os.environ.get("EXPORT_ALL", "0")=="1"
syms = {} syms = {}
@ -22,6 +23,11 @@ for obj in sys.argv[1:-2]:
if sym.startswith("??$get_from_env"): syms[sym] = 1 if sym.startswith("??$get_from_env"): syms[sym] = 1
# if sym.startswith("??"): continue # if sym.startswith("??"): continue
if sym.startswith("my"): syms[sym] = 1 if sym.startswith("my"): syms[sym] = 1
# for cutt
if "custom_cuda" in sym: syms[sym] = 1
if "cutt" in sym: syms[sym] = 1
if "_cudaGetErrorEnum" in sym: syms[sym] = 1
if export_all: syms[sym] = 1
if "jittor" not in sym: continue if "jittor" not in sym: continue
syms[sym] = 1 syms[sym] = 1
# print(ret) # print(ret)

View File

@ -219,7 +219,7 @@ if os.name=='nt' and getattr(mp.current_process(), '_inheriting', False):
os.environ["log_silent"] = '1' os.environ["log_silent"] = '1'
if os.environ.get("DISABLE_MULTIPROCESSING", '0') == '1': if os.environ.get("DISABLE_MULTIPROCESSING", '0') == '1':
os.environ["use_parallel_op_compiler"] = '1' os.environ["use_parallel_op_compiler"] = '0'
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"): def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ] cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
n = len(cmds) n = len(cmds)
@ -278,11 +278,12 @@ def find_cache_path():
def get_version(output): def get_version(output):
if output.endswith("mpicc"): if output.endswith("mpicc"):
version = run_cmd(output+" --showme:version") version = run_cmd(f"\"{output}\" --showme:version")
elif os.name == 'nt' and output.endswith("cl"): elif os.name == 'nt' and (
output.endswith("cl") or output.endswith("cl.exe")):
version = run_cmd(output) version = run_cmd(output)
else: else:
version = run_cmd(output+" --version") version = run_cmd(f"\"{output}\" --version")
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version) v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
if len(v) == 0: if len(v) == 0:
v = re.findall("[0-9]+\\.[0-9]+", version) v = re.findall("[0-9]+\\.[0-9]+", version)
@ -427,7 +428,7 @@ msvc_path = ""
if os.name == 'nt' and os.environ.get("cc_path", "")=="": if os.name == 'nt' and os.environ.get("cc_path", "")=="":
from pathlib import Path from pathlib import Path
msvc_path = os.path.join(str(Path.home()), ".cache", "jittor", "msvc") msvc_path = os.path.join(str(Path.home()), ".cache", "jittor", "msvc")
cc_path = os.path.join(msvc_path, "cl_x64", "bin", "cl") cc_path = os.path.join(msvc_path, "VC", r"_\_\_\_\_\bin", "cl.exe")
check_msvc_install = True check_msvc_install = True
else: else:
cc_path = env_or_find('cc_path', 'g++', silent=True) cc_path = env_or_find('cc_path', 'g++', silent=True)

View File

@ -12,7 +12,21 @@ from jittor_utils import LOG
from jittor_utils.misc import download_url_to_local from jittor_utils.misc import download_url_to_local
import pathlib import pathlib
def get_cuda_driver_win():
try:
import ctypes
cuda_driver = ctypes.CDLL(r"nvcuda")
driver_version = ctypes.c_int()
r = cuda_driver.cuDriverGetVersion(ctypes.byref(driver_version))
if r != 0: return None
v = driver_version.value
return [v//1000, v%1000//10, v%10]
except:
return None
def get_cuda_driver(): def get_cuda_driver():
if os.name == 'nt':
return get_cuda_driver_win()
ret, out = sp.getstatusoutput("nvidia-smi -q -u") ret, out = sp.getstatusoutput("nvidia-smi -q -u")
if ret != 0: return None if ret != 0: return None
try: try:
@ -35,21 +49,34 @@ def install_cuda():
if not cuda_driver_version: if not cuda_driver_version:
return None return None
LOG.i("cuda_driver_version: ", cuda_driver_version) LOG.i("cuda_driver_version: ", cuda_driver_version)
if "JTCUDA_VERSION" in os.environ:
cuda_driver_version = list(map(int,os.enviroment["JTCUDA_VERSION"].split(".")))
LOG.i("JTCUDA_VERSION: ", cuda_driver_version)
if cuda_driver_version >= [11,2]: if os.name == 'nt':
cuda_tgz = "cuda11.2_cudnn8_linux.tgz" if cuda_driver_version >= [11,4]:
md5 = "b93a1a5d19098e93450ee080509e9836" cuda_tgz = "cuda11.4_cudnn8_win.zip"
elif cuda_driver_version >= [11,]: md5 = "06eed370d0d44bb2cc57809343911187"
cuda_tgz = "cuda11.0_cudnn8_linux.tgz" elif cuda_driver_version >= [10,]:
md5 = "5dbdb43e35b4db8249027997720bf1ca" cuda_tgz = "cuda10.2_cudnn7_win.zip"
elif cuda_driver_version >= [10,2]: md5 = "7dd9963833a91371299a2ba58779dd71"
cuda_tgz = "cuda10.2_cudnn7_linux.tgz" else:
md5 = "40f0563e8eb176f53e55943f6d212ad7" raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.2")
elif cuda_driver_version >= [10,]:
cuda_tgz = "cuda10.0_cudnn7_linux.tgz"
md5 = "f16d3ff63f081031d21faec3ec8b7dac"
else: else:
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}") if cuda_driver_version >= [11,2]:
cuda_tgz = "cuda11.2_cudnn8_linux.tgz"
md5 = "b93a1a5d19098e93450ee080509e9836"
elif cuda_driver_version >= [11,]:
cuda_tgz = "cuda11.0_cudnn8_linux.tgz"
md5 = "5dbdb43e35b4db8249027997720bf1ca"
elif cuda_driver_version >= [10,2]:
cuda_tgz = "cuda10.2_cudnn7_linux.tgz"
md5 = "40f0563e8eb176f53e55943f6d212ad7"
elif cuda_driver_version >= [10,]:
cuda_tgz = "cuda10.0_cudnn7_linux.tgz"
md5 = "f16d3ff63f081031d21faec3ec8b7dac"
else:
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.0")
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda") jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc") nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc")
nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64") nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64")
@ -65,9 +92,14 @@ def install_cuda():
download_url_to_local("https://cg.cs.tsinghua.edu.cn/jittor/assets/"+cuda_tgz, cuda_tgz, jtcuda_path, md5) download_url_to_local("https://cg.cs.tsinghua.edu.cn/jittor/assets/"+cuda_tgz, cuda_tgz, jtcuda_path, md5)
import tarfile if cuda_tgz.endswith(".zip"):
with tarfile.open(cuda_tgz_path, "r") as tar: import zipfile
tar.extractall(cuda_tgz_path[:-4]) zf = zipfile.ZipFile(cuda_tgz_path)
zf.extractall(path=cuda_tgz_path[:-4])
else:
import tarfile
with tarfile.open(cuda_tgz_path, "r") as tar:
tar.extractall(cuda_tgz_path[:-4])
assert os.path.isfile(nvcc_path) assert os.path.isfile(nvcc_path)
return nvcc_path return nvcc_path

View File

@ -8,7 +8,7 @@ def install(path):
LOG.i("Installing MSVC...") LOG.i("Installing MSVC...")
filename = "msvc.zip" filename = "msvc.zip"
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename
md5sum = "0fd71436c034808649b24baf28998ccc" md5sum = "929c4b86ea68160121f73c3edf6b5463"
download_url_to_local(url, filename, path, md5sum) download_url_to_local(url, filename, path, md5sum)
fullname = os.path.join(path, filename) fullname = os.path.join(path, filename)
import zipfile import zipfile

View File

@ -21,7 +21,10 @@ def ensure_dir(dir_path):
os.makedirs(dir_path) os.makedirs(dir_path)
def _progress(): def _progress():
pbar = tqdm(total=None) pbar = tqdm(total=None,
unit="B",
unit_scale=True,
unit_divisor=1024)
def bar_update(block_num, block_size, total_size): def bar_update(block_num, block_size, total_size):
""" reporthook """ reporthook