mirror of https://github.com/Jittor/Jittor
polish mkl and cutt install
This commit is contained in:
parent
7bc620f274
commit
2bf1b89c5e
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.0.4'
|
||||
__version__ = '1.3.0.5'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -75,9 +75,8 @@ def install_mkl(root_folder):
|
|||
# this env is used for execute example/text
|
||||
bin_path = os.path.join(dirname, "bin")
|
||||
sys.path.append(bin_path)
|
||||
os.add_dll_directory(bin_path)
|
||||
os.environ["PATH"] = os.environ.get("PATH", "") + ";" + bin_path
|
||||
cmd = f"cd /d {dirname}/examples && {cc_path} {dirname}/examples/cnn_inference_f32.cpp -I{dirname}/include -Fe: {dirname}/examples/test {fix_cl_flags(cc_flags)} {dirname}/lib/mkldnn.lib"
|
||||
cmd = f"cd /d {dirname}/examples && {cc_path} {dirname}/examples/cnn_inference_f32.cpp -I{dirname}/include -Fe: {dirname}/examples/test.exe {fix_cl_flags(cc_flags).replace('-LD', '')} {dirname}/lib/mkldnn.lib"
|
||||
|
||||
assert 0 == os.system(cmd)
|
||||
assert 0 == os.system(f"{dirname}/examples/test")
|
||||
|
@ -130,8 +129,7 @@ def setup_mkl():
|
|||
if os.name == 'nt':
|
||||
mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll')
|
||||
mkl_bin_path = os.path.join(mkl_home, 'bin')
|
||||
os.add_dll_directory(mkl_bin_path)
|
||||
extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -ldnnl "
|
||||
extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -L\"{mkl_bin_path}\" -ldnnl "
|
||||
assert os.path.isdir(mkl_include_path)
|
||||
assert os.path.isdir(mkl_lib_path)
|
||||
assert os.path.isfile(mkl_lib_name)
|
||||
|
@ -374,8 +372,6 @@ def setup_cutt():
|
|||
LOG.v(f"cutt_lib_path: {cutt_lib_path}")
|
||||
LOG.v(f"cutt_lib_name: {cutt_lib_name}")
|
||||
# We do not link manualy, link in custom ops
|
||||
if os.name == "nt":
|
||||
os.add_dll_directory(cutt_lib_path)
|
||||
ctypes.CDLL(cutt_lib_name, dlopen_flags)
|
||||
|
||||
cutt_op_dir = os.path.join(jittor_path, "extern", "cuda", "cutt", "ops")
|
||||
|
|
|
@ -847,7 +847,6 @@ def check_cuda():
|
|||
# cc_flags += f" \"{cuda_lib}\\cudart.lib\" "
|
||||
cuda_lib_path = glob.glob(cuda_bin+"/cudart64*")[0]
|
||||
cc_flags += f" -lcudart -L\"{cuda_lib}\" "
|
||||
os.add_dll_directory(cuda_dir)
|
||||
dll = ctypes.CDLL(cuda_lib_path, dlopen_flags)
|
||||
ret = dll.cudaDeviceSynchronize()
|
||||
assert ret == 0
|
||||
|
@ -1051,6 +1050,7 @@ if os.name == 'nt':
|
|||
os.path.dirname(sys.executable),
|
||||
"libs",
|
||||
)
|
||||
cc_flags = remove_flags(cc_flags, ["-f", "-m"])
|
||||
cc_flags = cc_flags.replace("-std=c++14", "-std=c++17")
|
||||
cc_flags = cc_flags.replace("-lstdc++", "")
|
||||
cc_flags = cc_flags.replace("-ldl", "")
|
||||
|
@ -1061,6 +1061,7 @@ if os.name == 'nt':
|
|||
mp = jittor_utils.msvc_path
|
||||
cc_flags += f' -nologo -I"{mp}\\VC\\include" -I"{mp}\\win10_kits\\include\\ucrt" -I"{mp}\\win10_kits\\include\\shared" -I"{mp}\\win10_kits\\include\\um" -DNOMINMAX '
|
||||
cc_flags += f' -L"{mp}\\VC\\lib" -L"{mp}\\win10_kits\\lib\\um\\x64" -L"{mp}\\win10_kits\\lib\\ucrt\\x64" '
|
||||
win_libpaths = {}
|
||||
def fix_cl_flags(cmd):
|
||||
cmd = cmd.replace(".o ", ".obj ")
|
||||
cmd = cmd.replace(".o\" ", ".obj\" ")
|
||||
|
@ -1087,6 +1088,10 @@ if os.name == 'nt':
|
|||
elif f.startswith("-LD"):
|
||||
output.append(f)
|
||||
elif f.startswith("-L"):
|
||||
path = f[2:].replace("\"", "")
|
||||
if path not in win_libpaths:
|
||||
win_libpaths[path] = 1
|
||||
os.add_dll_directory(path)
|
||||
output2.append("-LIBPATH:"+f[2:])
|
||||
elif ".lib" in f:
|
||||
output2.append(f)
|
||||
|
@ -1141,7 +1146,7 @@ if has_cuda:
|
|||
# nvcc don't support -Wall option
|
||||
if os.name == 'nt':
|
||||
nvcc_flags = nvcc_flags.replace("-fp:", "-Xcompiler -fp:")
|
||||
nvcc_flags = nvcc_flags.replace("-EHa", "-Xcompiler -EHa")
|
||||
nvcc_flags = nvcc_flags.replace("-EH", "-Xcompiler -EH")
|
||||
nvcc_flags = nvcc_flags.replace("-M", "-Xcompiler -M")
|
||||
nvcc_flags = nvcc_flags.replace("-nologo", "")
|
||||
nvcc_flags = nvcc_flags.replace("-std:", "-std=")
|
||||
|
|
|
@ -376,6 +376,7 @@ def get_py3_include_path():
|
|||
|
||||
if os.name == 'nt':
|
||||
# Windows
|
||||
sys.executable = sys.executable.lower()
|
||||
_py3_include_path = '-I"' + os.path.join(
|
||||
os.path.dirname(sys.executable),
|
||||
"include"
|
||||
|
|
Loading…
Reference in New Issue