From c1ee6d9ed38c43bd64c10f6dcad371523c9a0893 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Sun, 26 Sep 2021 19:48:22 +0800 Subject: [PATCH] polish win_cuda on linux --- python/jittor/compile_extern.py | 6 ++-- python/jittor/compiler.py | 17 +++++++++-- python/jittor/src/jit_compiler.cc | 37 +++++++++++++++++++----- python/jittor/src/utils/cache_compile.cc | 3 +- python/jittor/utils/asm_tuner.py | 27 ++++++++++++----- 5 files changed, 69 insertions(+), 21 deletions(-) diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index e38ab6e6..d89b6fcb 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -126,7 +126,7 @@ def setup_mkl(): mkl_lib_path = os.path.join(mkl_home, "lib") mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so") - extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -lmkldnn -Wl,-rpath='{mkl_lib_path}' " + extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -lmkldnn " if os.name == 'nt': mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll') mkl_bin_path = os.path.join(mkl_home, 'bin') @@ -199,9 +199,9 @@ def setup_cuda_extern(): cuda_extern_src = os.path.join(jittor_path, "extern", "cuda", "src") cuda_extern_files = [os.path.join(cuda_extern_src, name) for name in os.listdir(cuda_extern_src)] - so_name = os.path.join(cache_path_cuda, "cuda_extern"+so) + so_name = os.path.join(cache_path_cuda, "libcuda_extern"+so) compile(cc_path, cc_flags+f" -I\"{cuda_include}\" ", cuda_extern_files, so_name) - link_cuda_extern = f" -L\"{cache_path_cuda}\" -lcuda_extern " + link_cuda_extern = f" -L\"{cache_path_cuda}\" -llibcuda_extern " ctypes.CDLL(so_name, dlopen_flags) try: diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index 61a15494..8d3ea461 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -118,7 +118,7 @@ def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="") inputs = new_inputs if len(inputs) == 1 or combind_build: - cmd = f"\"{compiler}\" {' '.join(inputs)} {flags} {link} -o {output}" + cmd = f"\"{compiler}\" {' '.join(inputs)} {flags} -o {output}" return do_compile(fix_cl_flags(cmd)) # split compile object file and link # remove -l -L flags when compile object files @@ -1019,7 +1019,18 @@ if platform.system() == 'Darwin': kernel_opt_flags += " -Xpreprocessor -fopenmp " elif cc_type != 'cl': kernel_opt_flags += " -fopenmp " -fix_cl_flags = lambda x:x +def fix_cl_flags(cmd): + output = shsplit(cmd) + output2 = [] + for s in output: + if s.startswith("-l") and ("cpython" in s or "lib" in s): + output2.append(f"-l:{s[2:]}.so") + elif s.startswith("-L"): + output2.append(f"{s} -Wl,-rpath={s[2:]}") + else: + output2.append(s) + return " ".join(output2) + if os.name == 'nt': if cc_type == 'g++': pass @@ -1251,9 +1262,9 @@ with jit_utils.import_scope(import_flags): import jittor_core as core flags = core.flags() -nvcc_flags = convert_nvcc_flags(cc_flags) if has_cuda: + nvcc_flags = convert_nvcc_flags(cc_flags) if len(flags.cuda_archs): nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} " nvcc_flags += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs)) diff --git a/python/jittor/src/jit_compiler.cc b/python/jittor/src/jit_compiler.cc index fe117b6f..03e3594f 100644 --- a/python/jittor/src/jit_compiler.cc +++ b/python/jittor/src/jit_compiler.cc @@ -33,7 +33,6 @@ DEFINE_FLAG(string, python_path, "", "Path of python interpreter"); DEFINE_FLAG(string, cache_path, "", "Cache path of jittor"); DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not"); -#ifdef _MSC_VER vector shsplit(const string& s) { auto s1 = split(s, " "); vector s2; @@ -54,7 +53,8 @@ vector shsplit(const string& s) { return s2; } -string fix_cl_flags(const string& cmd) { +string fix_cl_flags(const string& cmd, bool is_cuda) { +#ifdef _MSC_VER auto flags = shsplit(cmd); vector output, output2; @@ -95,8 +95,31 @@ string fix_cl_flags(const string& cmd) { cmdx += " "; } return cmdx; -} +#else + auto flags = shsplit(cmd); + vector output; + + for (auto& f : flags) { + if (startswith(f, "-l") && + (f.find("cpython") != string::npos || + f.find("lib") != string::npos)) + output.push_back("-l:"+f.substr(2)+".so"); + else if (startswith(f, "-L")) { + if (is_cuda) + output.push_back(f+" -Xlinker -rpath="+f.substr(2)); + else + output.push_back(f+" -Wl,-rpath="+f.substr(2)); + } else + output.push_back(f); + } + string cmdx = ""; + for (auto& s : output) { + cmdx += s; + cmdx += " "; + } + return cmdx; #endif +} namespace jit_compiler { @@ -174,12 +197,12 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c if (is_cuda_op) { cmd = "\"" + nvcc_path + "\"" + " \"" + jit_src_path + "\"" + other_src - + nvcc_flags + extra_flags + + fix_cl_flags(nvcc_flags + extra_flags, is_cuda_op) + " -o \"" + jit_lib_path + "\""; } else { cmd = "\"" + cc_path + "\"" + " \"" + jit_src_path + "\"" + other_src - + cc_flags + extra_flags + + fix_cl_flags(cc_flags + extra_flags, is_cuda_op) + " -o \"" + jit_lib_path + "\""; #ifdef __linux__ cmd = python_path+" "+jittor_path+"/utils/asm_tuner.py " @@ -193,12 +216,12 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c + nvcc_flags + extra_flags + " -o \"" + jit_lib_path + "\"" + " -Xlinker -EXPORT:\"" - + symbol_name + "\"";; + + symbol_name + "\""; } else { cmd = "\"" + cc_path + "\"" + " \"" + jit_src_path + "\"" + other_src + " -Fe: \"" + jit_lib_path + "\" " - + fix_cl_flags(cc_flags + extra_flags) + " -EXPORT:\"" + + fix_cl_flags(cc_flags + extra_flags, is_cuda_op) + " -EXPORT:\"" + symbol_name + "\""; } #endif diff --git a/python/jittor/src/utils/cache_compile.cc b/python/jittor/src/utils/cache_compile.cc index 5ee73772..12f9fe59 100644 --- a/python/jittor/src/utils/cache_compile.cc +++ b/python/jittor/src/utils/cache_compile.cc @@ -241,7 +241,8 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa find_names(cmd, input_names, output_name, extra); string output_cache_key; bool ran = false; - output_cache_key = read_all(output_name+".key"); + if (file_exist(output_name)) + output_cache_key = read_all(output_name+".key"); string cache_key; unordered_set processed; auto src_path = join(jittor_path, "src"); diff --git a/python/jittor/utils/asm_tuner.py b/python/jittor/utils/asm_tuner.py index 230cabc2..b4187d4d 100755 --- a/python/jittor/utils/asm_tuner.py +++ b/python/jittor/utils/asm_tuner.py @@ -131,15 +131,28 @@ so_pos=cmd.find("_op.so") # remove -Xclang ... remove_clang_flag = lambda s: re.sub("-Xclang (('[^']*')|([^ ]*))", "", s) +def shsplit(s): + s1 = s.split(' ') + s2 = [] + count = 0 + for s in s1: + nc = s.count('"') + s.count('\'') + if count&1: + count += nc + s2[-1] += " " + s2[-1] += s + else: + count = nc + s2.append(s) + return s2 + def remove_flags(flags, rm_flags): - flags = flags.split(" ") + flags = shsplit(flags) output = [] for s in flags: - if s.startswith("-load"): - output.append(s) - continue + ss = s.replace("\"", "") for rm in rm_flags: - if s.startswith(rm): + if ss.startswith(rm) or ss.endswith(rm): break else: output.append(s) @@ -161,7 +174,7 @@ else: #cc_to_so .replace("-ldl", "") \ .replace("-shared", "-S") \ .replace(" -o ", " -g -o ") - asm_cmd = remove_flags(asm_cmd, ['-l', '-L', '-Wl,']) + asm_cmd = remove_flags(asm_cmd, ['-l', '-L', '-Wl,', '.lib', '-shared']) run_cmd(asm_cmd) s_path = cc_path.replace("_op.cc","_op.post.s") @@ -169,5 +182,5 @@ else: #cc_to_so pass_asm(cc_path,s_path) asm_cmd = cmd.replace("_op.cc", "_op.s") \ - .replace("-g", "") + .replace(" -g", "") run_cmd(remove_clang_flag(asm_cmd)) \ No newline at end of file