Merge branch 'master' into macOS

This commit is contained in:
lzhengning 2021-06-04 13:26:29 +08:00
commit 23106e0606
35 changed files with 262 additions and 99 deletions

View File

@ -9,8 +9,8 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.9'
from . import lock
__version__ = '1.2.3.14'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
ori_float = float
@ -59,7 +59,7 @@ def safeunpickle(path):
if path.startswith("https:") or path.startswith("http:"):
base = path.split("/")[-1]
fname = os.path.join(compiler.ck_path, base)
from jittor.utils.misc import download_url_to_local
from jittor_utils.misc import download_url_to_local
download_url_to_local(path, base, compiler.ck_path, None)
path = fname
if path.endswith(".pth"):

View File

@ -8,11 +8,12 @@ import os, sys, shutil
import platform
from .compiler import *
from jittor_utils import run_cmd, get_version, get_int_version
from jittor.utils.misc import download_url_to_local
from jittor_utils.misc import download_url_to_local
def search_file(dirs, name, prefer_version=()):
for d in dirs:
fname = os.path.join(d, name)
prefer_version = tuple( str(p) for p in prefer_version )
for i in range(len(prefer_version),-1,-1):
vname = ".".join((fname,)+prefer_version[:i])
if os.path.isfile(vname):
@ -114,7 +115,7 @@ def install_cub(root_folder):
with tarfile.open(fullname, "r") as tar:
tar.extractall(root_folder)
assert 0 == os.system(f"cd {dirname}/examples && "
f"{nvcc_path} device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
f"{nvcc_path} --cudart=shared -ccbin=\"{cc_path}\" device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
if core.get_device_count():
assert 0 == os.system(f"cd {dirname}/examples && ./test")
return dirname
@ -161,9 +162,11 @@ def setup_cuda_extern():
line = traceback.format_exc()
LOG.w(f"CUDA found but {lib_name} is not loaded:\n{line}")
if lib_name == "cudnn":
LOG.w(f"Develop version of CUDNN not found, "
"please refer to CUDA offical tar file installation: "
"https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar")
LOG.w(f"""Develop version of CUDNN not found,
please refer to CUDA offical tar file installation:
https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar
or you can let jittor install cuda and cudnn for you:
>>> python3.{sys.version_info.minor} -m jittor_utils.install_cuda""")
def setup_cuda_lib(lib_name, link=True, extra_flags=""):
globals()[lib_name+"_ops"] = None
@ -187,6 +190,12 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
prefer_version = ("8",)
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so", prefer_version)
if lib_name == "cublas" and nvcc_version[0] >= 10:
# manual link libcublasLt.so
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"libcublasLt.so", nvcc_version)
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
if lib_name == "cudnn":
# cudnn cannot found libcudnn_cnn_train.so.8, we manual link for it.
if nvcc_version >= (11,0,0):
@ -249,7 +258,7 @@ def install_cutt(root_folder):
if len(flags.cuda_archs):
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
run_cmd(f"make NVCC_GENCODE='{arch_flag}' nvcc_path='{nvcc_path}'", cwd=dirname)
run_cmd(f"make NVCC_GENCODE='{arch_flag} --cudart=shared -ccbin=\"{cc_path}\" ' nvcc_path='{nvcc_path}'", cwd=dirname)
return dirname
def setup_cutt():
@ -325,7 +334,7 @@ def install_nccl(root_folder):
if len(flags.cuda_archs):
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
run_cmd(f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag}' ", cwd=dirname)
run_cmd(f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag} --cudart=shared -ccbin=\"{cc_path}\" ' ", cwd=dirname)
return dirname
def setup_nccl():

View File

@ -19,7 +19,8 @@ from ctypes.util import find_library
import jittor_utils as jit_utils
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
from . import pyjt_compiler
from . import lock
from jittor_utils import lock
from jittor_utils import install_cuda
from jittor import __version__
def find_jittor_path():
@ -766,7 +767,7 @@ def compile_extern():
LOG.vv(f"Compile extern llvm passes: {str(files)}")
def check_cuda():
if nvcc_path == "":
if not nvcc_path:
return
global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include, cuda_home
cuda_dir = os.path.dirname(get_full_path_of_executable(nvcc_path))
@ -775,6 +776,9 @@ def check_cuda():
# assert cuda_dir.endswith("bin") and "cuda" in cuda_dir.lower(), f"Wrong cuda_dir: {cuda_dir}"
cuda_include = os.path.abspath(os.path.join(cuda_dir, "..", "include"))
cuda_lib = os.path.abspath(os.path.join(cuda_dir, "..", "lib64"))
if nvcc_path == "/usr/bin/nvcc":
# this nvcc is install by package manager
cuda_lib = "/usr/lib/x86_64-linux-gnu"
cuda_include2 = os.path.join(jittor_path, "extern","cuda","inc")
cc_flags += f" -DHAS_CUDA -I'{cuda_include}' -I'{cuda_include2}' "
core_link_flags += f" -lcudart -L'{cuda_lib}' "
@ -788,6 +792,7 @@ def check_cache_compile():
"src/utils/log.cc",
"src/utils/tracer.cc",
"src/utils/jit_utils.cc",
"src/utils/str_utils.cc",
]
global jit_utils_core_files
jit_utils_core_files = files
@ -860,21 +865,37 @@ check_debug_flags()
sys.path.append(cache_path)
LOG.i(f"Jittor({__version__}) src: {jittor_path}")
LOG.i(f"{jit_utils.cc_type} at {jit_utils.cc_path}")
LOG.i(f"{jit_utils.cc_type} at {jit_utils.cc_path}{jit_utils.get_version(jit_utils.cc_path)}")
LOG.i(f"cache_path: {cache_path}")
with jit_utils.import_scope(import_flags):
jit_utils.try_import_jit_utils_core()
python_path = sys.executable
# something python do not return the correct sys executable
# sometime python do not return the correct sys executable
# this will happend when multiple python version installed
ex_python_path = python_path + '.' + str(sys.version_info.minor)
if os.path.isfile(ex_python_path):
python_path = ex_python_path
py3_config_path = jit_utils.py3_config_path
nvcc_path = env_or_try_find('nvcc_path', '/usr/local/cuda/bin/nvcc')
# if jtcuda is already installed
nvcc_path = None
if install_cuda.has_installation():
nvcc_path = install_cuda.install_cuda()
if nvcc_path:
nvcc_path = try_find_exe(nvcc_path)
# check system installed cuda
if not nvcc_path:
nvcc_path = env_or_try_find('nvcc_path', 'nvcc') or try_find_exe('/usr/local/cuda/bin/nvcc') or try_find_exe('/usr/bin/nvcc')
# if system has no cuda, install jtcuda
if not nvcc_path:
nvcc_path = install_cuda.install_cuda()
if nvcc_path:
nvcc_path = try_find_exe(nvcc_path)
if not nvcc_path:
nvcc_path = ""
gdb_path = try_find_exe('gdb')
addr2line_path = try_find_exe('addr2line')
has_pybt = check_pybt(gdb_path, python_path)

View File

@ -12,7 +12,7 @@ import gzip
from PIL import Image
# our lib jittor import
from jittor.dataset.dataset import Dataset, dataset_root
from jittor.utils.misc import ensure_dir, download_url_to_local
from jittor_utils.misc import ensure_dir, download_url_to_local
import jittor as jt
import jittor.transform as trans

View File

@ -8,7 +8,7 @@
#include "var.h"
#include "cub_test_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#ifdef JIT
#include "cub_test.h"

View File

@ -8,7 +8,7 @@
#include "var.h"
#include "cudnn_test_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
int cudnn_test_entry( int argc, char** argv );

View File

@ -5,7 +5,7 @@
// ***************************************************************
#include "var.h"
#include "cutt_test_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#ifdef JIT
#include "cutt.h"

View File

@ -8,7 +8,7 @@
// ***************************************************************
#include "var.h"
#include "nccl_all_reduce_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include <nccl.h>
#include <cuda_runtime.h>

View File

@ -8,7 +8,7 @@
// ***************************************************************
#include "var.h"
#include "nccl_broadcast_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include <nccl.h>
#include <cuda_runtime.h>

View File

@ -8,7 +8,7 @@
// ***************************************************************
#include "var.h"
#include "nccl_reduce_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include <nccl.h>
#include <cuda_runtime.h>

View File

@ -5,7 +5,7 @@
// ***************************************************************
#include "var.h"
#include "nccl_test_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include "nccl_warper.h"

View File

@ -10,7 +10,7 @@
#include "var.h"
#include "mpi_all_reduce_op.h"
#include "ops/op_register.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include "misc/cuda_flags.h"
namespace jittor {

View File

@ -10,7 +10,7 @@
#include "var.h"
#include "mpi_broadcast_op.h"
#include "ops/op_register.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include "misc/cuda_flags.h"
namespace jittor {

View File

@ -10,7 +10,7 @@
#include "var.h"
#include "mpi_reduce_op.h"
#include "ops/op_register.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include "misc/cuda_flags.h"
namespace jittor {

View File

@ -7,7 +7,7 @@
#include "var.h"
#include "mpi_test_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
namespace jittor {

View File

@ -37,6 +37,7 @@ void cleanup() {
static void init_cuda_devices() {
#ifdef HAS_CUDA
if (cuda_archs.size()) return;
int count=0;
cudaGetDeviceCount(&count);
for (int i=0; i<count; i++) {

View File

@ -7,7 +7,7 @@
#include <sys/mman.h>
#include <sstream>
#include "jit_key.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
namespace jittor {

View File

@ -12,7 +12,7 @@
#include "jit_compiler.h"
#include "utils/cache_compile.h"
#include "opt/tuner_manager.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include "ops/op_register.h"
#include "ops/array_op.h"
#include "lock.h"

View File

@ -5,7 +5,7 @@
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "opt/expr.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
namespace jittor {
namespace expr {

View File

@ -28,44 +28,6 @@ std::ostream& operator<<(std::ostream& os, KernelIR& ir) {
return os << ir.to_string();
}
bool startswith(const string& a, const string& b, uint start, bool equal, uint end) {
if (!end) end = a.size();
if (b.size()+start > end) return false;
if (equal && b.size()+start != end) return false;
for (uint i=0; i<b.size(); i++)
if (a[i+start] != b[i]) return false;
return true;
}
bool endswith(const string& a, const string& b) {
if (a.size() < b.size()) return false;
return startswith(a, b, a.size()-b.size());
}
vector<string> split(const string& s, const string& sep, int max_split) {
vector<string> ret;
int pos = -1, pos_next;
while (1) {
pos_next = s.find(sep, pos+1);
if (pos_next == (int)string::npos || (int)ret.size() == max_split-1) {
ret.push_back(s.substr(pos+sep.size()));
return ret;
}
ret.push_back(s.substr(pos+sep.size(), pos_next-pos-sep.size()));
pos = pos_next;
}
ASSERT(max_split==0);
return ret;
}
string strip(const string& s) {
int i=0;
while (i<s.size() && (s[i]==' ' || s[i]=='\t' || s[i]=='\n')) i++;
int j = s.size();
while (j>i && (s[j]==' ' || s[j]=='\t' || s[j]=='\n')) j--;
return s.substr(i,j-i);
}
void KernelIR::del_scope() {
if (father && (type=="define" || type=="func" || type=="macro")) {
father->scope[attrs["lvalue"]].remove(this);

View File

@ -6,7 +6,7 @@
// ***************************************************************
#pragma once
#include "common.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
namespace jittor {

View File

@ -11,7 +11,7 @@
#include "opt/expr.h"
#include "opt/pass_manager.h"
#include "opt/pass/float_atomic_fix_pass.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
namespace jittor {

View File

@ -32,7 +32,7 @@
#include "opt/pass/fake_main_pass.h"
#include "opt/pass/check_cache_pass.h"
#include "opt/pass/mark_raw_pass.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
namespace jittor {

View File

@ -12,7 +12,7 @@
#include "pyjt/py_obj_holder.h"
#include "pyjt/py_converter.h"
#include "pybind/py_var_tracer.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include "op.h"
#include "var.h"
#include "fused_op.h"

View File

@ -24,8 +24,8 @@ void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr
tmp_data_t tmp_data;
void numpy_init() {
PyObjHolder np(PyImport_ImportModule("numpy.core.multiarray"));
PyObjHolder api(PyObject_GetAttrString(np.obj, "_ARRAY_API"));
PyObjHolder np(PyImport_ImportModule("numpy.core.multiarray"), "numpy is not installed");
PyObjHolder api(PyObject_GetAttrString(np.obj, "_ARRAY_API"), "numpy _ARRAY_API not found, you may need to reinstall numpy");
PyArray_API = (void **) PyCapsule_GetPointer(api.obj, NULL);
#define fill(name, i) name = (decltype(name))PyArray_API[i]

View File

@ -86,15 +86,19 @@ ArrayOp::ArrayOp(PyObject* obj) {
// use 32-bit by default
if ((auto_convert_64_to_32 || holder.obj)
&& args.dtype.dsize() == 8 && args.ptr) {
auto num = PyArray_Size(arr)/8;
auto size = PyArray_Size(arr);
args.buffer.reset(new char[size]);
auto pre_data = args.ptr;
args.ptr = args.buffer.get();
auto num = size/8;
if (args.dtype.is_int()) {
auto* __restrict__ i64 = (int64*)args.ptr;
auto* __restrict__ i64 = (int64*)pre_data;
auto* __restrict__ i32 = (int32*)args.ptr;
for (int i=0; i<num; i++)
i32[i] = (int32)i64[i];
args.dtype = ns_int32;
} else if (args.dtype.is_float()) {
auto* __restrict__ f64 = (float64*)args.ptr;
auto* __restrict__ f64 = (float64*)pre_data;
auto* __restrict__ f32 = (float32*)args.ptr;
for (int i=0; i<num; i++)
f32[i] = (float32)f64[i];

View File

@ -25,6 +25,17 @@ struct PyObjHolder {
LOGf << "Python error occur";
}
}
inline void assign(PyObject* obj, const char* err_msg) {
if (!obj) {
LOGf << err_msg;
}
this->obj = obj;
}
inline PyObjHolder(PyObject* obj, const char* err_msg) : obj(obj) {
if (!obj) {
LOGf << err_msg;
}
}
inline ~PyObjHolder() {
if (obj) Py_DECREF(obj);
}

View File

@ -13,6 +13,7 @@
#include <unistd.h>
#include "utils/log.h"
#include "utils/mwsr_list.h"
#include "utils/str_utils.h"
namespace jittor {
@ -305,6 +306,26 @@ bool check_vlog(const char* fileline, int verbose) {
return verbose <= log_v;
}
static inline void check_cuda_unsupport_version(const string& output) {
// check error like:
// /usr/include/crt/host_config.h:121:2: error: #error -- unsupported GNU version! gcc versions later than 6 are not supported!
// #error -- unsupported GNU version! gcc versions later than 6 are not supported!
string pat = "crt/host_config.h";
auto id = output.find(pat);
if (id == string::npos) return;
auto end = id + pat.size();
while (id>=0 && !(output[id]==' ' || output[id]=='\t' || output[id]=='\n'))
id--;
id ++;
auto fname = output.substr(id, end-id);
LOGw << R"(
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Dear user, your nvcc and gcc version are not match,
but you can hot fix it by this command:
>>> sudo python3 -c 's=open(")" >> fname >> R"(","r").read().replace("#error", "//#error");open(")" >> fname >> R"(","w").write(s)'
)";
}
int system_popen(const char* cmd) {
char buf[BUFSIZ];
string cmd2;
@ -312,17 +333,20 @@ int system_popen(const char* cmd) {
cmd2 += " 2>&1 ";
FILE *ptr = popen(cmd2.c_str(), "r");
if (!ptr) return -1;
int64 len=0;
string output;
while (fgets(buf, BUFSIZ, ptr) != NULL) {
len += strlen(buf);
output += buf;
std::cerr << buf;
}
if (len) std::cerr.flush();
if (output.size()) std::cerr.flush();
auto ret = pclose(ptr);
if (len<10 && ret) {
if (output.size()<10 && ret) {
// maybe overcommit
return -1;
}
if (ret) {
check_cuda_unsupport_version(output);
}
return ret;
}

View File

@ -125,8 +125,9 @@ struct LogFatalVoidify {
#define LOG_IF(level, cond) _LOG_IF(level, cond, 0)
template<class T> T get_from_env(const char* name,const T& _default) {
auto s = getenv(name);
if (s == NULL) return _default;
auto ss = getenv(name);
if (ss == NULL) return _default;
string s = ss;
std::istringstream is(s);
T env;
if (is >> env) {
@ -135,6 +136,8 @@ template<class T> T get_from_env(const char* name,const T& _default) {
return env;
}
}
if (s.size() && is.eof())
return env;
LOGw << "Load" << name << "from env(" << s << ") failed, use default" << _default;
return _default;
}

View File

@ -0,0 +1,50 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "utils/str_utils.h"
namespace jittor {
bool startswith(const string& a, const string& b, uint start, bool equal, uint end) {
if (!end) end = a.size();
if (b.size()+start > end) return false;
if (equal && b.size()+start != end) return false;
for (uint i=0; i<b.size(); i++)
if (a[i+start] != b[i]) return false;
return true;
}
bool endswith(const string& a, const string& b) {
if (a.size() < b.size()) return false;
return startswith(a, b, a.size()-b.size());
}
vector<string> split(const string& s, const string& sep, int max_split) {
vector<string> ret;
int pos = -1, pos_next;
while (1) {
pos_next = s.find(sep, pos+1);
if (pos_next == (int)string::npos || (int)ret.size() == max_split-1) {
ret.push_back(s.substr(pos+sep.size()));
return ret;
}
ret.push_back(s.substr(pos+sep.size(), pos_next-pos-sep.size()));
pos = pos_next;
}
ASSERT(max_split==0);
return ret;
}
string strip(const string& s) {
int i=0;
while (i<s.size() && (s[i]==' ' || s[i]=='\t' || s[i]=='\n')) i++;
int j = s.size();
while (j>i && (s[j]==' ' || s[j]=='\t' || s[j]=='\n')) j--;
return s.substr(i,j-i);
}
} // jittor

View File

@ -16,6 +16,8 @@ import contextlib
import threading
import time
from ctypes import cdll
import shutil
import urllib.request
class LogWarper:
def __init__(self):
@ -182,7 +184,6 @@ def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
def download(url, filename):
from six.moves import urllib
if os.path.isfile(filename):
if os.path.getsize(filename) > 100:
return
@ -248,7 +249,9 @@ def get_int_version(output):
return ver
def find_exe(name, check_version=True, silent=False):
output = run_cmd(f'which {name}', err_msg=f'{name} not found')
output = shutil.which(name)
if not output:
raise RuntimeError(f"{name} not found")
if check_version:
version = get_version(name)
else:
@ -299,6 +302,6 @@ for py3_config_path in py3_config_paths:
break
else:
raise RuntimeError(f"python3.{sys.version_info.minor}-config "
"not found in {py3_config_paths}, please specify "
"enviroment variable 'python_config_path',"
" or apt install python3.{sys.version_info.minor}-dev")
f"not found in {py3_config_paths}, please specify "
f"enviroment variable 'python_config_path',"
f" or apt install python3.{sys.version_info.minor}-dev")

View File

@ -0,0 +1,78 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import os
import sys
import subprocess as sp
import jittor_utils as jit_utils
from jittor_utils import LOG
from jittor_utils.misc import download_url_to_local
import pathlib
def get_cuda_driver():
ret, out = sp.getstatusoutput("nvidia-smi -q -u")
if ret != 0: return None
try:
out = out.lower()
out = out.split('cuda version')[1] \
.split(':')[1] \
.splitlines()[0] \
.strip()
out = [ int(s) for s in out.split('.')]
return out
except:
return None
def has_installation():
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
return os.path.isdir(jtcuda_path)
def install_cuda():
cuda_driver_version = get_cuda_driver()
if not cuda_driver_version:
return None
LOG.i("cuda_driver_version: ", cuda_driver_version)
if cuda_driver_version >= [11,2]:
cuda_tgz = "cuda11.2_cudnn8_linux.tgz"
md5 = "b93a1a5d19098e93450ee080509e9836"
elif cuda_driver_version >= [11,]:
cuda_tgz = "cuda11.0_cudnn8_linux.tgz"
md5 = "5dbdb43e35b4db8249027997720bf1ca"
elif cuda_driver_version >= [10,2]:
cuda_tgz = "cuda10.2_cudnn7_linux.tgz"
md5 = "a78f296746d97e9d76615289c2fe98ac"
elif cuda_driver_version >= [10,]:
cuda_tgz = "cuda10.0_cudnn7_linux.tgz"
md5 = "f16d3ff63f081031d21faec3ec8b7dac"
else:
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}")
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc")
nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64")
sys.path.append(nvcc_lib_path)
new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path
os.environ["LD_LIBRARY_PATH"] = new_ld_path
if os.path.isfile(nvcc_path):
return nvcc_path
os.makedirs(jtcuda_path, exist_ok=True)
cuda_tgz_path = os.path.join(jtcuda_path, cuda_tgz)
download_url_to_local("https://cg.cs.tsinghua.edu.cn/jittor/assets/"+cuda_tgz, cuda_tgz, jtcuda_path, md5)
import tarfile
with tarfile.open(cuda_tgz_path, "r") as tar:
tar.extractall(cuda_tgz_path[:-4])
assert os.path.isfile(nvcc_path)
return nvcc_path
if __name__ == "__main__":
nvcc_path = install_cuda()
LOG.i("nvcc is installed at ", nvcc_path)

View File

@ -7,12 +7,11 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import os
from six.moves import urllib
import hashlib
import urllib.request
from tqdm import tqdm
from .. import lock
from jittor_utils import lock
def ensure_dir(dir_path):
if not os.path.isdir(dir_path):
@ -41,14 +40,11 @@ def download_url_to_local(url, filename, root_folder, md5):
if check_file_exist(file_path, md5):
return
else:
try:
print('Downloading ' + url + ' to ' + file_path)
urllib.request.urlretrieve(
url, file_path,
reporthook=_progress()
)
except(urllib.error.URLError, IOError) as e:
raise e
print('Downloading ' + url + ' to ' + file_path)
urllib.request.urlretrieve(
url, file_path,
reporthook=_progress()
)
if not check_file_exist(file_path, md5):
raise RuntimeError("File downloads failed.")
@ -72,3 +68,4 @@ def calculate_md5(file_path, chunk_size=1024 * 1024):
def check_md5(file_path, md5, **kwargs):
return md5 == calculate_md5(file_path, **kwargs)