mirror of https://github.com/Jittor/Jittor
Merge branch 'master' into macOS
This commit is contained in:
commit
23106e0606
|
@ -9,8 +9,8 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.9'
|
||||
from . import lock
|
||||
__version__ = '1.2.3.14'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
ori_float = float
|
||||
|
@ -59,7 +59,7 @@ def safeunpickle(path):
|
|||
if path.startswith("https:") or path.startswith("http:"):
|
||||
base = path.split("/")[-1]
|
||||
fname = os.path.join(compiler.ck_path, base)
|
||||
from jittor.utils.misc import download_url_to_local
|
||||
from jittor_utils.misc import download_url_to_local
|
||||
download_url_to_local(path, base, compiler.ck_path, None)
|
||||
path = fname
|
||||
if path.endswith(".pth"):
|
||||
|
|
|
@ -8,11 +8,12 @@ import os, sys, shutil
|
|||
import platform
|
||||
from .compiler import *
|
||||
from jittor_utils import run_cmd, get_version, get_int_version
|
||||
from jittor.utils.misc import download_url_to_local
|
||||
from jittor_utils.misc import download_url_to_local
|
||||
|
||||
def search_file(dirs, name, prefer_version=()):
|
||||
for d in dirs:
|
||||
fname = os.path.join(d, name)
|
||||
prefer_version = tuple( str(p) for p in prefer_version )
|
||||
for i in range(len(prefer_version),-1,-1):
|
||||
vname = ".".join((fname,)+prefer_version[:i])
|
||||
if os.path.isfile(vname):
|
||||
|
@ -114,7 +115,7 @@ def install_cub(root_folder):
|
|||
with tarfile.open(fullname, "r") as tar:
|
||||
tar.extractall(root_folder)
|
||||
assert 0 == os.system(f"cd {dirname}/examples && "
|
||||
f"{nvcc_path} device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
|
||||
f"{nvcc_path} --cudart=shared -ccbin=\"{cc_path}\" device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
|
||||
if core.get_device_count():
|
||||
assert 0 == os.system(f"cd {dirname}/examples && ./test")
|
||||
return dirname
|
||||
|
@ -161,9 +162,11 @@ def setup_cuda_extern():
|
|||
line = traceback.format_exc()
|
||||
LOG.w(f"CUDA found but {lib_name} is not loaded:\n{line}")
|
||||
if lib_name == "cudnn":
|
||||
LOG.w(f"Develop version of CUDNN not found, "
|
||||
"please refer to CUDA offical tar file installation: "
|
||||
"https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar")
|
||||
LOG.w(f"""Develop version of CUDNN not found,
|
||||
please refer to CUDA offical tar file installation:
|
||||
https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar
|
||||
or you can let jittor install cuda and cudnn for you:
|
||||
>>> python3.{sys.version_info.minor} -m jittor_utils.install_cuda""")
|
||||
|
||||
def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
||||
globals()[lib_name+"_ops"] = None
|
||||
|
@ -187,6 +190,12 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
prefer_version = ("8",)
|
||||
culib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"lib{lib_name}.so", prefer_version)
|
||||
|
||||
if lib_name == "cublas" and nvcc_version[0] >= 10:
|
||||
# manual link libcublasLt.so
|
||||
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu"], f"libcublasLt.so", nvcc_version)
|
||||
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
|
||||
|
||||
|
||||
if lib_name == "cudnn":
|
||||
# cudnn cannot found libcudnn_cnn_train.so.8, we manual link for it.
|
||||
if nvcc_version >= (11,0,0):
|
||||
|
@ -249,7 +258,7 @@ def install_cutt(root_folder):
|
|||
if len(flags.cuda_archs):
|
||||
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
|
||||
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||
run_cmd(f"make NVCC_GENCODE='{arch_flag}' nvcc_path='{nvcc_path}'", cwd=dirname)
|
||||
run_cmd(f"make NVCC_GENCODE='{arch_flag} --cudart=shared -ccbin=\"{cc_path}\" ' nvcc_path='{nvcc_path}'", cwd=dirname)
|
||||
return dirname
|
||||
|
||||
def setup_cutt():
|
||||
|
@ -325,7 +334,7 @@ def install_nccl(root_folder):
|
|||
if len(flags.cuda_archs):
|
||||
arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
|
||||
arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs))
|
||||
run_cmd(f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag}' ", cwd=dirname)
|
||||
run_cmd(f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag} --cudart=shared -ccbin=\"{cc_path}\" ' ", cwd=dirname)
|
||||
return dirname
|
||||
|
||||
def setup_nccl():
|
||||
|
|
|
@ -19,7 +19,8 @@ from ctypes.util import find_library
|
|||
import jittor_utils as jit_utils
|
||||
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
|
||||
from . import pyjt_compiler
|
||||
from . import lock
|
||||
from jittor_utils import lock
|
||||
from jittor_utils import install_cuda
|
||||
from jittor import __version__
|
||||
|
||||
def find_jittor_path():
|
||||
|
@ -766,7 +767,7 @@ def compile_extern():
|
|||
LOG.vv(f"Compile extern llvm passes: {str(files)}")
|
||||
|
||||
def check_cuda():
|
||||
if nvcc_path == "":
|
||||
if not nvcc_path:
|
||||
return
|
||||
global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include, cuda_home
|
||||
cuda_dir = os.path.dirname(get_full_path_of_executable(nvcc_path))
|
||||
|
@ -775,6 +776,9 @@ def check_cuda():
|
|||
# assert cuda_dir.endswith("bin") and "cuda" in cuda_dir.lower(), f"Wrong cuda_dir: {cuda_dir}"
|
||||
cuda_include = os.path.abspath(os.path.join(cuda_dir, "..", "include"))
|
||||
cuda_lib = os.path.abspath(os.path.join(cuda_dir, "..", "lib64"))
|
||||
if nvcc_path == "/usr/bin/nvcc":
|
||||
# this nvcc is install by package manager
|
||||
cuda_lib = "/usr/lib/x86_64-linux-gnu"
|
||||
cuda_include2 = os.path.join(jittor_path, "extern","cuda","inc")
|
||||
cc_flags += f" -DHAS_CUDA -I'{cuda_include}' -I'{cuda_include2}' "
|
||||
core_link_flags += f" -lcudart -L'{cuda_lib}' "
|
||||
|
@ -788,6 +792,7 @@ def check_cache_compile():
|
|||
"src/utils/log.cc",
|
||||
"src/utils/tracer.cc",
|
||||
"src/utils/jit_utils.cc",
|
||||
"src/utils/str_utils.cc",
|
||||
]
|
||||
global jit_utils_core_files
|
||||
jit_utils_core_files = files
|
||||
|
@ -860,21 +865,37 @@ check_debug_flags()
|
|||
|
||||
sys.path.append(cache_path)
|
||||
LOG.i(f"Jittor({__version__}) src: {jittor_path}")
|
||||
LOG.i(f"{jit_utils.cc_type} at {jit_utils.cc_path}")
|
||||
LOG.i(f"{jit_utils.cc_type} at {jit_utils.cc_path}{jit_utils.get_version(jit_utils.cc_path)}")
|
||||
LOG.i(f"cache_path: {cache_path}")
|
||||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
jit_utils.try_import_jit_utils_core()
|
||||
|
||||
python_path = sys.executable
|
||||
# something python do not return the correct sys executable
|
||||
# sometime python do not return the correct sys executable
|
||||
# this will happend when multiple python version installed
|
||||
ex_python_path = python_path + '.' + str(sys.version_info.minor)
|
||||
if os.path.isfile(ex_python_path):
|
||||
python_path = ex_python_path
|
||||
py3_config_path = jit_utils.py3_config_path
|
||||
|
||||
nvcc_path = env_or_try_find('nvcc_path', '/usr/local/cuda/bin/nvcc')
|
||||
# if jtcuda is already installed
|
||||
nvcc_path = None
|
||||
if install_cuda.has_installation():
|
||||
nvcc_path = install_cuda.install_cuda()
|
||||
if nvcc_path:
|
||||
nvcc_path = try_find_exe(nvcc_path)
|
||||
# check system installed cuda
|
||||
if not nvcc_path:
|
||||
nvcc_path = env_or_try_find('nvcc_path', 'nvcc') or try_find_exe('/usr/local/cuda/bin/nvcc') or try_find_exe('/usr/bin/nvcc')
|
||||
# if system has no cuda, install jtcuda
|
||||
if not nvcc_path:
|
||||
nvcc_path = install_cuda.install_cuda()
|
||||
if nvcc_path:
|
||||
nvcc_path = try_find_exe(nvcc_path)
|
||||
if not nvcc_path:
|
||||
nvcc_path = ""
|
||||
|
||||
gdb_path = try_find_exe('gdb')
|
||||
addr2line_path = try_find_exe('addr2line')
|
||||
has_pybt = check_pybt(gdb_path, python_path)
|
||||
|
|
|
@ -12,7 +12,7 @@ import gzip
|
|||
from PIL import Image
|
||||
# our lib jittor import
|
||||
from jittor.dataset.dataset import Dataset, dataset_root
|
||||
from jittor.utils.misc import ensure_dir, download_url_to_local
|
||||
from jittor_utils.misc import ensure_dir, download_url_to_local
|
||||
import jittor as jt
|
||||
import jittor.transform as trans
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
#include "var.h"
|
||||
#include "cub_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
#ifdef JIT
|
||||
#include "cub_test.h"
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
#include "var.h"
|
||||
#include "cudnn_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
int cudnn_test_entry( int argc, char** argv );
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "cutt_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
#ifdef JIT
|
||||
#include "cutt.h"
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "nccl_all_reduce_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
#include <nccl.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "nccl_broadcast_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
#include <nccl.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "nccl_reduce_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
#include <nccl.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "nccl_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
#include "nccl_warper.h"
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include "var.h"
|
||||
#include "mpi_all_reduce_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
namespace jittor {
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include "var.h"
|
||||
#include "mpi_broadcast_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
namespace jittor {
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include "var.h"
|
||||
#include "mpi_reduce_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
namespace jittor {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
#include "var.h"
|
||||
#include "mpi_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ void cleanup() {
|
|||
|
||||
static void init_cuda_devices() {
|
||||
#ifdef HAS_CUDA
|
||||
if (cuda_archs.size()) return;
|
||||
int count=0;
|
||||
cudaGetDeviceCount(&count);
|
||||
for (int i=0; i<count; i++) {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include <sys/mman.h>
|
||||
#include <sstream>
|
||||
#include "jit_key.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "jit_compiler.h"
|
||||
#include "utils/cache_compile.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "lock.h"
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "opt/expr.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
namespace expr {
|
||||
|
|
|
@ -28,44 +28,6 @@ std::ostream& operator<<(std::ostream& os, KernelIR& ir) {
|
|||
return os << ir.to_string();
|
||||
}
|
||||
|
||||
bool startswith(const string& a, const string& b, uint start, bool equal, uint end) {
|
||||
if (!end) end = a.size();
|
||||
if (b.size()+start > end) return false;
|
||||
if (equal && b.size()+start != end) return false;
|
||||
for (uint i=0; i<b.size(); i++)
|
||||
if (a[i+start] != b[i]) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool endswith(const string& a, const string& b) {
|
||||
if (a.size() < b.size()) return false;
|
||||
return startswith(a, b, a.size()-b.size());
|
||||
}
|
||||
|
||||
vector<string> split(const string& s, const string& sep, int max_split) {
|
||||
vector<string> ret;
|
||||
int pos = -1, pos_next;
|
||||
while (1) {
|
||||
pos_next = s.find(sep, pos+1);
|
||||
if (pos_next == (int)string::npos || (int)ret.size() == max_split-1) {
|
||||
ret.push_back(s.substr(pos+sep.size()));
|
||||
return ret;
|
||||
}
|
||||
ret.push_back(s.substr(pos+sep.size(), pos_next-pos-sep.size()));
|
||||
pos = pos_next;
|
||||
}
|
||||
ASSERT(max_split==0);
|
||||
return ret;
|
||||
}
|
||||
|
||||
string strip(const string& s) {
|
||||
int i=0;
|
||||
while (i<s.size() && (s[i]==' ' || s[i]=='\t' || s[i]=='\n')) i++;
|
||||
int j = s.size();
|
||||
while (j>i && (s[j]==' ' || s[j]=='\t' || s[j]=='\n')) j--;
|
||||
return s.substr(i,j-i);
|
||||
}
|
||||
|
||||
void KernelIR::del_scope() {
|
||||
if (father && (type=="define" || type=="func" || type=="macro")) {
|
||||
father->scope[attrs["lvalue"]].remove(this);
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
#include "opt/expr.h"
|
||||
#include "opt/pass_manager.h"
|
||||
#include "opt/pass/float_atomic_fix_pass.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
#include "opt/pass/fake_main_pass.h"
|
||||
#include "opt/pass/check_cache_pass.h"
|
||||
#include "opt/pass/mark_raw_pass.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "pyjt/py_obj_holder.h"
|
||||
#include "pyjt/py_converter.h"
|
||||
#include "pybind/py_var_tracer.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "op.h"
|
||||
#include "var.h"
|
||||
#include "fused_op.h"
|
||||
|
|
|
@ -24,8 +24,8 @@ void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr
|
|||
tmp_data_t tmp_data;
|
||||
|
||||
void numpy_init() {
|
||||
PyObjHolder np(PyImport_ImportModule("numpy.core.multiarray"));
|
||||
PyObjHolder api(PyObject_GetAttrString(np.obj, "_ARRAY_API"));
|
||||
PyObjHolder np(PyImport_ImportModule("numpy.core.multiarray"), "numpy is not installed");
|
||||
PyObjHolder api(PyObject_GetAttrString(np.obj, "_ARRAY_API"), "numpy _ARRAY_API not found, you may need to reinstall numpy");
|
||||
PyArray_API = (void **) PyCapsule_GetPointer(api.obj, NULL);
|
||||
|
||||
#define fill(name, i) name = (decltype(name))PyArray_API[i]
|
||||
|
|
|
@ -86,15 +86,19 @@ ArrayOp::ArrayOp(PyObject* obj) {
|
|||
// use 32-bit by default
|
||||
if ((auto_convert_64_to_32 || holder.obj)
|
||||
&& args.dtype.dsize() == 8 && args.ptr) {
|
||||
auto num = PyArray_Size(arr)/8;
|
||||
auto size = PyArray_Size(arr);
|
||||
args.buffer.reset(new char[size]);
|
||||
auto pre_data = args.ptr;
|
||||
args.ptr = args.buffer.get();
|
||||
auto num = size/8;
|
||||
if (args.dtype.is_int()) {
|
||||
auto* __restrict__ i64 = (int64*)args.ptr;
|
||||
auto* __restrict__ i64 = (int64*)pre_data;
|
||||
auto* __restrict__ i32 = (int32*)args.ptr;
|
||||
for (int i=0; i<num; i++)
|
||||
i32[i] = (int32)i64[i];
|
||||
args.dtype = ns_int32;
|
||||
} else if (args.dtype.is_float()) {
|
||||
auto* __restrict__ f64 = (float64*)args.ptr;
|
||||
auto* __restrict__ f64 = (float64*)pre_data;
|
||||
auto* __restrict__ f32 = (float32*)args.ptr;
|
||||
for (int i=0; i<num; i++)
|
||||
f32[i] = (float32)f64[i];
|
||||
|
|
|
@ -25,6 +25,17 @@ struct PyObjHolder {
|
|||
LOGf << "Python error occur";
|
||||
}
|
||||
}
|
||||
inline void assign(PyObject* obj, const char* err_msg) {
|
||||
if (!obj) {
|
||||
LOGf << err_msg;
|
||||
}
|
||||
this->obj = obj;
|
||||
}
|
||||
inline PyObjHolder(PyObject* obj, const char* err_msg) : obj(obj) {
|
||||
if (!obj) {
|
||||
LOGf << err_msg;
|
||||
}
|
||||
}
|
||||
inline ~PyObjHolder() {
|
||||
if (obj) Py_DECREF(obj);
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include <unistd.h>
|
||||
#include "utils/log.h"
|
||||
#include "utils/mwsr_list.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -305,6 +306,26 @@ bool check_vlog(const char* fileline, int verbose) {
|
|||
return verbose <= log_v;
|
||||
}
|
||||
|
||||
static inline void check_cuda_unsupport_version(const string& output) {
|
||||
// check error like:
|
||||
// /usr/include/crt/host_config.h:121:2: error: #error -- unsupported GNU version! gcc versions later than 6 are not supported!
|
||||
// #error -- unsupported GNU version! gcc versions later than 6 are not supported!
|
||||
string pat = "crt/host_config.h";
|
||||
auto id = output.find(pat);
|
||||
if (id == string::npos) return;
|
||||
auto end = id + pat.size();
|
||||
while (id>=0 && !(output[id]==' ' || output[id]=='\t' || output[id]=='\n'))
|
||||
id--;
|
||||
id ++;
|
||||
auto fname = output.substr(id, end-id);
|
||||
LOGw << R"(
|
||||
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
Dear user, your nvcc and gcc version are not match,
|
||||
but you can hot fix it by this command:
|
||||
>>> sudo python3 -c 's=open(")" >> fname >> R"(","r").read().replace("#error", "//#error");open(")" >> fname >> R"(","w").write(s)'
|
||||
)";
|
||||
}
|
||||
|
||||
int system_popen(const char* cmd) {
|
||||
char buf[BUFSIZ];
|
||||
string cmd2;
|
||||
|
@ -312,17 +333,20 @@ int system_popen(const char* cmd) {
|
|||
cmd2 += " 2>&1 ";
|
||||
FILE *ptr = popen(cmd2.c_str(), "r");
|
||||
if (!ptr) return -1;
|
||||
int64 len=0;
|
||||
string output;
|
||||
while (fgets(buf, BUFSIZ, ptr) != NULL) {
|
||||
len += strlen(buf);
|
||||
output += buf;
|
||||
std::cerr << buf;
|
||||
}
|
||||
if (len) std::cerr.flush();
|
||||
if (output.size()) std::cerr.flush();
|
||||
auto ret = pclose(ptr);
|
||||
if (len<10 && ret) {
|
||||
if (output.size()<10 && ret) {
|
||||
// maybe overcommit
|
||||
return -1;
|
||||
}
|
||||
if (ret) {
|
||||
check_cuda_unsupport_version(output);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -125,8 +125,9 @@ struct LogFatalVoidify {
|
|||
#define LOG_IF(level, cond) _LOG_IF(level, cond, 0)
|
||||
|
||||
template<class T> T get_from_env(const char* name,const T& _default) {
|
||||
auto s = getenv(name);
|
||||
if (s == NULL) return _default;
|
||||
auto ss = getenv(name);
|
||||
if (ss == NULL) return _default;
|
||||
string s = ss;
|
||||
std::istringstream is(s);
|
||||
T env;
|
||||
if (is >> env) {
|
||||
|
@ -135,6 +136,8 @@ template<class T> T get_from_env(const char* name,const T& _default) {
|
|||
return env;
|
||||
}
|
||||
}
|
||||
if (s.size() && is.eof())
|
||||
return env;
|
||||
LOGw << "Load" << name << "from env(" << s << ") failed, use default" << _default;
|
||||
return _default;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
||||
bool startswith(const string& a, const string& b, uint start, bool equal, uint end) {
|
||||
if (!end) end = a.size();
|
||||
if (b.size()+start > end) return false;
|
||||
if (equal && b.size()+start != end) return false;
|
||||
for (uint i=0; i<b.size(); i++)
|
||||
if (a[i+start] != b[i]) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool endswith(const string& a, const string& b) {
|
||||
if (a.size() < b.size()) return false;
|
||||
return startswith(a, b, a.size()-b.size());
|
||||
}
|
||||
|
||||
vector<string> split(const string& s, const string& sep, int max_split) {
|
||||
vector<string> ret;
|
||||
int pos = -1, pos_next;
|
||||
while (1) {
|
||||
pos_next = s.find(sep, pos+1);
|
||||
if (pos_next == (int)string::npos || (int)ret.size() == max_split-1) {
|
||||
ret.push_back(s.substr(pos+sep.size()));
|
||||
return ret;
|
||||
}
|
||||
ret.push_back(s.substr(pos+sep.size(), pos_next-pos-sep.size()));
|
||||
pos = pos_next;
|
||||
}
|
||||
ASSERT(max_split==0);
|
||||
return ret;
|
||||
}
|
||||
|
||||
string strip(const string& s) {
|
||||
int i=0;
|
||||
while (i<s.size() && (s[i]==' ' || s[i]=='\t' || s[i]=='\n')) i++;
|
||||
int j = s.size();
|
||||
while (j>i && (s[j]==' ' || s[j]=='\t' || s[j]=='\n')) j--;
|
||||
return s.substr(i,j-i);
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -16,6 +16,8 @@ import contextlib
|
|||
import threading
|
||||
import time
|
||||
from ctypes import cdll
|
||||
import shutil
|
||||
import urllib.request
|
||||
|
||||
class LogWarper:
|
||||
def __init__(self):
|
||||
|
@ -182,7 +184,6 @@ def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
|
|||
|
||||
|
||||
def download(url, filename):
|
||||
from six.moves import urllib
|
||||
if os.path.isfile(filename):
|
||||
if os.path.getsize(filename) > 100:
|
||||
return
|
||||
|
@ -248,7 +249,9 @@ def get_int_version(output):
|
|||
return ver
|
||||
|
||||
def find_exe(name, check_version=True, silent=False):
|
||||
output = run_cmd(f'which {name}', err_msg=f'{name} not found')
|
||||
output = shutil.which(name)
|
||||
if not output:
|
||||
raise RuntimeError(f"{name} not found")
|
||||
if check_version:
|
||||
version = get_version(name)
|
||||
else:
|
||||
|
@ -299,6 +302,6 @@ for py3_config_path in py3_config_paths:
|
|||
break
|
||||
else:
|
||||
raise RuntimeError(f"python3.{sys.version_info.minor}-config "
|
||||
"not found in {py3_config_paths}, please specify "
|
||||
"enviroment variable 'python_config_path',"
|
||||
" or apt install python3.{sys.version_info.minor}-dev")
|
||||
f"not found in {py3_config_paths}, please specify "
|
||||
f"enviroment variable 'python_config_path',"
|
||||
f" or apt install python3.{sys.version_info.minor}-dev")
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import os
|
||||
import sys
|
||||
import subprocess as sp
|
||||
import jittor_utils as jit_utils
|
||||
from jittor_utils import LOG
|
||||
from jittor_utils.misc import download_url_to_local
|
||||
import pathlib
|
||||
|
||||
def get_cuda_driver():
|
||||
ret, out = sp.getstatusoutput("nvidia-smi -q -u")
|
||||
if ret != 0: return None
|
||||
try:
|
||||
out = out.lower()
|
||||
out = out.split('cuda version')[1] \
|
||||
.split(':')[1] \
|
||||
.splitlines()[0] \
|
||||
.strip()
|
||||
out = [ int(s) for s in out.split('.')]
|
||||
return out
|
||||
except:
|
||||
return None
|
||||
|
||||
def has_installation():
|
||||
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
|
||||
return os.path.isdir(jtcuda_path)
|
||||
|
||||
def install_cuda():
|
||||
cuda_driver_version = get_cuda_driver()
|
||||
if not cuda_driver_version:
|
||||
return None
|
||||
LOG.i("cuda_driver_version: ", cuda_driver_version)
|
||||
|
||||
if cuda_driver_version >= [11,2]:
|
||||
cuda_tgz = "cuda11.2_cudnn8_linux.tgz"
|
||||
md5 = "b93a1a5d19098e93450ee080509e9836"
|
||||
elif cuda_driver_version >= [11,]:
|
||||
cuda_tgz = "cuda11.0_cudnn8_linux.tgz"
|
||||
md5 = "5dbdb43e35b4db8249027997720bf1ca"
|
||||
elif cuda_driver_version >= [10,2]:
|
||||
cuda_tgz = "cuda10.2_cudnn7_linux.tgz"
|
||||
md5 = "a78f296746d97e9d76615289c2fe98ac"
|
||||
elif cuda_driver_version >= [10,]:
|
||||
cuda_tgz = "cuda10.0_cudnn7_linux.tgz"
|
||||
md5 = "f16d3ff63f081031d21faec3ec8b7dac"
|
||||
else:
|
||||
raise RuntimeError(f"Unsupport cuda driver version: {cuda_driver_version}")
|
||||
jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda")
|
||||
nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc")
|
||||
nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64")
|
||||
sys.path.append(nvcc_lib_path)
|
||||
new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path
|
||||
os.environ["LD_LIBRARY_PATH"] = new_ld_path
|
||||
|
||||
if os.path.isfile(nvcc_path):
|
||||
return nvcc_path
|
||||
|
||||
os.makedirs(jtcuda_path, exist_ok=True)
|
||||
cuda_tgz_path = os.path.join(jtcuda_path, cuda_tgz)
|
||||
download_url_to_local("https://cg.cs.tsinghua.edu.cn/jittor/assets/"+cuda_tgz, cuda_tgz, jtcuda_path, md5)
|
||||
|
||||
|
||||
import tarfile
|
||||
with tarfile.open(cuda_tgz_path, "r") as tar:
|
||||
tar.extractall(cuda_tgz_path[:-4])
|
||||
|
||||
assert os.path.isfile(nvcc_path)
|
||||
return nvcc_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
nvcc_path = install_cuda()
|
||||
LOG.i("nvcc is installed at ", nvcc_path)
|
|
@ -7,12 +7,11 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
import os
|
||||
from six.moves import urllib
|
||||
import hashlib
|
||||
import urllib.request
|
||||
from tqdm import tqdm
|
||||
from .. import lock
|
||||
from jittor_utils import lock
|
||||
|
||||
def ensure_dir(dir_path):
|
||||
if not os.path.isdir(dir_path):
|
||||
|
@ -41,14 +40,11 @@ def download_url_to_local(url, filename, root_folder, md5):
|
|||
if check_file_exist(file_path, md5):
|
||||
return
|
||||
else:
|
||||
try:
|
||||
print('Downloading ' + url + ' to ' + file_path)
|
||||
urllib.request.urlretrieve(
|
||||
url, file_path,
|
||||
reporthook=_progress()
|
||||
)
|
||||
except(urllib.error.URLError, IOError) as e:
|
||||
raise e
|
||||
print('Downloading ' + url + ' to ' + file_path)
|
||||
urllib.request.urlretrieve(
|
||||
url, file_path,
|
||||
reporthook=_progress()
|
||||
)
|
||||
if not check_file_exist(file_path, md5):
|
||||
raise RuntimeError("File downloads failed.")
|
||||
|
||||
|
@ -72,3 +68,4 @@ def calculate_md5(file_path, chunk_size=1024 * 1024):
|
|||
|
||||
def check_md5(file_path, md5, **kwargs):
|
||||
return md5 == calculate_md5(file_path, **kwargs)
|
||||
|
Loading…
Reference in New Issue