mirror of https://github.com/Jittor/Jittor
add cuda aarch64 support
This commit is contained in:
parent
0748fc1854
commit
d563826b2b
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.82'
|
||||
__version__ = '1.2.3.83'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -183,13 +183,20 @@ def setup_cuda_extern():
|
|||
line = traceback.format_exc()
|
||||
LOG.w(f"CUDA found but {lib_name} is not loaded:\n{line}")
|
||||
if lib_name == "cudnn":
|
||||
LOG.w(f"""Develop version of CUDNN not found,
|
||||
msg = """Develop version of CUDNN not found,
|
||||
please refer to CUDA offical tar file installation:
|
||||
https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar
|
||||
https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar"""
|
||||
if platform.machine() == "x86_64":
|
||||
msg += """
|
||||
or you can let jittor install cuda and cudnn for you:
|
||||
>>> python3.{sys.version_info.minor} -m jittor_utils.install_cuda""")
|
||||
>>> python3.{sys.version_info.minor} -m jittor_utils.install_cuda
|
||||
"""
|
||||
LOG.w(msg)
|
||||
|
||||
def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
||||
arch_key = "x86_64"
|
||||
if platform.machine() != "x86_64":
|
||||
arch_key = "aarch64"
|
||||
globals()[lib_name+"_ops"] = None
|
||||
globals()[lib_name] = None
|
||||
if not has_cuda: return
|
||||
|
@ -201,20 +208,20 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
|
||||
link_flags = ""
|
||||
if link:
|
||||
extra_include_path = os.path.abspath(os.path.join(cuda_include, "..", "targets/x86_64-linux/include"))
|
||||
extra_lib_path = os.path.abspath(os.path.join(cuda_lib, "..", "targets/x86_64-linux/lib"))
|
||||
extra_include_path = os.path.abspath(os.path.join(cuda_include, "..", f"targets/{arch_key}-linux/include"))
|
||||
extra_lib_path = os.path.abspath(os.path.join(cuda_lib, "..", f"targets/{arch_key}-linux/lib"))
|
||||
cuda_include_name = search_file([cuda_include, extra_include_path, "/usr/include"], lib_name+".h")
|
||||
# cuda11 prefer cudnn 8
|
||||
nvcc_version = get_int_version(nvcc_path)
|
||||
prefer_version = ()
|
||||
if nvcc_version[0] == 11:
|
||||
prefer_version = ("8",)
|
||||
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu", "/usr/lib"], f"lib{lib_name}.so", prefer_version)
|
||||
culib_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"lib{lib_name}.so", prefer_version)
|
||||
|
||||
if lib_name == "cublas" and nvcc_version[0] >= 10:
|
||||
# manual link libcublasLt.so
|
||||
try:
|
||||
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version)
|
||||
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version)
|
||||
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
|
||||
except:
|
||||
# some aarch64 os, such as uos with FT2000 cpu,
|
||||
|
@ -228,7 +235,7 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
if nvcc_version >= (11,0,0):
|
||||
libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"]
|
||||
for l in libs:
|
||||
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu", "/usr/lib"], l, prefer_version)
|
||||
ex_cudnn_path = search_file([cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], l, prefer_version)
|
||||
ctypes.CDLL(ex_cudnn_path, dlopen_flags)
|
||||
|
||||
# dynamic link cuda library
|
||||
|
|
|
@ -74,5 +74,9 @@ class TestTransposeOp(unittest.TestCase):
|
|||
b = a.transpose()
|
||||
assert (a.data.transpose() == b.data).all()
|
||||
|
||||
a = jt.zeros((1,1))
|
||||
b = a.transpose((1,0))
|
||||
b.sync()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue