mirror of https://github.com/Jittor/Jittor
auto cuda downloader
This commit is contained in:
parent
3ac160cb9a
commit
07eb4a7a0e
|
@ -9,8 +9,8 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.2.3.10'
|
__version__ = '1.2.3.11'
|
||||||
from . import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
ori_float = float
|
ori_float = float
|
||||||
|
@ -59,7 +59,7 @@ def safeunpickle(path):
|
||||||
if path.startswith("https:") or path.startswith("http:"):
|
if path.startswith("https:") or path.startswith("http:"):
|
||||||
base = path.split("/")[-1]
|
base = path.split("/")[-1]
|
||||||
fname = os.path.join(compiler.ck_path, base)
|
fname = os.path.join(compiler.ck_path, base)
|
||||||
from jittor.utils.misc import download_url_to_local
|
from jittor_utils.misc import download_url_to_local
|
||||||
download_url_to_local(path, base, compiler.ck_path, None)
|
download_url_to_local(path, base, compiler.ck_path, None)
|
||||||
path = fname
|
path = fname
|
||||||
if path.endswith(".pth"):
|
if path.endswith(".pth"):
|
||||||
|
|
|
@ -7,11 +7,12 @@
|
||||||
import os, sys, shutil
|
import os, sys, shutil
|
||||||
from .compiler import *
|
from .compiler import *
|
||||||
from jittor_utils import run_cmd, get_version, get_int_version
|
from jittor_utils import run_cmd, get_version, get_int_version
|
||||||
from jittor.utils.misc import download_url_to_local
|
from jittor_utils.misc import download_url_to_local
|
||||||
|
|
||||||
def search_file(dirs, name, prefer_version=()):
|
def search_file(dirs, name, prefer_version=()):
|
||||||
for d in dirs:
|
for d in dirs:
|
||||||
fname = os.path.join(d, name)
|
fname = os.path.join(d, name)
|
||||||
|
prefer_version = tuple( str(p) for p in prefer_version )
|
||||||
for i in range(len(prefer_version),-1,-1):
|
for i in range(len(prefer_version),-1,-1):
|
||||||
vname = ".".join((fname,)+prefer_version[:i])
|
vname = ".".join((fname,)+prefer_version[:i])
|
||||||
if os.path.isfile(vname):
|
if os.path.isfile(vname):
|
||||||
|
@ -106,7 +107,7 @@ def install_cub(root_folder):
|
||||||
with tarfile.open(fullname, "r") as tar:
|
with tarfile.open(fullname, "r") as tar:
|
||||||
tar.extractall(root_folder)
|
tar.extractall(root_folder)
|
||||||
assert 0 == os.system(f"cd {dirname}/examples && "
|
assert 0 == os.system(f"cd {dirname}/examples && "
|
||||||
f"{nvcc_path} device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
|
f"{nvcc_path} --cudart=shared -ccbin=\"{cc_path}\" device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
|
||||||
if core.get_device_count():
|
if core.get_device_count():
|
||||||
assert 0 == os.system(f"cd {dirname}/examples && ./test")
|
assert 0 == os.system(f"cd {dirname}/examples && ./test")
|
||||||
return dirname
|
return dirname
|
||||||
|
@ -153,9 +154,11 @@ def setup_cuda_extern():
|
||||||
line = traceback.format_exc()
|
line = traceback.format_exc()
|
||||||
LOG.w(f"CUDA found but {lib_name} is not loaded:\n{line}")
|
LOG.w(f"CUDA found but {lib_name} is not loaded:\n{line}")
|
||||||
if lib_name == "cudnn":
|
if lib_name == "cudnn":
|
||||||
LOG.w(f"Develop version of CUDNN not found, "
|
LOG.w(f"""Develop version of CUDNN not found,
|
||||||
"please refer to CUDA offical tar file installation: "
|
please refer to CUDA offical tar file installation:
|
||||||
"https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar")
|
https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar
|
||||||
|
or you can let jittor install cuda and cudnn for you:
|
||||||
|
>>> python3.{sys.version_info.minor} -m jittor_utils.install_cuda""")
|
||||||
|
|
||||||
def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
||||||
globals()[lib_name+"_ops"] = None
|
globals()[lib_name+"_ops"] = None
|
||||||
|
@ -179,6 +182,12 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
||||||
prefer_version = ("8",)
|
prefer_version = ("8",)
|
||||||
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so", prefer_version)
|
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so", prefer_version)
|
||||||
|
|
||||||
|
if lib_name == "cublas":
|
||||||
|
# manual link libcublasLt.so
|
||||||
|
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"libcublasLt.so", nvcc_version)
|
||||||
|
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
|
||||||
|
|
||||||
|
|
||||||
if lib_name == "cudnn":
|
if lib_name == "cudnn":
|
||||||
# cudnn cannot found libcudnn_cnn_train.so.8, we manual link for it.
|
# cudnn cannot found libcudnn_cnn_train.so.8, we manual link for it.
|
||||||
if nvcc_version >= (11,0,0):
|
if nvcc_version >= (11,0,0):
|
||||||
|
@ -241,7 +250,7 @@ def install_cutt(root_folder):
|
||||||
if len(flags.cuda_archs):
|
if len(flags.cuda_archs):
|
||||||
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
|
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
|
||||||
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||||
run_cmd(f"make NVCC_GENCODE='{arch_flag}' nvcc_path='{nvcc_path}'", cwd=dirname)
|
run_cmd(f"make NVCC_GENCODE='{arch_flag} --cudart=shared -ccbin=\"{cc_path}\" ' nvcc_path='{nvcc_path}'", cwd=dirname)
|
||||||
return dirname
|
return dirname
|
||||||
|
|
||||||
def setup_cutt():
|
def setup_cutt():
|
||||||
|
@ -317,7 +326,7 @@ def install_nccl(root_folder):
|
||||||
if len(flags.cuda_archs):
|
if len(flags.cuda_archs):
|
||||||
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
|
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
|
||||||
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||||
run_cmd(f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag}' ", cwd=dirname)
|
run_cmd(f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag} --cudart=shared -ccbin=\"{cc_path}\" ' ", cwd=dirname)
|
||||||
return dirname
|
return dirname
|
||||||
|
|
||||||
def setup_nccl():
|
def setup_nccl():
|
||||||
|
|
|
@ -18,7 +18,8 @@ from ctypes.util import find_library
|
||||||
import jittor_utils as jit_utils
|
import jittor_utils as jit_utils
|
||||||
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
|
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
|
||||||
from . import pyjt_compiler
|
from . import pyjt_compiler
|
||||||
from . import lock
|
from jittor_utils import lock
|
||||||
|
from jittor_utils import install_cuda
|
||||||
from jittor import __version__
|
from jittor import __version__
|
||||||
|
|
||||||
def find_jittor_path():
|
def find_jittor_path():
|
||||||
|
@ -867,6 +868,11 @@ if os.path.isfile(ex_python_path):
|
||||||
python_path = ex_python_path
|
python_path = ex_python_path
|
||||||
py3_config_path = jit_utils.py3_config_path
|
py3_config_path = jit_utils.py3_config_path
|
||||||
nvcc_path = env_or_try_find('nvcc_path', 'nvcc') or try_find_exe('/usr/local/cuda/bin/nvcc') or try_find_exe('/usr/bin/nvcc')
|
nvcc_path = env_or_try_find('nvcc_path', 'nvcc') or try_find_exe('/usr/local/cuda/bin/nvcc') or try_find_exe('/usr/bin/nvcc')
|
||||||
|
if not nvcc_path:
|
||||||
|
cuda_driver = install_cuda.get_cuda_driver()
|
||||||
|
nvcc_path = install_cuda.install_cuda()
|
||||||
|
if nvcc_path:
|
||||||
|
nvcc_path = try_find_exe(nvcc_path)
|
||||||
gdb_path = try_find_exe('gdb')
|
gdb_path = try_find_exe('gdb')
|
||||||
addr2line_path = try_find_exe('addr2line')
|
addr2line_path = try_find_exe('addr2line')
|
||||||
has_pybt = check_pybt(gdb_path, python_path)
|
has_pybt = check_pybt(gdb_path, python_path)
|
||||||
|
|
|
@ -12,7 +12,7 @@ import gzip
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
# our lib jittor import
|
# our lib jittor import
|
||||||
from jittor.dataset.dataset import Dataset, dataset_root
|
from jittor.dataset.dataset import Dataset, dataset_root
|
||||||
from jittor.utils.misc import ensure_dir, download_url_to_local
|
from jittor_utils.misc import ensure_dir, download_url_to_local
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import jittor.transform as trans
|
import jittor.transform as trans
|
||||||
|
|
||||||
|
|
|
@ -24,8 +24,8 @@ void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr
|
||||||
tmp_data_t tmp_data;
|
tmp_data_t tmp_data;
|
||||||
|
|
||||||
void numpy_init() {
|
void numpy_init() {
|
||||||
PyObjHolder np(PyImport_ImportModule("numpy.core.multiarray"));
|
PyObjHolder np(PyImport_ImportModule("numpy.core.multiarray"), "numpy is not installed");
|
||||||
PyObjHolder api(PyObject_GetAttrString(np.obj, "_ARRAY_API"));
|
PyObjHolder api(PyObject_GetAttrString(np.obj, "_ARRAY_API"), "numpy _ARRAY_API not found, you may need to reinstall numpy");
|
||||||
PyArray_API = (void **) PyCapsule_GetPointer(api.obj, NULL);
|
PyArray_API = (void **) PyCapsule_GetPointer(api.obj, NULL);
|
||||||
|
|
||||||
#define fill(name, i) name = (decltype(name))PyArray_API[i]
|
#define fill(name, i) name = (decltype(name))PyArray_API[i]
|
||||||
|
|
|
@ -25,6 +25,17 @@ struct PyObjHolder {
|
||||||
LOGf << "Python error occur";
|
LOGf << "Python error occur";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
inline void assign(PyObject* obj, const char* err_msg) {
|
||||||
|
if (!obj) {
|
||||||
|
LOGf << err_msg;
|
||||||
|
}
|
||||||
|
this->obj = obj;
|
||||||
|
}
|
||||||
|
inline PyObjHolder(PyObject* obj, const char* err_msg) : obj(obj) {
|
||||||
|
if (!obj) {
|
||||||
|
LOGf << err_msg;
|
||||||
|
}
|
||||||
|
}
|
||||||
inline ~PyObjHolder() {
|
inline ~PyObjHolder() {
|
||||||
if (obj) Py_DECREF(obj);
|
if (obj) Py_DECREF(obj);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,8 @@ import contextlib
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from ctypes import cdll
|
from ctypes import cdll
|
||||||
|
import shutil
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
class LogWarper:
|
class LogWarper:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -180,7 +182,6 @@ def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
|
||||||
|
|
||||||
|
|
||||||
def download(url, filename):
|
def download(url, filename):
|
||||||
from six.moves import urllib
|
|
||||||
if os.path.isfile(filename):
|
if os.path.isfile(filename):
|
||||||
if os.path.getsize(filename) > 100:
|
if os.path.getsize(filename) > 100:
|
||||||
return
|
return
|
||||||
|
@ -246,7 +247,9 @@ def get_int_version(output):
|
||||||
return ver
|
return ver
|
||||||
|
|
||||||
def find_exe(name, check_version=True, silent=False):
|
def find_exe(name, check_version=True, silent=False):
|
||||||
output = run_cmd(f'which {name}', err_msg=f'{name} not found')
|
output = shutil.which(name)
|
||||||
|
if not output:
|
||||||
|
raise RuntimeError(f"{name} not found")
|
||||||
if check_version:
|
if check_version:
|
||||||
version = get_version(name)
|
version = get_version(name)
|
||||||
else:
|
else:
|
||||||
|
@ -297,6 +300,6 @@ for py3_config_path in py3_config_paths:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"python3.{sys.version_info.minor}-config "
|
raise RuntimeError(f"python3.{sys.version_info.minor}-config "
|
||||||
"not found in {py3_config_paths}, please specify "
|
f"not found in {py3_config_paths}, please specify "
|
||||||
"enviroment variable 'python_config_path',"
|
f"enviroment variable 'python_config_path',"
|
||||||
" or apt install python3.{sys.version_info.minor}-dev")
|
f" or apt install python3.{sys.version_info.minor}-dev")
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||||
|
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import subprocess as sp
|
||||||
|
import jittor_utils as jit_utils
|
||||||
|
from jittor_utils import LOG
|
||||||
|
from jittor_utils.misc import download_url_to_local
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
def get_cuda_driver():
|
||||||
|
ret, out = sp.getstatusoutput("nvidia-smi -q -u")
|
||||||
|
if ret != 0: return None
|
||||||
|
try:
|
||||||
|
out = out.lower()
|
||||||
|
out = out.split('cuda version')[1] \
|
||||||
|
.split(':')[1] \
|
||||||
|
.splitlines()[0] \
|
||||||
|
.strip()
|
||||||
|
out = [ int(s) for s in out.split('.')]
|
||||||
|
return out
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def install_cuda():
|
||||||
|
cuda_driver_version = get_cuda_driver()
|
||||||
|
LOG.i("cuda_driver_version: ", cuda_driver_version)
|
||||||
|
|
||||||
|
cuda_tgz = "cuda10.2_cudnn7_linux.tgz"
|
||||||
|
md5 = "a78f296746d97e9d76615289c2fe98ac"
|
||||||
|
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
|
||||||
|
nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc")
|
||||||
|
nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64")
|
||||||
|
sys.path.append(nvcc_lib_path)
|
||||||
|
new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path
|
||||||
|
os.environ["LD_LIBRARY_PATH"] = new_ld_path
|
||||||
|
|
||||||
|
if os.path.isfile(nvcc_path):
|
||||||
|
return nvcc_path
|
||||||
|
|
||||||
|
os.makedirs(jtcuda_path, exist_ok=True)
|
||||||
|
cuda_tgz_path = os.path.join(jtcuda_path, cuda_tgz)
|
||||||
|
download_url_to_local("https://cg.cs.tsinghua.edu.cn/jittor/assets/"+cuda_tgz, cuda_tgz, jtcuda_path, md5)
|
||||||
|
|
||||||
|
|
||||||
|
import tarfile
|
||||||
|
with tarfile.open(cuda_tgz_path, "r") as tar:
|
||||||
|
tar.extractall(cuda_tgz_path[:-4])
|
||||||
|
|
||||||
|
assert os.path.isfile(nvcc_path)
|
||||||
|
return nvcc_path
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
nvcc_path = install_cuda()
|
||||||
|
LOG.i("nvcc is installed at ", nvcc_path)
|
|
@ -7,12 +7,11 @@
|
||||||
# This file is subject to the terms and conditions defined in
|
# This file is subject to the terms and conditions defined in
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
import jittor as jt
|
|
||||||
import os
|
import os
|
||||||
from six.moves import urllib
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import urllib.request
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .. import lock
|
from jittor_utils import lock
|
||||||
|
|
||||||
def ensure_dir(dir_path):
|
def ensure_dir(dir_path):
|
||||||
if not os.path.isdir(dir_path):
|
if not os.path.isdir(dir_path):
|
||||||
|
@ -41,14 +40,11 @@ def download_url_to_local(url, filename, root_folder, md5):
|
||||||
if check_file_exist(file_path, md5):
|
if check_file_exist(file_path, md5):
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
try:
|
print('Downloading ' + url + ' to ' + file_path)
|
||||||
print('Downloading ' + url + ' to ' + file_path)
|
urllib.request.urlretrieve(
|
||||||
urllib.request.urlretrieve(
|
url, file_path,
|
||||||
url, file_path,
|
reporthook=_progress()
|
||||||
reporthook=_progress()
|
)
|
||||||
)
|
|
||||||
except(urllib.error.URLError, IOError) as e:
|
|
||||||
raise e
|
|
||||||
if not check_file_exist(file_path, md5):
|
if not check_file_exist(file_path, md5):
|
||||||
raise RuntimeError("File downloads failed.")
|
raise RuntimeError("File downloads failed.")
|
||||||
|
|
||||||
|
@ -72,3 +68,4 @@ def calculate_md5(file_path, chunk_size=1024 * 1024):
|
||||||
|
|
||||||
def check_md5(file_path, md5, **kwargs):
|
def check_md5(file_path, md5, **kwargs):
|
||||||
return md5 == calculate_md5(file_path, **kwargs)
|
return md5 == calculate_md5(file_path, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue