auto cuda downloader

This commit is contained in:
Dun Liang 2021-05-28 15:46:05 +08:00
parent 3ac160cb9a
commit 07eb4a7a0e
10 changed files with 117 additions and 30 deletions

View File

@ -9,8 +9,8 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.10'
from . import lock
__version__ = '1.2.3.11'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
ori_float = float
@ -59,7 +59,7 @@ def safeunpickle(path):
if path.startswith("https:") or path.startswith("http:"):
base = path.split("/")[-1]
fname = os.path.join(compiler.ck_path, base)
from jittor.utils.misc import download_url_to_local
from jittor_utils.misc import download_url_to_local
download_url_to_local(path, base, compiler.ck_path, None)
path = fname
if path.endswith(".pth"):

View File

@ -7,11 +7,12 @@
import os, sys, shutil
from .compiler import *
from jittor_utils import run_cmd, get_version, get_int_version
from jittor.utils.misc import download_url_to_local
from jittor_utils.misc import download_url_to_local
def search_file(dirs, name, prefer_version=()):
for d in dirs:
fname = os.path.join(d, name)
prefer_version = tuple( str(p) for p in prefer_version )
for i in range(len(prefer_version),-1,-1):
vname = ".".join((fname,)+prefer_version[:i])
if os.path.isfile(vname):
@ -106,7 +107,7 @@ def install_cub(root_folder):
with tarfile.open(fullname, "r") as tar:
tar.extractall(root_folder)
assert 0 == os.system(f"cd {dirname}/examples && "
f"{nvcc_path} device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
f"{nvcc_path} --cudart=shared -ccbin=\"{cc_path}\" device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
if core.get_device_count():
assert 0 == os.system(f"cd {dirname}/examples && ./test")
return dirname
@ -153,9 +154,11 @@ def setup_cuda_extern():
line = traceback.format_exc()
LOG.w(f"CUDA found but {lib_name} is not loaded:\n{line}")
if lib_name == "cudnn":
LOG.w(f"Develop version of CUDNN not found, "
"please refer to CUDA offical tar file installation: "
"https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar")
LOG.w(f"""Develop version of CUDNN not found,
please refer to CUDA offical tar file installation:
https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar
or you can let jittor install cuda and cudnn for you:
>>> python3.{sys.version_info.minor} -m jittor_utils.install_cuda""")
def setup_cuda_lib(lib_name, link=True, extra_flags=""):
globals()[lib_name+"_ops"] = None
@ -179,6 +182,12 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
prefer_version = ("8",)
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so", prefer_version)
if lib_name == "cublas":
# manual link libcublasLt.so
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"libcublasLt.so", nvcc_version)
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
if lib_name == "cudnn":
# cudnn cannot found libcudnn_cnn_train.so.8, we manual link for it.
if nvcc_version >= (11,0,0):
@ -241,7 +250,7 @@ def install_cutt(root_folder):
if len(flags.cuda_archs):
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
run_cmd(f"make NVCC_GENCODE='{arch_flag}' nvcc_path='{nvcc_path}'", cwd=dirname)
run_cmd(f"make NVCC_GENCODE='{arch_flag} --cudart=shared -ccbin=\"{cc_path}\" ' nvcc_path='{nvcc_path}'", cwd=dirname)
return dirname
def setup_cutt():
@ -317,7 +326,7 @@ def install_nccl(root_folder):
if len(flags.cuda_archs):
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
run_cmd(f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag}' ", cwd=dirname)
run_cmd(f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag} --cudart=shared -ccbin=\"{cc_path}\" ' ", cwd=dirname)
return dirname
def setup_nccl():

View File

@ -18,7 +18,8 @@ from ctypes.util import find_library
import jittor_utils as jit_utils
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
from . import pyjt_compiler
from . import lock
from jittor_utils import lock
from jittor_utils import install_cuda
from jittor import __version__
def find_jittor_path():
@ -867,6 +868,11 @@ if os.path.isfile(ex_python_path):
python_path = ex_python_path
py3_config_path = jit_utils.py3_config_path
nvcc_path = env_or_try_find('nvcc_path', 'nvcc') or try_find_exe('/usr/local/cuda/bin/nvcc') or try_find_exe('/usr/bin/nvcc')
if not nvcc_path:
cuda_driver = install_cuda.get_cuda_driver()
nvcc_path = install_cuda.install_cuda()
if nvcc_path:
nvcc_path = try_find_exe(nvcc_path)
gdb_path = try_find_exe('gdb')
addr2line_path = try_find_exe('addr2line')
has_pybt = check_pybt(gdb_path, python_path)

View File

@ -12,7 +12,7 @@ import gzip
from PIL import Image
# our lib jittor import
from jittor.dataset.dataset import Dataset, dataset_root
from jittor.utils.misc import ensure_dir, download_url_to_local
from jittor_utils.misc import ensure_dir, download_url_to_local
import jittor as jt
import jittor.transform as trans

View File

@ -24,8 +24,8 @@ void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr
tmp_data_t tmp_data;
void numpy_init() {
PyObjHolder np(PyImport_ImportModule("numpy.core.multiarray"));
PyObjHolder api(PyObject_GetAttrString(np.obj, "_ARRAY_API"));
PyObjHolder np(PyImport_ImportModule("numpy.core.multiarray"), "numpy is not installed");
PyObjHolder api(PyObject_GetAttrString(np.obj, "_ARRAY_API"), "numpy _ARRAY_API not found, you may need to reinstall numpy");
PyArray_API = (void **) PyCapsule_GetPointer(api.obj, NULL);
#define fill(name, i) name = (decltype(name))PyArray_API[i]

View File

@ -25,6 +25,17 @@ struct PyObjHolder {
LOGf << "Python error occur";
}
}
inline void assign(PyObject* obj, const char* err_msg) {
if (!obj) {
LOGf << err_msg;
}
this->obj = obj;
}
inline PyObjHolder(PyObject* obj, const char* err_msg) : obj(obj) {
if (!obj) {
LOGf << err_msg;
}
}
inline ~PyObjHolder() {
if (obj) Py_DECREF(obj);
}

View File

@ -16,6 +16,8 @@ import contextlib
import threading
import time
from ctypes import cdll
import shutil
import urllib.request
class LogWarper:
def __init__(self):
@ -180,7 +182,6 @@ def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
def download(url, filename):
from six.moves import urllib
if os.path.isfile(filename):
if os.path.getsize(filename) > 100:
return
@ -246,7 +247,9 @@ def get_int_version(output):
return ver
def find_exe(name, check_version=True, silent=False):
output = run_cmd(f'which {name}', err_msg=f'{name} not found')
output = shutil.which(name)
if not output:
raise RuntimeError(f"{name} not found")
if check_version:
version = get_version(name)
else:
@ -297,6 +300,6 @@ for py3_config_path in py3_config_paths:
break
else:
raise RuntimeError(f"python3.{sys.version_info.minor}-config "
"not found in {py3_config_paths}, please specify "
"enviroment variable 'python_config_path',"
" or apt install python3.{sys.version_info.minor}-dev")
f"not found in {py3_config_paths}, please specify "
f"enviroment variable 'python_config_path',"
f" or apt install python3.{sys.version_info.minor}-dev")

View File

@ -0,0 +1,61 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import os
import sys
import subprocess as sp
import jittor_utils as jit_utils
from jittor_utils import LOG
from jittor_utils.misc import download_url_to_local
import pathlib
def get_cuda_driver():
ret, out = sp.getstatusoutput("nvidia-smi -q -u")
if ret != 0: return None
try:
out = out.lower()
out = out.split('cuda version')[1] \
.split(':')[1] \
.splitlines()[0] \
.strip()
out = [ int(s) for s in out.split('.')]
return out
except:
return None
def install_cuda():
cuda_driver_version = get_cuda_driver()
LOG.i("cuda_driver_version: ", cuda_driver_version)
cuda_tgz = "cuda10.2_cudnn7_linux.tgz"
md5 = "a78f296746d97e9d76615289c2fe98ac"
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc")
nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64")
sys.path.append(nvcc_lib_path)
new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path
os.environ["LD_LIBRARY_PATH"] = new_ld_path
if os.path.isfile(nvcc_path):
return nvcc_path
os.makedirs(jtcuda_path, exist_ok=True)
cuda_tgz_path = os.path.join(jtcuda_path, cuda_tgz)
download_url_to_local("https://cg.cs.tsinghua.edu.cn/jittor/assets/"+cuda_tgz, cuda_tgz, jtcuda_path, md5)
import tarfile
with tarfile.open(cuda_tgz_path, "r") as tar:
tar.extractall(cuda_tgz_path[:-4])
assert os.path.isfile(nvcc_path)
return nvcc_path
if __name__ == "__main__":
nvcc_path = install_cuda()
LOG.i("nvcc is installed at ", nvcc_path)

View File

@ -7,12 +7,11 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import os
from six.moves import urllib
import hashlib
import urllib.request
from tqdm import tqdm
from .. import lock
from jittor_utils import lock
def ensure_dir(dir_path):
if not os.path.isdir(dir_path):
@ -41,14 +40,11 @@ def download_url_to_local(url, filename, root_folder, md5):
if check_file_exist(file_path, md5):
return
else:
try:
print('Downloading ' + url + ' to ' + file_path)
urllib.request.urlretrieve(
url, file_path,
reporthook=_progress()
)
except(urllib.error.URLError, IOError) as e:
raise e
print('Downloading ' + url + ' to ' + file_path)
urllib.request.urlretrieve(
url, file_path,
reporthook=_progress()
)
if not check_file_exist(file_path, md5):
raise RuntimeError("File downloads failed.")
@ -72,3 +68,4 @@ def calculate_md5(file_path, chunk_size=1024 * 1024):
def check_md5(file_path, md5, **kwargs):
return md5 == calculate_md5(file_path, **kwargs)