add /usr/lib as cuda lib search path

This commit is contained in:
Dun Liang 2021-06-05 10:59:06 +08:00
parent 2c2f5b156d
commit fd7d68e6aa
2 changed files with 4 additions and 4 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.17'
__version__ = '1.2.3.18'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -180,11 +180,11 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
prefer_version = ()
if nvcc_version[0] == 11:
prefer_version = ("8",)
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so", prefer_version)
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu", "/usr/lib"], f"lib{lib_name}.so", prefer_version)
if lib_name == "cublas" and nvcc_version[0] >= 10:
# manual link libcublasLt.so
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"libcublasLt.so", nvcc_version)
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version)
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
@ -193,7 +193,7 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
if nvcc_version >= (11,0,0):
libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"]
for l in libs:
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], l, prefer_version)
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu", "/usr/lib"], l, prefer_version)
ctypes.CDLL(ex_cudnn_path, dlopen_flags)
# dynamic link cuda library