mirror of https://github.com/Jittor/Jittor
fix cudnn 8 cuda 11 link problem
This commit is contained in:
parent
b9f6f048cc
commit
2476a471d4
|
@ -6,7 +6,7 @@
|
|||
# ***************************************************************
|
||||
import os, sys, shutil
|
||||
from .compiler import *
|
||||
from jittor_utils import run_cmd, get_version
|
||||
from jittor_utils import run_cmd, get_version, get_int_version
|
||||
from jittor.utils.misc import download_url_to_local
|
||||
|
||||
def search_file(dirs, name):
|
||||
|
@ -160,6 +160,16 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
extra_lib_path = os.path.abspath(os.path.join(cuda_lib, "..", "targets/x86_64-linux/lib"))
|
||||
cuda_include_name = search_file([cuda_include, extra_include_path, "/usr/include"], lib_name+".h")
|
||||
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so")
|
||||
|
||||
if lib_name == "cudnn":
|
||||
# cudnn cannot found libcudnn_cnn_train.so.8, we manual link for it.
|
||||
nvcc_version = get_int_version(nvcc_path)
|
||||
if nvcc_version >= (11,0,0):
|
||||
libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"]
|
||||
for l in libs:
|
||||
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], l)
|
||||
ctypes.CDLL(ex_cudnn_path, dlopen_flags)
|
||||
|
||||
# dynamic link cuda library
|
||||
ctypes.CDLL(culib_path, dlopen_flags)
|
||||
link_flags = f"-l{lib_name} -L'{cuda_lib}'"
|
||||
|
|
|
@ -241,6 +241,12 @@ def get_version(output):
|
|||
version = "("+v[-1]+")"
|
||||
return version
|
||||
|
||||
def get_int_version(output):
|
||||
ver = get_version(output)
|
||||
ver = ver[1:-1].split('.')
|
||||
ver = tuple(( int(v) for v in ver ))
|
||||
return ver
|
||||
|
||||
def find_exe(name, check_version=True, silent=False):
|
||||
output = run_cmd(f'which {name}', err_msg=f'{name} not found')
|
||||
if check_version:
|
||||
|
|
Loading…
Reference in New Issue