mirror of https://github.com/Jittor/Jittor
support cuda win
This commit is contained in:
parent
d85af13024
commit
123e915bb3
|
@ -11,8 +11,17 @@ from jittor_utils import run_cmd, get_version, get_int_version
|
|||
from jittor_utils.misc import download_url_to_local
|
||||
|
||||
def search_file(dirs, name, prefer_version=()):
|
||||
if os.name == 'nt':
|
||||
if name.startswith("lib"):
|
||||
name = name[3:].replace(".so", "64*.dll")
|
||||
for d in dirs:
|
||||
fname = os.path.join(d, name)
|
||||
if os.name == 'nt':
|
||||
lname = os.path.join(d, name)
|
||||
names = glob.glob(lname)
|
||||
if len(names):
|
||||
return names[0]
|
||||
continue
|
||||
prefer_version = tuple( str(p) for p in prefer_version )
|
||||
for i in range(len(prefer_version),-1,-1):
|
||||
vname = ".".join((fname,)+prefer_version[:i])
|
||||
|
@ -122,8 +131,7 @@ def setup_mkl():
|
|||
mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll')
|
||||
mkl_bin_path = os.path.join(mkl_home, 'bin')
|
||||
os.add_dll_directory(mkl_bin_path)
|
||||
mkl_lib = os.path.join(mkl_lib_path, "dnnl.lib")
|
||||
extra_flags = f" -I\"{mkl_include_path}\" \"{mkl_lib}\" "
|
||||
extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -ldnnl "
|
||||
assert os.path.isdir(mkl_include_path)
|
||||
assert os.path.isdir(mkl_lib_path)
|
||||
assert os.path.isfile(mkl_lib_name)
|
||||
|
@ -156,17 +164,17 @@ def install_cub(root_folder):
|
|||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, filename.replace(".tgz",""))
|
||||
|
||||
if not os.path.isfile(os.path.join(dirname, "examples", "test")):
|
||||
if not os.path.isfile(os.path.join(dirname, "examples", "device/example_device_radix_sort.cu")):
|
||||
LOG.i("Downloading cub...")
|
||||
download_url_to_local(url, filename, root_folder, md5)
|
||||
import tarfile
|
||||
|
||||
with tarfile.open(fullname, "r") as tar:
|
||||
tar.extractall(root_folder)
|
||||
assert 0 == os.system(f"cd {dirname}/examples && "
|
||||
f"{nvcc_path} --cudart=shared -ccbin=\"{cc_path}\" device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
|
||||
if core.get_device_count():
|
||||
assert 0 == os.system(f"cd {dirname}/examples && ./test")
|
||||
# assert 0 == os.system(f"cd {dirname}/examples && "
|
||||
# f"{nvcc_path} --cudart=shared -ccbin=\"{cc_path}\" device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
|
||||
# if core.get_device_count():
|
||||
# assert 0 == os.system(f"cd {dirname}/examples && ./test")
|
||||
return dirname
|
||||
|
||||
def setup_cub():
|
||||
|
@ -191,8 +199,9 @@ def setup_cuda_extern():
|
|||
cuda_extern_src = os.path.join(jittor_path, "extern", "cuda", "src")
|
||||
cuda_extern_files = [os.path.join(cuda_extern_src, name)
|
||||
for name in os.listdir(cuda_extern_src)]
|
||||
so_name = os.path.join(cache_path_cuda, "cuda_extern.so")
|
||||
compile(cc_path, cc_flags+f" -I'{cuda_include}' ", cuda_extern_files, so_name)
|
||||
so_name = os.path.join(cache_path_cuda, "cuda_extern"+so)
|
||||
compile(cc_path, cc_flags+f" -I\"{cuda_include}\" ", cuda_extern_files, so_name)
|
||||
link_cuda_extern = f" -L\"{cache_path_cuda}\" -lcuda_extern "
|
||||
ctypes.CDLL(so_name, dlopen_flags)
|
||||
|
||||
try:
|
||||
|
@ -205,7 +214,7 @@ def setup_cuda_extern():
|
|||
libs = ["cublas", "cudnn", "curand"]
|
||||
for lib_name in libs:
|
||||
try:
|
||||
setup_cuda_lib(lib_name)
|
||||
setup_cuda_lib(lib_name, extra_flags=link_cuda_extern)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
line = traceback.format_exc()
|
||||
|
@ -244,12 +253,12 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
prefer_version = ()
|
||||
if nvcc_version[0] == 11:
|
||||
prefer_version = ("8",)
|
||||
culib_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"lib{lib_name}.so", prefer_version)
|
||||
culib_path = search_file([cuda_bin, cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"lib{lib_name}.so", prefer_version)
|
||||
|
||||
if lib_name == "cublas" and nvcc_version[0] >= 10:
|
||||
# manual link libcublasLt.so
|
||||
try:
|
||||
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version)
|
||||
cublas_lt_lib_path = search_file([cuda_bin, cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version)
|
||||
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
|
||||
except:
|
||||
# some aarch64 os, such as uos with FT2000 cpu,
|
||||
|
@ -263,12 +272,12 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
if nvcc_version >= (11,0,0):
|
||||
libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"]
|
||||
for l in libs:
|
||||
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], l, prefer_version)
|
||||
ex_cudnn_path = search_file([cuda_bin, cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], l, prefer_version)
|
||||
ctypes.CDLL(ex_cudnn_path, dlopen_flags)
|
||||
|
||||
# dynamic link cuda library
|
||||
ctypes.CDLL(culib_path, dlopen_flags)
|
||||
link_flags = f"-l{lib_name} -L'{cuda_lib}'"
|
||||
link_flags = f"-l{lib_name} -L\"{cuda_lib}\""
|
||||
|
||||
# find all source files
|
||||
culib_src_dir = os.path.join(jittor_path, "extern", "cuda", lib_name)
|
||||
|
@ -281,7 +290,7 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
|
||||
# compile and get operators
|
||||
culib = compile_custom_ops(culib_src_files, return_module=True,
|
||||
extra_flags=f" -I'{jt_cuda_include}' -I'{jt_culib_include}' {link_flags} {extra_flags} ")
|
||||
extra_flags=f" -I\"{jt_cuda_include}\" -I\"{jt_culib_include}\" {link_flags} {extra_flags} ")
|
||||
culib_ops = culib.ops
|
||||
globals()[lib_name+"_ops"] = culib_ops
|
||||
globals()[lib_name] = culib
|
||||
|
@ -289,19 +298,20 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
|
||||
def install_cutt(root_folder):
|
||||
# Modified from: https://github.com/ap-hynninen/cutt
|
||||
url = "https://codeload.github.com/Jittor/cutt/zip/v1.1"
|
||||
url = "https://codeload.github.com/Jittor/cutt/zip/v1.2"
|
||||
|
||||
filename = "cutt-1.1.zip"
|
||||
filename = "cutt-1.2.zip"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, filename.replace(".zip",""))
|
||||
true_md5 = "7bb71cf7c49dbe57772539bf043778f7"
|
||||
true_md5 = "14d0fd1132c8cd657dc3cf29ce4db931"
|
||||
|
||||
if os.path.exists(fullname):
|
||||
md5 = run_cmd('md5sum '+fullname).split()[0]
|
||||
from jittor_utils.misc import calculate_md5
|
||||
md5 = calculate_md5(fullname)
|
||||
if md5 != true_md5:
|
||||
os.remove(fullname)
|
||||
shutil.rmtree(dirname)
|
||||
if not os.path.isfile(os.path.join(dirname, "bin", "cutt_test")):
|
||||
if not os.path.isfile(os.path.join(dirname, "lib/libcutt"+so)):
|
||||
LOG.i("Downloading cutt...")
|
||||
download_url_to_local(url, filename, root_folder, true_md5)
|
||||
|
||||
|
@ -320,7 +330,17 @@ def install_cutt(root_folder):
|
|||
if len(flags.cuda_archs):
|
||||
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
|
||||
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||
run_cmd(f"make NVCC_GENCODE='{arch_flag} --cudart=shared -ccbin=\"{cc_path}\" ' nvcc_path='{nvcc_path}'", cwd=dirname)
|
||||
cutt_include = f" -I\"{dirname}/include\" -I\"{dirname}/src\" "
|
||||
files = glob.glob(dirname+"/src/*.c*", recursive=True)
|
||||
files2 = []
|
||||
for f in files:
|
||||
if f.endswith("cutt_bench.cpp") or \
|
||||
f.endswith("cutt_test.cpp"):
|
||||
continue
|
||||
files2.append(f)
|
||||
cutt_flags = cc_flags+opt_flags+cutt_include
|
||||
os.makedirs(dirname+"/lib", exist_ok=True)
|
||||
compile(cc_path, cutt_flags, files2, dirname+"/lib/libcutt"+so, cuda_flags=arch_flag)
|
||||
return dirname
|
||||
|
||||
def setup_cutt():
|
||||
|
@ -342,11 +362,11 @@ def setup_cutt():
|
|||
|
||||
make_cache_dir(cutt_path)
|
||||
install_cutt(cutt_path)
|
||||
cutt_home = os.path.join(cutt_path, "cutt-1.1")
|
||||
cutt_home = os.path.join(cutt_path, "cutt-1.2")
|
||||
cutt_include_path = os.path.join(cutt_home, "src")
|
||||
cutt_lib_path = os.path.join(cutt_home, "lib")
|
||||
|
||||
cutt_lib_name = os.path.join(cutt_lib_path, "libcutt.so")
|
||||
cutt_lib_name = os.path.join(cutt_lib_path, "libcutt"+so)
|
||||
assert os.path.isdir(cutt_include_path)
|
||||
assert os.path.isdir(cutt_lib_path)
|
||||
assert os.path.isfile(cutt_lib_name), cutt_lib_name
|
||||
|
@ -354,12 +374,14 @@ def setup_cutt():
|
|||
LOG.v(f"cutt_lib_path: {cutt_lib_path}")
|
||||
LOG.v(f"cutt_lib_name: {cutt_lib_name}")
|
||||
# We do not link manualy, link in custom ops
|
||||
if os.name == "nt":
|
||||
os.add_dll_directory(cutt_lib_path)
|
||||
ctypes.CDLL(cutt_lib_name, dlopen_flags)
|
||||
|
||||
cutt_op_dir = os.path.join(jittor_path, "extern", "cuda", "cutt", "ops")
|
||||
cutt_op_files = [os.path.join(cutt_op_dir, name) for name in os.listdir(cutt_op_dir)]
|
||||
cutt_ops = compile_custom_ops(cutt_op_files,
|
||||
extra_flags=f" -I'{cutt_include_path}'")
|
||||
extra_flags=f" -I\"{cutt_include_path}\" -L\"{cutt_lib_path}\" -llibcutt ")
|
||||
LOG.vv("Get cutt_ops: "+str(dir(cutt_ops)))
|
||||
|
||||
|
||||
|
@ -442,7 +464,7 @@ def setup_nccl():
|
|||
nccl_src_files.append(os.path.join(r, fname))
|
||||
|
||||
nccl_ops = compile_custom_ops(nccl_src_files,
|
||||
extra_flags=f" -I'{nccl_include_path}' {mpi_compile_flags} ")
|
||||
extra_flags=f" -I\"{nccl_include_path}\" {mpi_compile_flags} ")
|
||||
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
|
||||
|
||||
def manual_link(flags):
|
||||
|
@ -498,7 +520,7 @@ def setup_mpi():
|
|||
mpi_src_files.append(os.path.join(r, fname))
|
||||
|
||||
# mpi compile flags add for nccl
|
||||
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
|
||||
mpi_compile_flags += f" -I\"{os.path.join(mpi_src_dir, 'inc')}\" "
|
||||
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
|
||||
|
||||
mpi_version = get_version(mpicc_path)
|
||||
|
|
|
@ -34,25 +34,62 @@ def make_cache_dir(cache_path):
|
|||
LOG.i(f"Create cache dir: {cache_path}")
|
||||
os.mkdir(cache_path)
|
||||
|
||||
def shsplit(s):
|
||||
s1 = s.split(' ')
|
||||
s2 = []
|
||||
count = 0
|
||||
for s in s1:
|
||||
nc = s.count('"') + s.count('\'')
|
||||
if count&1:
|
||||
count += nc
|
||||
s2[-1] += " "
|
||||
s2[-1] += s
|
||||
else:
|
||||
count = nc
|
||||
s2.append(s)
|
||||
return s2
|
||||
|
||||
|
||||
def remove_flags(flags, rm_flags):
|
||||
flags = flags.split(" ")
|
||||
flags = shsplit(flags)
|
||||
output = []
|
||||
for s in flags:
|
||||
ss = s.replace("\"", "")
|
||||
for rm in rm_flags:
|
||||
if s.startswith(rm):
|
||||
if ss.startswith(rm) or ss.endswith(rm):
|
||||
break
|
||||
else:
|
||||
output.append(s)
|
||||
return " ".join(output)
|
||||
|
||||
def compile(compiler, flags, inputs, output, combind_build=False):
|
||||
def moveback_flags(flags, rm_flags):
|
||||
flags = shsplit(flags)
|
||||
output = []
|
||||
output2 = []
|
||||
for s in flags:
|
||||
ss = s.replace("\"", "")
|
||||
for rm in rm_flags:
|
||||
if ss.startswith(rm) or ss.endswith(rm):
|
||||
output2.append(s)
|
||||
break
|
||||
else:
|
||||
output.append(s)
|
||||
return " ".join(output+output2)
|
||||
|
||||
def map_flags(flags, func):
|
||||
flags = shsplit(flags)
|
||||
output = []
|
||||
for s in flags:
|
||||
output.append(func(s))
|
||||
return " ".join(output)
|
||||
|
||||
def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags=""):
|
||||
def do_compile(cmd):
|
||||
if jit_utils.cc:
|
||||
return jit_utils.cc.cache_compile(cmd, cache_path, jittor_path)
|
||||
else:
|
||||
run_cmd(cmd)
|
||||
return True
|
||||
link = link_flags
|
||||
base_output = os.path.basename(output).split('.')[0]
|
||||
if os.name == 'nt':
|
||||
# windows do not combind build, need gen def
|
||||
|
@ -64,18 +101,9 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
|||
# initialize order in windows seems reversed
|
||||
inputs = list(inputs[::-1])
|
||||
link = link + f' -Wl,--export-all-symbols,--out-implib,"{afile}" '
|
||||
if base_output == "jit_utils_core":
|
||||
pass
|
||||
elif base_output == "jittor_core":
|
||||
inputs.append(os.path.join(cache_path, f"jit_utils_core{lib_suffix}"))
|
||||
else:
|
||||
inputs.append(os.path.join(cache_path, f"jit_utils_core{lib_suffix}"))
|
||||
inputs.append(os.path.join(cache_path, f"jittor_core{lib_suffix}"))
|
||||
|
||||
# if output is core, add core_link_flags
|
||||
if output.startswith("jittor_core"):
|
||||
link = link + core_link_flags
|
||||
output = os.path.join(cache_path, output)
|
||||
if not os.path.isabs(output):
|
||||
output = os.path.join(cache_path, output)
|
||||
# don't recompile object file in inputs
|
||||
obj_files = []
|
||||
ex_obj_files = []
|
||||
|
@ -94,30 +122,33 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
|||
return do_compile(fix_cl_flags(cmd))
|
||||
# split compile object file and link
|
||||
# remove -l -L flags when compile object files
|
||||
oflags = remove_flags(flags, ['-l', '-L', '-Wl,'])
|
||||
oflags = remove_flags(flags, ['-l', '-L', '-Wl,', '.lib', '-shared'])
|
||||
cmds = []
|
||||
for input, obj_file in zip(inputs, obj_files):
|
||||
cc = compiler
|
||||
nflags = oflags
|
||||
cmd = f"{input} {nflags} {lto_flags} -c -o {obj_file}"
|
||||
if input.endswith(".cu"):
|
||||
if has_cuda:
|
||||
nflags = convert_nvcc_flags(oflags)
|
||||
cc = nvcc_path
|
||||
cmd = f"\"{nvcc_path}\" {cuda_flags} {cmd}"
|
||||
cmd = convert_nvcc_flags(fix_cl_flags(cmd))
|
||||
else:
|
||||
continue
|
||||
cmd = f"\"{cc}\" {input} {nflags} {lto_flags} -c -o {obj_file}"
|
||||
else:
|
||||
cmd = f"\"{cc}\" {cmd}"
|
||||
cmd = fix_cl_flags(cmd)
|
||||
if "nan_checker" in input:
|
||||
# nan checker needs to disable fast_math
|
||||
cmd = cmd.replace("--use_fast_math", "")
|
||||
cmd = cmd.replace("-Ofast", "-O2")
|
||||
cmds.append(fix_cl_flags(cmd))
|
||||
cmds.append(cmd)
|
||||
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
|
||||
obj_files += ex_obj_files
|
||||
if os.name == 'nt':
|
||||
dumpdef_path = os.path.join(jittor_path, "utils", "dumpdef.py")
|
||||
cmd = f"\"{sys.executable}\" \"{dumpdef_path}\" {' '.join(obj_files)} -Fo: \"{output}.def\""
|
||||
do_compile(fix_cl_flags(cmd))
|
||||
cmd = f"\"{compiler}\" {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}"
|
||||
cmd = f"\"{compiler}\" {' '.join(obj_files)} -o {output} {flags} {lto_flags}"
|
||||
return do_compile(fix_cl_flags(cmd))
|
||||
|
||||
def gen_jit_tests():
|
||||
|
@ -673,11 +704,15 @@ def compile_custom_ops(
|
|||
|
||||
op_extra_flags = includes + extra_flags
|
||||
|
||||
lib_path = os.path.join(cache_path, "custom_ops")
|
||||
make_cache_dir(lib_path)
|
||||
gen_src_fname = os.path.join(lib_path, gen_name+".cc")
|
||||
gen_head_fname = os.path.join(lib_path, gen_name+".h")
|
||||
gen_lib = os.path.join(lib_path, gen_name+extension_suffix)
|
||||
libname = gen_name + lib_suffix
|
||||
op_extra_flags += f" -L\"{lib_path}\" -l\"{libname}\" "
|
||||
|
||||
gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags)
|
||||
make_cache_dir(os.path.join(cache_path, "custom_ops"))
|
||||
gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
|
||||
gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
|
||||
gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
|
||||
pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src)
|
||||
# gen src initialize first
|
||||
builds.insert(0, gen_src_fname)
|
||||
|
@ -794,8 +829,9 @@ def compile_extern():
|
|||
def check_cuda():
|
||||
if not nvcc_path:
|
||||
return
|
||||
global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include, cuda_home
|
||||
global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include, cuda_home, cuda_bin
|
||||
cuda_dir = os.path.dirname(get_full_path_of_executable(nvcc_path))
|
||||
cuda_bin = cuda_dir
|
||||
cuda_home = os.path.abspath(os.path.join(cuda_dir, ".."))
|
||||
# try default nvidia-cuda-toolkit in Ubuntu 20.04
|
||||
# assert cuda_dir.endswith("bin") and "cuda" in cuda_dir.lower(), f"Wrong cuda_dir: {cuda_dir}"
|
||||
|
@ -805,10 +841,25 @@ def check_cuda():
|
|||
# this nvcc is install by package manager
|
||||
cuda_lib = "/usr/lib/x86_64-linux-gnu"
|
||||
cuda_include2 = os.path.join(jittor_path, "extern","cuda","inc")
|
||||
cc_flags += f" -DHAS_CUDA -I'{cuda_include}' -I'{cuda_include2}' "
|
||||
core_link_flags += f" -lcudart -L'{cuda_lib}' "
|
||||
# ctypes.CDLL(cuda_lib+"/libcudart.so", import_flags)
|
||||
ctypes.CDLL(cuda_lib+"/libcudart.so", dlopen_flags)
|
||||
cc_flags += f" -DHAS_CUDA -I\"{cuda_include}\" -I\"{cuda_include2}\" "
|
||||
if os.name == 'nt':
|
||||
cuda_lib = os.path.abspath(os.path.join(cuda_dir, "..", "lib", "x64"))
|
||||
# cc_flags += f" \"{cuda_lib}\\cudart.lib\" "
|
||||
cuda_lib_path = glob.glob(cuda_bin+"/cudart64*")[0]
|
||||
cc_flags += f" -lcudart -L\"{cuda_lib}\" "
|
||||
os.add_dll_directory(cuda_dir)
|
||||
# dll = ctypes.CDLL("cudart64_110", dlopen_flags)
|
||||
dll = ctypes.CDLL(cuda_lib_path, dlopen_flags)
|
||||
cuda_driver = ctypes.CDLL(r"nvcuda", dlopen_flags)
|
||||
driver_version = ctypes.c_int()
|
||||
r = cuda_driver.cuDriverGetVersion(ctypes.byref(driver_version))
|
||||
print("version:", driver_version, r)
|
||||
ret = dll.cudaDeviceSynchronize()
|
||||
assert ret == 0
|
||||
else:
|
||||
cc_flags += f" -lcudart -L\"{cuda_lib}\" "
|
||||
# ctypes.CDLL(cuda_lib+"/libcudart.so", import_flags)
|
||||
ctypes.CDLL(cuda_lib+"/libcudart.so", dlopen_flags)
|
||||
has_cuda = 1
|
||||
|
||||
def check_cache_compile():
|
||||
|
@ -950,89 +1001,90 @@ if platform.system() == 'Darwin' and platform.machine() == 'arm64':
|
|||
if "cc_flags" in os.environ:
|
||||
cc_flags += os.environ["cc_flags"] + ' '
|
||||
|
||||
link_flags = " -lstdc++ -ldl -shared "
|
||||
cc_flags += " -lstdc++ -ldl -shared "
|
||||
if platform.system() == 'Darwin':
|
||||
# TODO: if not using apple clang, there is no need to add -lomp
|
||||
link_flags += "-undefined dynamic_lookup -lomp "
|
||||
cc_flags += "-undefined dynamic_lookup -lomp "
|
||||
if platform.machine() == "arm64":
|
||||
link_flags += " -L/opt/homebrew/lib "
|
||||
cc_flags += " -L/opt/homebrew/lib "
|
||||
|
||||
core_link_flags = ""
|
||||
opt_flags = ""
|
||||
|
||||
py_include = jit_utils.get_py3_include_path()
|
||||
LOG.i(f"py_include: {py_include}")
|
||||
extension_suffix = jit_utils.get_py3_extension_suffix()
|
||||
lib_suffix = extension_suffix.replace(".pyd", ".lib")
|
||||
lib_suffix = extension_suffix.rsplit(".", 1)[0]
|
||||
LOG.i(f"extension_suffix: {extension_suffix}")
|
||||
so = ".so" if os.name != 'nt' else ".dll"
|
||||
|
||||
|
||||
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags
|
||||
if platform.system() == 'Darwin':
|
||||
# TODO: if not using apple clang, cannot add -Xpreprocessor
|
||||
kernel_opt_flags = kernel_opt_flags + " -Xpreprocessor -fopenmp "
|
||||
kernel_opt_flags += " -Xpreprocessor -fopenmp "
|
||||
elif cc_type != 'cl':
|
||||
kernel_opt_flags = kernel_opt_flags + " -fopenmp "
|
||||
kernel_opt_flags += " -fopenmp "
|
||||
fix_cl_flags = lambda x:x
|
||||
if os.name == 'nt':
|
||||
if cc_type == 'g++':
|
||||
link_flags = link_flags.replace('-ldl', '')
|
||||
py3_link_path = '-L"' + os.path.join(
|
||||
os.path.dirname(sys.executable),
|
||||
"libs"
|
||||
) + f'" -lpython3{sys.version_info.minor} '
|
||||
core_link_flags = py3_link_path
|
||||
link_flags += core_link_flags
|
||||
# link_flags += " -Wl,--unresolved-symbols=ignore-all "
|
||||
# cc_flags += " -Xlinker --allow-shlib-undefined "
|
||||
cc_flags = cc_flags.replace('-std=c++14', '-std=c++17')
|
||||
link_flags += " -fopenmp "
|
||||
kernel_opt_flags += f" {cache_path}\\jit_utils_core{lib_suffix} "
|
||||
kernel_opt_flags += f" {cache_path}\\jittor_core{lib_suffix} "
|
||||
pass
|
||||
elif cc_type == 'cl':
|
||||
py3_link_path = os.path.join(
|
||||
os.path.dirname(sys.executable),
|
||||
"libs",
|
||||
f'python3{sys.version_info.minor}.lib'
|
||||
)
|
||||
# core_link_flags = py3_link_path
|
||||
link_flags += core_link_flags
|
||||
# link_flags += " -Wl,--unresolved-symbols=ignore-all "
|
||||
# cc_flags += " -Xlinker --allow-shlib-undefined "
|
||||
kernel_opt_flags += f" {cache_path}\\jit_utils_core{lib_suffix} "
|
||||
kernel_opt_flags += f" {cache_path}\\jittor_core{lib_suffix} "
|
||||
# cc_flags = " -std:c++17 -O2 -fp:fast -EHsc "
|
||||
cc_flags = " -std:c++17 -O2 -fp:fast -EHsc "
|
||||
# cc_flags += py3_link_path + " "
|
||||
cc_flags = cc_flags.replace("-std=c++14", "-std=c++17")
|
||||
cc_flags = cc_flags.replace("-lstdc++", "")
|
||||
cc_flags = cc_flags.replace("-ldl", "")
|
||||
cc_flags += f" -L\"{py3_link_path}\" -lpython3{sys.version_info.minor} "
|
||||
cc_flags += " -EHsc "
|
||||
import jittor_utils
|
||||
if jittor_utils.msvc_path:
|
||||
mp = jittor_utils.msvc_path
|
||||
cc_flags += f' -nologo -I"{mp}\\cl_x64\\include" -I"{mp}\\win10_kits\\include\\ucrt" -I"{mp}\\win10_kits\\include\\shared" -I"{mp}\\win10_kits\\include\\um" -DNOMINMAX '
|
||||
win_link_flags = f' -link -LIBPATH:"{mp}\\cl_x64\\lib" -LIBPATH:"{mp}\\win10_kits\\lib\\um\\x64" -LIBPATH:"{mp}\\win10_kits\\lib\\ucrt\\x64" '
|
||||
link_flags = ' -LD '
|
||||
kernel_opt_flags += win_link_flags# + " -EXPORT:\"?jit_run@FusedOp@jittor@@QEAAXXZ\""
|
||||
cc_flags += f' -nologo -I"{mp}\\VC\\include" -I"{mp}\\win10_kits\\include\\ucrt" -I"{mp}\\win10_kits\\include\\shared" -I"{mp}\\win10_kits\\include\\um" -DNOMINMAX '
|
||||
cc_flags += f' -L"{mp}\\VC\\lib" -L"{mp}\\win10_kits\\lib\\um\\x64" -L"{mp}\\win10_kits\\lib\\ucrt\\x64" '
|
||||
def fix_cl_flags(cmd):
|
||||
cmd = cmd.replace(".o ", ".obj ")
|
||||
cmd = cmd.replace(".o\" ", ".obj\" ")
|
||||
if cmd.endswith(".o"): cmd += "bj"
|
||||
from shlex import split
|
||||
if " -LD " in cmd:
|
||||
cmd = cmd.replace(" -o ", " -Fe: ")
|
||||
output = split(cmd.split("-Fe:")[1].strip(), posix=False)[0]
|
||||
base_output = os.path.basename(output).split('.')[0]
|
||||
cmd += win_link_flags
|
||||
cmd += f" -DEF:\"{output}.def\" -IGNORE:4102 -IGNORE:4197 -IGNORE:4217 {py3_link_path}"
|
||||
if base_output == "jit_utils_core":
|
||||
pass
|
||||
elif base_output == "jittor_core":
|
||||
cmd += " " + os.path.join(cache_path, f"jit_utils_core{lib_suffix}")
|
||||
else:
|
||||
cmd += " " + os.path.join(cache_path, f"jit_utils_core{lib_suffix} ")
|
||||
cmd += " " + os.path.join(cache_path, f"jittor_core{lib_suffix} ")
|
||||
if " -o " in cmd:
|
||||
if " -shared " in cmd:
|
||||
cmd = cmd.replace(" -o ", " -Fe: ")
|
||||
output = shsplit(cmd.split("-Fe:")[1].strip())[0]
|
||||
base_output = os.path.basename(output).split('.')[0]
|
||||
cmd += f" -DEF:\"{output}.def\" -IGNORE:4102 -IGNORE:4197 -IGNORE:4217 "
|
||||
|
||||
elif " -c -o " in cmd:
|
||||
cmd = cmd.replace(" -c -o ", " -c -Fo: ")
|
||||
elif " -c -o " in cmd:
|
||||
cmd = cmd.replace(" -c -o ", " -c -Fo: ")
|
||||
flags = shsplit(cmd)
|
||||
output = []
|
||||
output2 = []
|
||||
for f in flags:
|
||||
if f.startswith("-link"):
|
||||
pass
|
||||
elif f.startswith("-l"):
|
||||
output2.append(f[2:]+".lib")
|
||||
elif f.startswith("-LIB"):
|
||||
output2.append(f)
|
||||
elif f.startswith("-LD"):
|
||||
output.append(f)
|
||||
elif f.startswith("-L"):
|
||||
output2.append("-LIBPATH:"+f[2:])
|
||||
elif ".lib" in f:
|
||||
output2.append(f)
|
||||
elif f.startswith("-DEF:"):
|
||||
output2.append(f)
|
||||
elif f.startswith("-W") or f.startswith("-f"):
|
||||
pass
|
||||
elif f.startswith("-std="):
|
||||
output.append(f.replace("=", ":"))
|
||||
else:
|
||||
output.append(f)
|
||||
cmd = " ".join(output)
|
||||
if len(output2):
|
||||
cmd += " -link " + " ".join(output2)
|
||||
cmd = cmd.replace("-include", "-FI")
|
||||
cmd = cmd.replace("-shared", "-LD")
|
||||
return cmd
|
||||
|
||||
if ' -O' not in cc_flags:
|
||||
|
@ -1055,7 +1107,7 @@ ck_path = os.path.join(cache_path, "checkpoints")
|
|||
make_cache_dir(ck_path)
|
||||
|
||||
# build cache_compile
|
||||
cc_flags += f" -I{os.path.join(jittor_path, 'src')} "
|
||||
cc_flags += f" -I\"{os.path.join(jittor_path, 'src')}\" "
|
||||
cc_flags += py_include
|
||||
check_cache_compile()
|
||||
LOG.v(f"Get cache_compile: {jit_utils.cc}")
|
||||
|
@ -1065,9 +1117,28 @@ has_cuda = 0
|
|||
check_cuda()
|
||||
nvcc_flags = os.environ.get("nvcc_flags", "")
|
||||
if has_cuda:
|
||||
nvcc_flags += cc_flags + link_flags
|
||||
nvcc_flags += cc_flags
|
||||
def convert_nvcc_flags(nvcc_flags):
|
||||
# nvcc don't support -Wall option
|
||||
if os.name == 'nt':
|
||||
nvcc_flags = nvcc_flags.replace("-fp:", "-Xcompiler -fp:")
|
||||
nvcc_flags = nvcc_flags.replace("-EHsc", "-Xcompiler -EHsc")
|
||||
nvcc_flags = nvcc_flags.replace("-nologo", "")
|
||||
nvcc_flags = nvcc_flags.replace("-std:", "-std=")
|
||||
nvcc_flags = nvcc_flags.replace("-Fo:", "-o")
|
||||
nvcc_flags = nvcc_flags.replace("-LD", "-shared")
|
||||
nvcc_flags = nvcc_flags.replace("-LIBPATH:", "-L")
|
||||
nvcc_flags = nvcc_flags.replace("-link", "")
|
||||
def func(x):
|
||||
if ".lib" not in x: return x
|
||||
x = x.replace("\"", "")
|
||||
a = os.path.dirname(x)
|
||||
b = os.path.basename(x)
|
||||
if not b.endswith(".lib"):
|
||||
return x
|
||||
return f"-L\"{a}\" -l{b[:-4]}"
|
||||
nvcc_flags = map_flags(nvcc_flags, func)
|
||||
nvcc_flags = nvcc_flags.replace("-std=c++17", "-std=c++14 -Xcompiler -std:c++14")
|
||||
nvcc_flags = nvcc_flags.replace("-Wall", "")
|
||||
nvcc_flags = nvcc_flags.replace("-Wno-unknown-pragmas", "")
|
||||
nvcc_flags = nvcc_flags.replace("-fopenmp", "")
|
||||
|
@ -1075,10 +1146,10 @@ if has_cuda:
|
|||
nvcc_flags = nvcc_flags.replace("-Werror", "")
|
||||
nvcc_flags = nvcc_flags.replace("-fPIC", "-Xcompiler -fPIC")
|
||||
nvcc_flags = nvcc_flags.replace("-fdiagnostics", "-Xcompiler -fdiagnostics")
|
||||
nvcc_flags += f" -x cu --cudart=shared -ccbin='{cc_path}' --use_fast_math "
|
||||
nvcc_flags += f" -x cu --cudart=shared -ccbin=\"{cc_path}\" --use_fast_math "
|
||||
# nvcc warning is noise
|
||||
nvcc_flags += " -w "
|
||||
nvcc_flags += f" -I'{os.path.join(jittor_path, 'extern/cuda/inc')}' "
|
||||
nvcc_flags += f" -I\"{os.path.join(jittor_path, 'extern/cuda/inc')}\" "
|
||||
if os.environ.get("cuda_debug", "0") == "1":
|
||||
nvcc_flags += " -G "
|
||||
return nvcc_flags
|
||||
|
@ -1092,7 +1163,7 @@ jit_src = gen_jit_op_maker(op_headers)
|
|||
LOG.vvvv(jit_src)
|
||||
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
|
||||
f.write(jit_src)
|
||||
cc_flags += f' -I{cache_path} '
|
||||
cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" '
|
||||
# gen pyjt
|
||||
pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
|
||||
|
||||
|
@ -1164,16 +1235,19 @@ if use_data_gz:
|
|||
f.write(data.decode("utf8"))
|
||||
dflags = (cc_flags+opt_flags)\
|
||||
.replace("-Wall", "") \
|
||||
.replace("-Werror", "")
|
||||
.replace("-Werror", "") \
|
||||
.replace("-shared", "")
|
||||
vdp = os.path.join(jittor_path, "src", "utils", "vdp")
|
||||
run_cmd(fix_cl_flags(f"{cc_path} {dflags} -include {vdp} {data_s_path} -c -o {data_o_path}"))
|
||||
os.remove(data_s_path)
|
||||
# os.remove(data_s_path)
|
||||
with open(data_gz_md5_path, 'w') as f:
|
||||
f.write(md5)
|
||||
files.append(data_o_path)
|
||||
files = [f for f in files if "__data__" not in f]
|
||||
|
||||
cc_flags += f" -l\"jit_utils_core{lib_suffix}\" "
|
||||
compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix)
|
||||
cc_flags += f" -l\"jittor_core{lib_suffix}\" "
|
||||
|
||||
# TODO: move to compile_extern.py
|
||||
# compile_extern()
|
||||
|
@ -1182,6 +1256,8 @@ with jit_utils.import_scope(import_flags):
|
|||
import jittor_core as core
|
||||
|
||||
flags = core.flags()
|
||||
nvcc_flags = convert_nvcc_flags(cc_flags)
|
||||
|
||||
if has_cuda:
|
||||
if len(flags.cuda_archs):
|
||||
nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} "
|
||||
|
@ -1189,7 +1265,7 @@ if has_cuda:
|
|||
|
||||
flags.cc_path = cc_path
|
||||
flags.cc_type = cc_type
|
||||
flags.cc_flags = cc_flags + link_flags + kernel_opt_flags
|
||||
flags.cc_flags = cc_flags + kernel_opt_flags
|
||||
flags.nvcc_path = nvcc_path
|
||||
flags.nvcc_flags = nvcc_flags
|
||||
flags.python_path = python_path
|
||||
|
|
|
@ -190,7 +190,7 @@ void CudnnConv3dBackwardWOp::jit_run() {
|
|||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_algos];
|
||||
STACK_ALLOC(cudnnConvolutionBwdFilterAlgoPerf_t,perf_results,num_algos);
|
||||
cudnnConvolutionBwdFilterAlgo_t algo;
|
||||
bool benchmark=true;
|
||||
|
||||
|
|
|
@ -181,7 +181,7 @@ void CudnnConv3dBackwardXOp::jit_run() {
|
|||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos];
|
||||
STACK_ALLOC(cudnnConvolutionBwdDataAlgoPerf_t,perf_results,num_algos);
|
||||
cudnnConvolutionBwdDataAlgo_t algo;
|
||||
bool benchmark=true;
|
||||
|
||||
|
|
|
@ -184,7 +184,7 @@ void CudnnConv3dOp::jit_run() {
|
|||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionFwdAlgoPerf_t perf_results[num_algos];
|
||||
STACK_ALLOC(cudnnConvolutionFwdAlgoPerf_t,perf_results,num_algos);
|
||||
cudnnConvolutionFwdAlgo_t algo;
|
||||
bool benchmark=true;
|
||||
|
||||
|
|
|
@ -180,7 +180,7 @@ void CudnnConvBackwardWOp::jit_run() {
|
|||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_algos];
|
||||
STACK_ALLOC(cudnnConvolutionBwdFilterAlgoPerf_t,perf_results,num_algos);
|
||||
cudnnConvolutionBwdFilterAlgo_t algo;
|
||||
bool benchmark=true;
|
||||
|
||||
|
|
|
@ -181,7 +181,7 @@ void CudnnConvBackwardXOp::jit_run() {
|
|||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos];
|
||||
STACK_ALLOC(cudnnConvolutionBwdDataAlgoPerf_t,perf_results,num_algos);
|
||||
cudnnConvolutionBwdDataAlgo_t algo;
|
||||
bool benchmark=true;
|
||||
|
||||
|
|
|
@ -183,7 +183,7 @@ void CudnnConvOp::jit_run() {
|
|||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionFwdAlgoPerf_t perf_results[num_algos];
|
||||
STACK_ALLOC(cudnnConvolutionFwdAlgoPerf_t,perf_results,num_algos);
|
||||
cudnnConvolutionFwdAlgo_t algo;
|
||||
bool benchmark=true;
|
||||
|
||||
|
|
|
@ -84,9 +84,6 @@ static double second (void)
|
|||
gettimeofday(&tv, NULL);
|
||||
return (double)tv.tv_sec + (double)tv.tv_usec / 1000000.0;
|
||||
}
|
||||
#else
|
||||
#error unsupported platform
|
||||
#endif
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
|
@ -995,4 +992,10 @@ int cudnn_test_entry( int argc, char** argv )
|
|||
doTest<half1>(algo, dimA, padA, convstrideA, filterdimA, filterFormat, mathType, benchmark);
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
int cudnn_test_entry( int argc, char** argv ) {
|
||||
return 0;
|
||||
}
|
||||
#endif
|
|
@ -472,9 +472,9 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
}
|
||||
#endif
|
||||
last_is_cuda = is_cuda;
|
||||
_JT_SEH_START2;
|
||||
// _JT_SEH_START2;
|
||||
op->do_run_after_prepare(jkl);
|
||||
_JT_SEH_END2;
|
||||
// _JT_SEH_END2;
|
||||
#ifdef HAS_CUDA
|
||||
// migrate to gpu
|
||||
if (PREDICT_BRANCH_NOT_TAKEN((!is_cuda && use_cuda && !use_cuda_managed_allocator))) {
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "utils/cache_compile.h"
|
||||
#include "utils/flags.h"
|
||||
#include "fused_op.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -32,6 +33,71 @@ DEFINE_FLAG(string, python_path, "", "Path of python interpreter");
|
|||
DEFINE_FLAG(string, cache_path, "", "Cache path of jittor");
|
||||
DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not");
|
||||
|
||||
#ifdef _MSC_VER
|
||||
vector<string> shsplit(const string& s) {
|
||||
auto s1 = split(s, " ");
|
||||
vector<string> s2;
|
||||
int count = 0;
|
||||
for (auto& s : s1) {
|
||||
int nc = 0;
|
||||
for (auto& c : s)
|
||||
nc += c=='"' || c=='\'';
|
||||
if (count&1) {
|
||||
count += nc;
|
||||
s2.back() += " ";
|
||||
s2.back() += s;
|
||||
} else {
|
||||
count = nc;
|
||||
s2.push_back(s);
|
||||
}
|
||||
}
|
||||
return s2;
|
||||
}
|
||||
|
||||
string fix_cl_flags(const string& cmd) {
|
||||
auto flags = shsplit(cmd);
|
||||
vector<string> output, output2;
|
||||
|
||||
for (auto& f : flags) {
|
||||
if (startswith(f, "-link"))
|
||||
continue;
|
||||
else if (startswith(f, "-l"))
|
||||
output2.push_back(f.substr(2)+".lib");
|
||||
else if (startswith(f, "-LIB"))
|
||||
output2.push_back(f);
|
||||
else if (startswith(f, "-LD"))
|
||||
output.push_back(f);
|
||||
else if (startswith(f, "-L"))
|
||||
output2.push_back("-LIBPATH:"+f.substr(2));
|
||||
else if (f.find(".lib") != string::npos)
|
||||
output2.push_back(f);
|
||||
else if (startswith(f, "-DEF:"))
|
||||
output2.push_back(f);
|
||||
else if (startswith(f, "-W") || startswith(f,"-f"))
|
||||
continue;
|
||||
else if (startswith(f,"-std="))
|
||||
output.push_back("-std:"+f.substr(5));
|
||||
else if (startswith(f,"-include"))
|
||||
output.push_back("-FI");
|
||||
else if (startswith(f,"-shared"))
|
||||
output.push_back("-LD");
|
||||
else
|
||||
output.push_back(f);
|
||||
}
|
||||
string cmdx = "";
|
||||
for (auto& s : output) {
|
||||
cmdx += s;
|
||||
cmdx += " ";
|
||||
}
|
||||
cmdx += "-link ";
|
||||
for (auto& s : output2) {
|
||||
cmdx += s;
|
||||
cmdx += " ";
|
||||
}
|
||||
return cmdx;
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace jit_compiler {
|
||||
|
||||
std::mutex dl_open_mutex;
|
||||
|
@ -103,6 +169,7 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
|||
write(jit_src_path, src);
|
||||
string cmd;
|
||||
|
||||
auto symbol_name = get_symbol_name(jit_key);
|
||||
#ifndef _MSC_VER
|
||||
if (is_cuda_op) {
|
||||
cmd = "\"" + nvcc_path + "\""
|
||||
|
@ -124,21 +191,18 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
|||
cmd = "\"" + nvcc_path + "\""
|
||||
+ " \"" + jit_src_path + "\"" + other_src
|
||||
+ nvcc_flags + extra_flags
|
||||
+ " -o \"" + jit_lib_path + "\"";
|
||||
+ " -o \"" + jit_lib_path + "\""
|
||||
+ " -Xlinker -EXPORT:\""
|
||||
+ symbol_name + "\"";;
|
||||
} else {
|
||||
auto symbol_name = get_symbol_name(jit_key);
|
||||
auto pos = cc_flags.find("-link");
|
||||
auto cc_flags1 = cc_flags.substr(0, pos);
|
||||
auto cc_flags2 = cc_flags.substr(pos);
|
||||
cmd = "\"" + cc_path + "\""
|
||||
+ " \"" + jit_src_path + "\"" + other_src
|
||||
+ cc_flags1 + extra_flags
|
||||
+ " -Fe: \"" + jit_lib_path + "\" " + cc_flags2 + " -EXPORT:\""
|
||||
+ " -Fe: \"" + jit_lib_path + "\" "
|
||||
+ fix_cl_flags(cc_flags + extra_flags) + " -EXPORT:\""
|
||||
+ symbol_name + "\"";
|
||||
}
|
||||
#endif
|
||||
cache_compile(cmd, cache_path, jittor_path);
|
||||
auto symbol_name = get_symbol_name(jit_key);
|
||||
auto jit_entry = load_jit_lib(jit_lib_path, symbol_name);
|
||||
return jit_entry;
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@ namespace jittor {
|
|||
using std::string_view;
|
||||
#elif defined(__GNUC__)
|
||||
using std::experimental::string_view;
|
||||
#elif __cplusplus < 201400L
|
||||
using string_view = string;
|
||||
#else
|
||||
using std::string_view;
|
||||
#endif
|
||||
|
|
|
@ -446,7 +446,7 @@ void GetitemOp::jit_prepare(JK& jk) {
|
|||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
int no = o_shape.size();
|
||||
int masks[no];
|
||||
STACK_ALLOC(int, masks, no);
|
||||
int tdims[6];
|
||||
cuda_loop_schedule(o_shape, masks, tdims);
|
||||
for (int i=0; i<no; i++) {
|
||||
|
|
|
@ -237,7 +237,7 @@ void SetitemOp::jit_prepare(JK& jk) {
|
|||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
int no = o_shape.size();
|
||||
int masks[no];
|
||||
STACK_ALLOC(int, masks, no);
|
||||
int tdims[6];
|
||||
cuda_loop_schedule(o_shape, masks, tdims);
|
||||
for (int i=0; i<no; i++) {
|
||||
|
|
|
@ -254,10 +254,12 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
|
|||
continue;
|
||||
processed.insert(input_names[i]);
|
||||
auto src = read_all(input_names[i]);
|
||||
auto back = input_names[i].back();
|
||||
// *.lib
|
||||
if (back == 'b') continue;
|
||||
ASSERT(src.size()) << "Source read failed:" << input_names[i] << "cmd:" << cmd;
|
||||
auto hash = S(hash64(src));
|
||||
vector<string> new_names;
|
||||
auto back = input_names[i].back();
|
||||
// *.obj, *.o, *.pyd
|
||||
if (back != 'j' && back != 'o' && back != 'd')
|
||||
process(src, new_names, cmd);
|
||||
|
|
|
@ -6,6 +6,7 @@ def_path = sys.argv[-1]
|
|||
|
||||
# print(sys.argv)
|
||||
dumpbin_path = os.environ.get("dumpbin_path", "dumpbin")
|
||||
export_all = os.environ.get("EXPORT_ALL", "0")=="1"
|
||||
|
||||
syms = {}
|
||||
|
||||
|
@ -22,6 +23,11 @@ for obj in sys.argv[1:-2]:
|
|||
if sym.startswith("??$get_from_env"): syms[sym] = 1
|
||||
# if sym.startswith("??"): continue
|
||||
if sym.startswith("my"): syms[sym] = 1
|
||||
# for cutt
|
||||
if "custom_cuda" in sym: syms[sym] = 1
|
||||
if "cutt" in sym: syms[sym] = 1
|
||||
if "_cudaGetErrorEnum" in sym: syms[sym] = 1
|
||||
if export_all: syms[sym] = 1
|
||||
if "jittor" not in sym: continue
|
||||
syms[sym] = 1
|
||||
# print(ret)
|
||||
|
|
|
@ -219,7 +219,7 @@ if os.name=='nt' and getattr(mp.current_process(), '_inheriting', False):
|
|||
os.environ["log_silent"] = '1'
|
||||
|
||||
if os.environ.get("DISABLE_MULTIPROCESSING", '0') == '1':
|
||||
os.environ["use_parallel_op_compiler"] = '1'
|
||||
os.environ["use_parallel_op_compiler"] = '0'
|
||||
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
|
||||
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
|
||||
n = len(cmds)
|
||||
|
@ -278,11 +278,12 @@ def find_cache_path():
|
|||
|
||||
def get_version(output):
|
||||
if output.endswith("mpicc"):
|
||||
version = run_cmd(output+" --showme:version")
|
||||
elif os.name == 'nt' and output.endswith("cl"):
|
||||
version = run_cmd(f"\"{output}\" --showme:version")
|
||||
elif os.name == 'nt' and (
|
||||
output.endswith("cl") or output.endswith("cl.exe")):
|
||||
version = run_cmd(output)
|
||||
else:
|
||||
version = run_cmd(output+" --version")
|
||||
version = run_cmd(f"\"{output}\" --version")
|
||||
v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version)
|
||||
if len(v) == 0:
|
||||
v = re.findall("[0-9]+\\.[0-9]+", version)
|
||||
|
@ -427,7 +428,7 @@ msvc_path = ""
|
|||
if os.name == 'nt' and os.environ.get("cc_path", "")=="":
|
||||
from pathlib import Path
|
||||
msvc_path = os.path.join(str(Path.home()), ".cache", "jittor", "msvc")
|
||||
cc_path = os.path.join(msvc_path, "cl_x64", "bin", "cl")
|
||||
cc_path = os.path.join(msvc_path, "VC", r"_\_\_\_\_\bin", "cl.exe")
|
||||
check_msvc_install = True
|
||||
else:
|
||||
cc_path = env_or_find('cc_path', 'g++', silent=True)
|
||||
|
|
|
@ -12,7 +12,21 @@ from jittor_utils import LOG
|
|||
from jittor_utils.misc import download_url_to_local
|
||||
import pathlib
|
||||
|
||||
def get_cuda_driver_win():
|
||||
try:
|
||||
import ctypes
|
||||
cuda_driver = ctypes.CDLL(r"nvcuda")
|
||||
driver_version = ctypes.c_int()
|
||||
r = cuda_driver.cuDriverGetVersion(ctypes.byref(driver_version))
|
||||
if r != 0: return None
|
||||
v = driver_version.value
|
||||
return [v//1000, v%1000//10, v%10]
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_cuda_driver():
|
||||
if os.name == 'nt':
|
||||
return get_cuda_driver_win()
|
||||
ret, out = sp.getstatusoutput("nvidia-smi -q -u")
|
||||
if ret != 0: return None
|
||||
try:
|
||||
|
@ -35,21 +49,34 @@ def install_cuda():
|
|||
if not cuda_driver_version:
|
||||
return None
|
||||
LOG.i("cuda_driver_version: ", cuda_driver_version)
|
||||
if "JTCUDA_VERSION" in os.environ:
|
||||
cuda_driver_version = list(map(int,os.enviroment["JTCUDA_VERSION"].split(".")))
|
||||
LOG.i("JTCUDA_VERSION: ", cuda_driver_version)
|
||||
|
||||
if cuda_driver_version >= [11,2]:
|
||||
cuda_tgz = "cuda11.2_cudnn8_linux.tgz"
|
||||
md5 = "b93a1a5d19098e93450ee080509e9836"
|
||||
elif cuda_driver_version >= [11,]:
|
||||
cuda_tgz = "cuda11.0_cudnn8_linux.tgz"
|
||||
md5 = "5dbdb43e35b4db8249027997720bf1ca"
|
||||
elif cuda_driver_version >= [10,2]:
|
||||
cuda_tgz = "cuda10.2_cudnn7_linux.tgz"
|
||||
md5 = "40f0563e8eb176f53e55943f6d212ad7"
|
||||
elif cuda_driver_version >= [10,]:
|
||||
cuda_tgz = "cuda10.0_cudnn7_linux.tgz"
|
||||
md5 = "f16d3ff63f081031d21faec3ec8b7dac"
|
||||
if os.name == 'nt':
|
||||
if cuda_driver_version >= [11,4]:
|
||||
cuda_tgz = "cuda11.4_cudnn8_win.zip"
|
||||
md5 = "06eed370d0d44bb2cc57809343911187"
|
||||
elif cuda_driver_version >= [10,]:
|
||||
cuda_tgz = "cuda10.2_cudnn7_win.zip"
|
||||
md5 = "7dd9963833a91371299a2ba58779dd71"
|
||||
else:
|
||||
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.2")
|
||||
else:
|
||||
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}")
|
||||
if cuda_driver_version >= [11,2]:
|
||||
cuda_tgz = "cuda11.2_cudnn8_linux.tgz"
|
||||
md5 = "b93a1a5d19098e93450ee080509e9836"
|
||||
elif cuda_driver_version >= [11,]:
|
||||
cuda_tgz = "cuda11.0_cudnn8_linux.tgz"
|
||||
md5 = "5dbdb43e35b4db8249027997720bf1ca"
|
||||
elif cuda_driver_version >= [10,2]:
|
||||
cuda_tgz = "cuda10.2_cudnn7_linux.tgz"
|
||||
md5 = "40f0563e8eb176f53e55943f6d212ad7"
|
||||
elif cuda_driver_version >= [10,]:
|
||||
cuda_tgz = "cuda10.0_cudnn7_linux.tgz"
|
||||
md5 = "f16d3ff63f081031d21faec3ec8b7dac"
|
||||
else:
|
||||
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.0")
|
||||
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
|
||||
nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc")
|
||||
nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64")
|
||||
|
@ -65,9 +92,14 @@ def install_cuda():
|
|||
download_url_to_local("https://cg.cs.tsinghua.edu.cn/jittor/assets/"+cuda_tgz, cuda_tgz, jtcuda_path, md5)
|
||||
|
||||
|
||||
import tarfile
|
||||
with tarfile.open(cuda_tgz_path, "r") as tar:
|
||||
tar.extractall(cuda_tgz_path[:-4])
|
||||
if cuda_tgz.endswith(".zip"):
|
||||
import zipfile
|
||||
zf = zipfile.ZipFile(cuda_tgz_path)
|
||||
zf.extractall(path=cuda_tgz_path[:-4])
|
||||
else:
|
||||
import tarfile
|
||||
with tarfile.open(cuda_tgz_path, "r") as tar:
|
||||
tar.extractall(cuda_tgz_path[:-4])
|
||||
|
||||
assert os.path.isfile(nvcc_path)
|
||||
return nvcc_path
|
||||
|
|
|
@ -8,7 +8,7 @@ def install(path):
|
|||
LOG.i("Installing MSVC...")
|
||||
filename = "msvc.zip"
|
||||
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename
|
||||
md5sum = "0fd71436c034808649b24baf28998ccc"
|
||||
md5sum = "929c4b86ea68160121f73c3edf6b5463"
|
||||
download_url_to_local(url, filename, path, md5sum)
|
||||
fullname = os.path.join(path, filename)
|
||||
import zipfile
|
||||
|
|
|
@ -21,7 +21,10 @@ def ensure_dir(dir_path):
|
|||
os.makedirs(dir_path)
|
||||
|
||||
def _progress():
|
||||
pbar = tqdm(total=None)
|
||||
pbar = tqdm(total=None,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024)
|
||||
|
||||
def bar_update(block_num, block_size, total_size):
|
||||
""" reporthook
|
||||
|
|
Loading…
Reference in New Issue