fix cudnn 8 cuda 11 link problem

This commit is contained in:
Dun Liang 2021-01-19 15:21:49 +08:00
parent b9f6f048cc
commit 2476a471d4
2 changed files with 17 additions and 1 deletions

View File

@ -6,7 +6,7 @@
# ***************************************************************
import os, sys, shutil
from .compiler import *
from jittor_utils import run_cmd, get_version
from jittor_utils import run_cmd, get_version, get_int_version
from jittor.utils.misc import download_url_to_local
def search_file(dirs, name):
@ -160,6 +160,16 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
extra_lib_path = os.path.abspath(os.path.join(cuda_lib, "..", "targets/x86_64-linux/lib"))
cuda_include_name = search_file([cuda_include, extra_include_path, "/usr/include"], lib_name+".h")
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so")
if lib_name == "cudnn":
# cudnn cannot found libcudnn_cnn_train.so.8, we manual link for it.
nvcc_version = get_int_version(nvcc_path)
if nvcc_version >= (11,0,0):
libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"]
for l in libs:
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], l)
ctypes.CDLL(ex_cudnn_path, dlopen_flags)
# dynamic link cuda library
ctypes.CDLL(culib_path, dlopen_flags)
link_flags = f"-l{lib_name} -L'{cuda_lib}'"

View File

@ -241,6 +241,12 @@ def get_version(output):
version = "("+v[-1]+")"
return version
def get_int_version(output):
ver = get_version(output)
ver = ver[1:-1].split('.')
ver = tuple(( int(v) for v in ver ))
return ver
def find_exe(name, check_version=True, silent=False):
output = run_cmd(f'which {name}', err_msg=f'{name} not found')
if check_version: