support package installed cuda

This commit is contained in:
Dun Liang 2021-05-27 19:35:47 +08:00
parent c4a937cd32
commit b721afeac3
24 changed files with 104 additions and 63 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.9'
__version__ = '1.2.3.10'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -769,6 +769,9 @@ def check_cuda():
# assert cuda_dir.endswith("bin") and "cuda" in cuda_dir.lower(), f"Wrong cuda_dir: {cuda_dir}"
cuda_include = os.path.abspath(os.path.join(cuda_dir, "..", "include"))
cuda_lib = os.path.abspath(os.path.join(cuda_dir, "..", "lib64"))
if nvcc_path == "/usr/bin/nvcc":
# this nvcc is install by package manager
cuda_lib = "/usr/lib/x86_64-linux-gnu"
cuda_include2 = os.path.join(jittor_path, "extern","cuda","inc")
cc_flags += f" -DHAS_CUDA -I'{cuda_include}' -I'{cuda_include2}' "
core_link_flags += f" -lcudart -L'{cuda_lib}' "
@ -850,14 +853,14 @@ check_debug_flags()
sys.path.append(cache_path)
LOG.i(f"Jittor({__version__}) src: {jittor_path}")
LOG.i(f"{jit_utils.cc_type} at {jit_utils.cc_path}")
LOG.i(f"{jit_utils.cc_type} at {jit_utils.cc_path}{jit_utils.get_version(jit_utils.cc_path)}")
LOG.i(f"cache_path: {cache_path}")
with jit_utils.import_scope(import_flags):
jit_utils.try_import_jit_utils_core()
python_path = sys.executable
# something python do not return the correct sys executable
# sometime python do not return the correct sys executable
# this will happend when multiple python version installed
ex_python_path = python_path + '.' + str(sys.version_info.minor)
if os.path.isfile(ex_python_path):
@ -865,6 +868,8 @@ if os.path.isfile(ex_python_path):
py3_config_path = jit_utils.py3_config_path
nvcc_path = env_or_try_find('nvcc_path', '/usr/local/cuda/bin/nvcc')
if not nvcc_path:
nvcc_path = env_or_try_find('nvcc_path', '/usr/bin/nvcc')
gdb_path = try_find_exe('gdb')
addr2line_path = try_find_exe('addr2line')
has_pybt = check_pybt(gdb_path, python_path)

View File

@ -8,7 +8,7 @@
#include "var.h"
#include "cub_test_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#ifdef JIT
#include "cub_test.h"

View File

@ -8,7 +8,7 @@
#include "var.h"
#include "cudnn_test_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
int cudnn_test_entry( int argc, char** argv );

View File

@ -5,7 +5,7 @@
// ***************************************************************
#include "var.h"
#include "cutt_test_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#ifdef JIT
#include "cutt.h"

View File

@ -8,7 +8,7 @@
// ***************************************************************
#include "var.h"
#include "nccl_all_reduce_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include <nccl.h>
#include <cuda_runtime.h>

View File

@ -8,7 +8,7 @@
// ***************************************************************
#include "var.h"
#include "nccl_broadcast_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include <nccl.h>
#include <cuda_runtime.h>

View File

@ -8,7 +8,7 @@
// ***************************************************************
#include "var.h"
#include "nccl_reduce_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include <nccl.h>
#include <cuda_runtime.h>

View File

@ -5,7 +5,7 @@
// ***************************************************************
#include "var.h"
#include "nccl_test_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include "nccl_warper.h"

View File

@ -10,7 +10,7 @@
#include "var.h"
#include "mpi_all_reduce_op.h"
#include "ops/op_register.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include "misc/cuda_flags.h"
namespace jittor {

View File

@ -10,7 +10,7 @@
#include "var.h"
#include "mpi_broadcast_op.h"
#include "ops/op_register.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include "misc/cuda_flags.h"
namespace jittor {

View File

@ -10,7 +10,7 @@
#include "var.h"
#include "mpi_reduce_op.h"
#include "ops/op_register.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include "misc/cuda_flags.h"
namespace jittor {

View File

@ -7,7 +7,7 @@
#include "var.h"
#include "mpi_test_op.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
namespace jittor {

View File

@ -7,7 +7,7 @@
#include <sys/mman.h>
#include <sstream>
#include "jit_key.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
namespace jittor {

View File

@ -12,7 +12,7 @@
#include "jit_compiler.h"
#include "utils/cache_compile.h"
#include "opt/tuner_manager.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include "ops/op_register.h"
#include "ops/array_op.h"
#include "lock.h"

View File

@ -5,7 +5,7 @@
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "opt/expr.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
namespace jittor {
namespace expr {

View File

@ -28,44 +28,6 @@ std::ostream& operator<<(std::ostream& os, KernelIR& ir) {
return os << ir.to_string();
}
bool startswith(const string& a, const string& b, uint start, bool equal, uint end) {
if (!end) end = a.size();
if (b.size()+start > end) return false;
if (equal && b.size()+start != end) return false;
for (uint i=0; i<b.size(); i++)
if (a[i+start] != b[i]) return false;
return true;
}
bool endswith(const string& a, const string& b) {
if (a.size() < b.size()) return false;
return startswith(a, b, a.size()-b.size());
}
vector<string> split(const string& s, const string& sep, int max_split) {
vector<string> ret;
int pos = -1, pos_next;
while (1) {
pos_next = s.find(sep, pos+1);
if (pos_next == (int)string::npos || (int)ret.size() == max_split-1) {
ret.push_back(s.substr(pos+sep.size()));
return ret;
}
ret.push_back(s.substr(pos+sep.size(), pos_next-pos-sep.size()));
pos = pos_next;
}
ASSERT(max_split==0);
return ret;
}
string strip(const string& s) {
int i=0;
while (i<s.size() && (s[i]==' ' || s[i]=='\t' || s[i]=='\n')) i++;
int j = s.size();
while (j>i && (s[j]==' ' || s[j]=='\t' || s[j]=='\n')) j--;
return s.substr(i,j-i);
}
void KernelIR::del_scope() {
if (father && (type=="define" || type=="func" || type=="macro")) {
father->scope[attrs["lvalue"]].remove(this);

View File

@ -6,7 +6,7 @@
// ***************************************************************
#pragma once
#include "common.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
namespace jittor {

View File

@ -11,7 +11,7 @@
#include "opt/expr.h"
#include "opt/pass_manager.h"
#include "opt/pass/float_atomic_fix_pass.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
namespace jittor {

View File

@ -32,7 +32,7 @@
#include "opt/pass/fake_main_pass.h"
#include "opt/pass/check_cache_pass.h"
#include "opt/pass/mark_raw_pass.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
namespace jittor {

View File

@ -12,7 +12,7 @@
#include "pyjt/py_obj_holder.h"
#include "pyjt/py_converter.h"
#include "pybind/py_var_tracer.h"
#include "misc/str_utils.h"
#include "utils/str_utils.h"
#include "op.h"
#include "var.h"
#include "fused_op.h"

View File

@ -13,6 +13,7 @@
#include <unistd.h>
#include "utils/log.h"
#include "utils/mwsr_list.h"
#include "utils/str_utils.h"
namespace jittor {
@ -301,6 +302,26 @@ bool check_vlog(const char* fileline, int verbose) {
return verbose <= log_v;
}
static inline void check_cuda_unsupport_version(const string& output) {
// check error like:
// /usr/include/crt/host_config.h:121:2: error: #error -- unsupported GNU version! gcc versions later than 6 are not supported!
// #error -- unsupported GNU version! gcc versions later than 6 are not supported!
string pat = "crt/host_config.h";
auto id = output.find(pat);
if (id == string::npos) return;
auto end = id + pat.size();
while (id>=0 && !(output[id]==' ' || output[id]=='\t' || output[id]=='\n'))
id--;
id ++;
auto fname = output.substr(id, end-id);
LOGw << R"(
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Dear user, your nvcc and gcc version are not match,
but you can hot fix it by this command:
>>> sudo python3 -c 's=open(")" >> fname >> R"(","r").read().replace("#error", "//#error");open(")" >> fname >> R"(","w").write(s)'
)";
}
int system_popen(const char* cmd) {
char buf[BUFSIZ];
string cmd2;
@ -308,17 +329,20 @@ int system_popen(const char* cmd) {
cmd2 += " 2>&1 ";
FILE *ptr = popen(cmd2.c_str(), "r");
if (!ptr) return -1;
int64 len=0;
string output;
while (fgets(buf, BUFSIZ, ptr) != NULL) {
len += strlen(buf);
output += buf;
std::cerr << buf;
}
if (len) std::cerr.flush();
if (output.size()) std::cerr.flush();
auto ret = pclose(ptr);
if (len<10 && ret) {
if (output.size()<10 && ret) {
// maybe overcommit
return -1;
}
if (ret) {
check_cuda_unsupport_version(output);
}
return ret;
}

View File

@ -0,0 +1,50 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "utils/str_utils.h"
namespace jittor {
bool startswith(const string& a, const string& b, uint start, bool equal, uint end) {
if (!end) end = a.size();
if (b.size()+start > end) return false;
if (equal && b.size()+start != end) return false;
for (uint i=0; i<b.size(); i++)
if (a[i+start] != b[i]) return false;
return true;
}
bool endswith(const string& a, const string& b) {
if (a.size() < b.size()) return false;
return startswith(a, b, a.size()-b.size());
}
vector<string> split(const string& s, const string& sep, int max_split) {
vector<string> ret;
int pos = -1, pos_next;
while (1) {
pos_next = s.find(sep, pos+1);
if (pos_next == (int)string::npos || (int)ret.size() == max_split-1) {
ret.push_back(s.substr(pos+sep.size()));
return ret;
}
ret.push_back(s.substr(pos+sep.size(), pos_next-pos-sep.size()));
pos = pos_next;
}
ASSERT(max_split==0);
return ret;
}
string strip(const string& s) {
int i=0;
while (i<s.size() && (s[i]==' ' || s[i]=='\t' || s[i]=='\n')) i++;
int j = s.size();
while (j>i && (s[j]==' ' || s[j]=='\t' || s[j]=='\n')) j--;
return s.substr(i,j-i);
}
} // jittor