add cuda aarch64 support

This commit is contained in:
Dun Liang 2021-07-28 11:41:13 +08:00
parent 0748fc1854
commit d563826b2b
3 changed files with 20 additions and 9 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.82'
__version__ = '1.2.3.83'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -183,13 +183,20 @@ def setup_cuda_extern():
line = traceback.format_exc()
LOG.w(f"CUDA found but {lib_name} is not loaded:\n{line}")
if lib_name == "cudnn":
LOG.w(f"""Develop version of CUDNN not found,
msg = """Develop version of CUDNN not found,
please refer to CUDA offical tar file installation:
https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar
https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar"""
if platform.machine() == "x86_64":
msg += """
or you can let jittor install cuda and cudnn for you:
>>> python3.{sys.version_info.minor} -m jittor_utils.install_cuda""")
>>> python3.{sys.version_info.minor} -m jittor_utils.install_cuda
"""
LOG.w(msg)
def setup_cuda_lib(lib_name, link=True, extra_flags=""):
arch_key = "x86_64"
if platform.machine() != "x86_64":
arch_key = "aarch64"
globals()[lib_name+"_ops"] = None
globals()[lib_name] = None
if not has_cuda: return
@ -201,20 +208,20 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
link_flags = ""
if link:
extra_include_path = os.path.abspath(os.path.join(cuda_include, "..", "targets/x86_64-linux/include"))
extra_lib_path = os.path.abspath(os.path.join(cuda_lib, "..", "targets/x86_64-linux/lib"))
extra_include_path = os.path.abspath(os.path.join(cuda_include, "..", f"targets/{arch_key}-linux/include"))
extra_lib_path = os.path.abspath(os.path.join(cuda_lib, "..", f"targets/{arch_key}-linux/lib"))
cuda_include_name = search_file([cuda_include, extra_include_path, "/usr/include"], lib_name+".h")
# cuda11 prefer cudnn 8
nvcc_version = get_int_version(nvcc_path)
prefer_version = ()
if nvcc_version[0] == 11:
prefer_version = ("8",)
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu", "/usr/lib"], f"lib{lib_name}.so", prefer_version)
culib_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"lib{lib_name}.so", prefer_version)
if lib_name == "cublas" and nvcc_version[0] >= 10:
# manual link libcublasLt.so
try:
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version)
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version)
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
except:
# some aarch64 os, such as uos with FT2000 cpu,
@ -228,7 +235,7 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
if nvcc_version >= (11,0,0):
libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"]
for l in libs:
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu", "/usr/lib"], l, prefer_version)
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], l, prefer_version)
ctypes.CDLL(ex_cudnn_path, dlopen_flags)
# dynamic link cuda library

View File

@ -74,5 +74,9 @@ class TestTransposeOp(unittest.TestCase):
b = a.transpose()
assert (a.data.transpose() == b.data).all()
a = jt.zeros((1,1))
b = a.transpose((1,0))
b.sync()
if __name__ == "__main__":
unittest.main()