mirror of https://github.com/Jittor/Jittor
polish win_cuda on linux
This commit is contained in:
parent
e77f1ea7cb
commit
c1ee6d9ed3
|
@ -126,7 +126,7 @@ def setup_mkl():
|
||||||
mkl_lib_path = os.path.join(mkl_home, "lib")
|
mkl_lib_path = os.path.join(mkl_home, "lib")
|
||||||
|
|
||||||
mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so")
|
mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so")
|
||||||
extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -lmkldnn -Wl,-rpath='{mkl_lib_path}' "
|
extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -lmkldnn "
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll')
|
mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll')
|
||||||
mkl_bin_path = os.path.join(mkl_home, 'bin')
|
mkl_bin_path = os.path.join(mkl_home, 'bin')
|
||||||
|
@ -199,9 +199,9 @@ def setup_cuda_extern():
|
||||||
cuda_extern_src = os.path.join(jittor_path, "extern", "cuda", "src")
|
cuda_extern_src = os.path.join(jittor_path, "extern", "cuda", "src")
|
||||||
cuda_extern_files = [os.path.join(cuda_extern_src, name)
|
cuda_extern_files = [os.path.join(cuda_extern_src, name)
|
||||||
for name in os.listdir(cuda_extern_src)]
|
for name in os.listdir(cuda_extern_src)]
|
||||||
so_name = os.path.join(cache_path_cuda, "cuda_extern"+so)
|
so_name = os.path.join(cache_path_cuda, "libcuda_extern"+so)
|
||||||
compile(cc_path, cc_flags+f" -I\"{cuda_include}\" ", cuda_extern_files, so_name)
|
compile(cc_path, cc_flags+f" -I\"{cuda_include}\" ", cuda_extern_files, so_name)
|
||||||
link_cuda_extern = f" -L\"{cache_path_cuda}\" -lcuda_extern "
|
link_cuda_extern = f" -L\"{cache_path_cuda}\" -llibcuda_extern "
|
||||||
ctypes.CDLL(so_name, dlopen_flags)
|
ctypes.CDLL(so_name, dlopen_flags)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -118,7 +118,7 @@ def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="")
|
||||||
inputs = new_inputs
|
inputs = new_inputs
|
||||||
|
|
||||||
if len(inputs) == 1 or combind_build:
|
if len(inputs) == 1 or combind_build:
|
||||||
cmd = f"\"{compiler}\" {' '.join(inputs)} {flags} {link} -o {output}"
|
cmd = f"\"{compiler}\" {' '.join(inputs)} {flags} -o {output}"
|
||||||
return do_compile(fix_cl_flags(cmd))
|
return do_compile(fix_cl_flags(cmd))
|
||||||
# split compile object file and link
|
# split compile object file and link
|
||||||
# remove -l -L flags when compile object files
|
# remove -l -L flags when compile object files
|
||||||
|
@ -1019,7 +1019,18 @@ if platform.system() == 'Darwin':
|
||||||
kernel_opt_flags += " -Xpreprocessor -fopenmp "
|
kernel_opt_flags += " -Xpreprocessor -fopenmp "
|
||||||
elif cc_type != 'cl':
|
elif cc_type != 'cl':
|
||||||
kernel_opt_flags += " -fopenmp "
|
kernel_opt_flags += " -fopenmp "
|
||||||
fix_cl_flags = lambda x:x
|
def fix_cl_flags(cmd):
|
||||||
|
output = shsplit(cmd)
|
||||||
|
output2 = []
|
||||||
|
for s in output:
|
||||||
|
if s.startswith("-l") and ("cpython" in s or "lib" in s):
|
||||||
|
output2.append(f"-l:{s[2:]}.so")
|
||||||
|
elif s.startswith("-L"):
|
||||||
|
output2.append(f"{s} -Wl,-rpath={s[2:]}")
|
||||||
|
else:
|
||||||
|
output2.append(s)
|
||||||
|
return " ".join(output2)
|
||||||
|
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
if cc_type == 'g++':
|
if cc_type == 'g++':
|
||||||
pass
|
pass
|
||||||
|
@ -1251,9 +1262,9 @@ with jit_utils.import_scope(import_flags):
|
||||||
import jittor_core as core
|
import jittor_core as core
|
||||||
|
|
||||||
flags = core.flags()
|
flags = core.flags()
|
||||||
nvcc_flags = convert_nvcc_flags(cc_flags)
|
|
||||||
|
|
||||||
if has_cuda:
|
if has_cuda:
|
||||||
|
nvcc_flags = convert_nvcc_flags(cc_flags)
|
||||||
if len(flags.cuda_archs):
|
if len(flags.cuda_archs):
|
||||||
nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} "
|
nvcc_flags += f" -arch=compute_{min(flags.cuda_archs)} "
|
||||||
nvcc_flags += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
nvcc_flags += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||||
|
|
|
@ -33,7 +33,6 @@ DEFINE_FLAG(string, python_path, "", "Path of python interpreter");
|
||||||
DEFINE_FLAG(string, cache_path, "", "Cache path of jittor");
|
DEFINE_FLAG(string, cache_path, "", "Cache path of jittor");
|
||||||
DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not");
|
DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not");
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
vector<string> shsplit(const string& s) {
|
vector<string> shsplit(const string& s) {
|
||||||
auto s1 = split(s, " ");
|
auto s1 = split(s, " ");
|
||||||
vector<string> s2;
|
vector<string> s2;
|
||||||
|
@ -54,7 +53,8 @@ vector<string> shsplit(const string& s) {
|
||||||
return s2;
|
return s2;
|
||||||
}
|
}
|
||||||
|
|
||||||
string fix_cl_flags(const string& cmd) {
|
string fix_cl_flags(const string& cmd, bool is_cuda) {
|
||||||
|
#ifdef _MSC_VER
|
||||||
auto flags = shsplit(cmd);
|
auto flags = shsplit(cmd);
|
||||||
vector<string> output, output2;
|
vector<string> output, output2;
|
||||||
|
|
||||||
|
@ -95,8 +95,31 @@ string fix_cl_flags(const string& cmd) {
|
||||||
cmdx += " ";
|
cmdx += " ";
|
||||||
}
|
}
|
||||||
return cmdx;
|
return cmdx;
|
||||||
}
|
#else
|
||||||
|
auto flags = shsplit(cmd);
|
||||||
|
vector<string> output;
|
||||||
|
|
||||||
|
for (auto& f : flags) {
|
||||||
|
if (startswith(f, "-l") &&
|
||||||
|
(f.find("cpython") != string::npos ||
|
||||||
|
f.find("lib") != string::npos))
|
||||||
|
output.push_back("-l:"+f.substr(2)+".so");
|
||||||
|
else if (startswith(f, "-L")) {
|
||||||
|
if (is_cuda)
|
||||||
|
output.push_back(f+" -Xlinker -rpath="+f.substr(2));
|
||||||
|
else
|
||||||
|
output.push_back(f+" -Wl,-rpath="+f.substr(2));
|
||||||
|
} else
|
||||||
|
output.push_back(f);
|
||||||
|
}
|
||||||
|
string cmdx = "";
|
||||||
|
for (auto& s : output) {
|
||||||
|
cmdx += s;
|
||||||
|
cmdx += " ";
|
||||||
|
}
|
||||||
|
return cmdx;
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
namespace jit_compiler {
|
namespace jit_compiler {
|
||||||
|
|
||||||
|
@ -174,12 +197,12 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
||||||
if (is_cuda_op) {
|
if (is_cuda_op) {
|
||||||
cmd = "\"" + nvcc_path + "\""
|
cmd = "\"" + nvcc_path + "\""
|
||||||
+ " \"" + jit_src_path + "\"" + other_src
|
+ " \"" + jit_src_path + "\"" + other_src
|
||||||
+ nvcc_flags + extra_flags
|
+ fix_cl_flags(nvcc_flags + extra_flags, is_cuda_op)
|
||||||
+ " -o \"" + jit_lib_path + "\"";
|
+ " -o \"" + jit_lib_path + "\"";
|
||||||
} else {
|
} else {
|
||||||
cmd = "\"" + cc_path + "\""
|
cmd = "\"" + cc_path + "\""
|
||||||
+ " \"" + jit_src_path + "\"" + other_src
|
+ " \"" + jit_src_path + "\"" + other_src
|
||||||
+ cc_flags + extra_flags
|
+ fix_cl_flags(cc_flags + extra_flags, is_cuda_op)
|
||||||
+ " -o \"" + jit_lib_path + "\"";
|
+ " -o \"" + jit_lib_path + "\"";
|
||||||
#ifdef __linux__
|
#ifdef __linux__
|
||||||
cmd = python_path+" "+jittor_path+"/utils/asm_tuner.py "
|
cmd = python_path+" "+jittor_path+"/utils/asm_tuner.py "
|
||||||
|
@ -193,12 +216,12 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
||||||
+ nvcc_flags + extra_flags
|
+ nvcc_flags + extra_flags
|
||||||
+ " -o \"" + jit_lib_path + "\""
|
+ " -o \"" + jit_lib_path + "\""
|
||||||
+ " -Xlinker -EXPORT:\""
|
+ " -Xlinker -EXPORT:\""
|
||||||
+ symbol_name + "\"";;
|
+ symbol_name + "\"";
|
||||||
} else {
|
} else {
|
||||||
cmd = "\"" + cc_path + "\""
|
cmd = "\"" + cc_path + "\""
|
||||||
+ " \"" + jit_src_path + "\"" + other_src
|
+ " \"" + jit_src_path + "\"" + other_src
|
||||||
+ " -Fe: \"" + jit_lib_path + "\" "
|
+ " -Fe: \"" + jit_lib_path + "\" "
|
||||||
+ fix_cl_flags(cc_flags + extra_flags) + " -EXPORT:\""
|
+ fix_cl_flags(cc_flags + extra_flags, is_cuda_op) + " -EXPORT:\""
|
||||||
+ symbol_name + "\"";
|
+ symbol_name + "\"";
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -241,7 +241,8 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
|
||||||
find_names(cmd, input_names, output_name, extra);
|
find_names(cmd, input_names, output_name, extra);
|
||||||
string output_cache_key;
|
string output_cache_key;
|
||||||
bool ran = false;
|
bool ran = false;
|
||||||
output_cache_key = read_all(output_name+".key");
|
if (file_exist(output_name))
|
||||||
|
output_cache_key = read_all(output_name+".key");
|
||||||
string cache_key;
|
string cache_key;
|
||||||
unordered_set<string> processed;
|
unordered_set<string> processed;
|
||||||
auto src_path = join(jittor_path, "src");
|
auto src_path = join(jittor_path, "src");
|
||||||
|
|
|
@ -131,15 +131,28 @@ so_pos=cmd.find("_op.so")
|
||||||
# remove -Xclang ...
|
# remove -Xclang ...
|
||||||
remove_clang_flag = lambda s: re.sub("-Xclang (('[^']*')|([^ ]*))", "", s)
|
remove_clang_flag = lambda s: re.sub("-Xclang (('[^']*')|([^ ]*))", "", s)
|
||||||
|
|
||||||
|
def shsplit(s):
|
||||||
|
s1 = s.split(' ')
|
||||||
|
s2 = []
|
||||||
|
count = 0
|
||||||
|
for s in s1:
|
||||||
|
nc = s.count('"') + s.count('\'')
|
||||||
|
if count&1:
|
||||||
|
count += nc
|
||||||
|
s2[-1] += " "
|
||||||
|
s2[-1] += s
|
||||||
|
else:
|
||||||
|
count = nc
|
||||||
|
s2.append(s)
|
||||||
|
return s2
|
||||||
|
|
||||||
def remove_flags(flags, rm_flags):
|
def remove_flags(flags, rm_flags):
|
||||||
flags = flags.split(" ")
|
flags = shsplit(flags)
|
||||||
output = []
|
output = []
|
||||||
for s in flags:
|
for s in flags:
|
||||||
if s.startswith("-load"):
|
ss = s.replace("\"", "")
|
||||||
output.append(s)
|
|
||||||
continue
|
|
||||||
for rm in rm_flags:
|
for rm in rm_flags:
|
||||||
if s.startswith(rm):
|
if ss.startswith(rm) or ss.endswith(rm):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
output.append(s)
|
output.append(s)
|
||||||
|
@ -161,7 +174,7 @@ else: #cc_to_so
|
||||||
.replace("-ldl", "") \
|
.replace("-ldl", "") \
|
||||||
.replace("-shared", "-S") \
|
.replace("-shared", "-S") \
|
||||||
.replace(" -o ", " -g -o ")
|
.replace(" -o ", " -g -o ")
|
||||||
asm_cmd = remove_flags(asm_cmd, ['-l', '-L', '-Wl,'])
|
asm_cmd = remove_flags(asm_cmd, ['-l', '-L', '-Wl,', '.lib', '-shared'])
|
||||||
run_cmd(asm_cmd)
|
run_cmd(asm_cmd)
|
||||||
|
|
||||||
s_path = cc_path.replace("_op.cc","_op.post.s")
|
s_path = cc_path.replace("_op.cc","_op.post.s")
|
||||||
|
@ -169,5 +182,5 @@ else: #cc_to_so
|
||||||
pass_asm(cc_path,s_path)
|
pass_asm(cc_path,s_path)
|
||||||
|
|
||||||
asm_cmd = cmd.replace("_op.cc", "_op.s") \
|
asm_cmd = cmd.replace("_op.cc", "_op.s") \
|
||||||
.replace("-g", "")
|
.replace(" -g", "")
|
||||||
run_cmd(remove_clang_flag(asm_cmd))
|
run_cmd(remove_clang_flag(asm_cmd))
|
Loading…
Reference in New Issue