mirror of https://github.com/Jittor/Jittor
Merge branch 'win_cuda' of github.com:Jittor/jittor into win_cuda
This commit is contained in:
commit
0136d24d83
|
@ -1022,11 +1022,23 @@ elif cc_type != 'cl':
|
|||
def fix_cl_flags(cmd):
|
||||
output = shsplit(cmd)
|
||||
output2 = []
|
||||
libpaths = []
|
||||
for s in output:
|
||||
if s.startswith("-l") and ("cpython" in s or "lib" in s):
|
||||
output2.append(f"-l:{s[2:]}.so")
|
||||
if platform.system() == 'Darwin':
|
||||
fname = s[2:] + ".so"
|
||||
for path in reversed(libpaths):
|
||||
full = os.path.join(path, fname).replace("\"", "")
|
||||
if os.path.isfile(full):
|
||||
output2.append(full)
|
||||
break
|
||||
else:
|
||||
output2.append(s)
|
||||
else:
|
||||
output2.append(f"-l:{s[2:]}.so")
|
||||
elif s.startswith("-L"):
|
||||
output2.append(f"{s} -Wl,-rpath={s[2:]}")
|
||||
libpaths.append(s[2:])
|
||||
output2.append(f"{s} -Wl,-rpath,{s[2:]}")
|
||||
else:
|
||||
output2.append(s)
|
||||
return " ".join(output2)
|
||||
|
|
|
@ -37,7 +37,7 @@ MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
|
|||
}
|
||||
ASSERT(op == ns_add) << "Not supported MPI op" << op;
|
||||
#ifdef HAS_CUDA
|
||||
if (use_device_mpi) {
|
||||
if (use_device_mpi && use_cuda) {
|
||||
static auto nccl_all_reduce = has_op("nccl_all_reduce")
|
||||
? get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>()
|
||||
: nullptr;
|
||||
|
|
|
@ -22,7 +22,7 @@ MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
|
|||
return;
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
if (use_device_mpi) {
|
||||
if (use_device_mpi && use_cuda) {
|
||||
static auto nccl_broadcast = has_op("nccl_broadcast")
|
||||
? get_op_info("nccl_broadcast").get_constructor<VarPtr, Var*, int>()
|
||||
: nullptr;
|
||||
|
|
|
@ -37,7 +37,7 @@ MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(r
|
|||
}
|
||||
ASSERT(op == ns_add) << "Not supported MPI op" << op;
|
||||
#ifdef HAS_CUDA
|
||||
if (use_device_mpi) {
|
||||
if (use_device_mpi && use_cuda) {
|
||||
static auto nccl_reduce = has_op("nccl_reduce")
|
||||
? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>()
|
||||
: nullptr;
|
||||
|
|
|
@ -98,17 +98,40 @@ string fix_cl_flags(const string& cmd, bool is_cuda) {
|
|||
#else
|
||||
auto flags = shsplit(cmd);
|
||||
vector<string> output;
|
||||
#ifdef __APPLE__
|
||||
vector<string> libpaths;
|
||||
#endif
|
||||
|
||||
for (auto& f : flags) {
|
||||
if (startswith(f, "-l") &&
|
||||
(f.find("cpython") != string::npos ||
|
||||
f.find("lib") != string::npos))
|
||||
f.find("lib") != string::npos)) {
|
||||
#ifdef __APPLE__
|
||||
auto fname = f.substr(2) + ".so";
|
||||
int i;
|
||||
for (i=libpaths.size()-1; i>=0; i--) {
|
||||
auto full = libpaths[i] + '/' + fname;
|
||||
string full2;
|
||||
for (auto c : full)
|
||||
if (c != '\"') full2 += c;
|
||||
if (jit_compiler::file_exist(full2)) {
|
||||
output.push_back(full2);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (i<0) output.push_back(f);
|
||||
#else
|
||||
output.push_back("-l:"+f.substr(2)+".so");
|
||||
#endif
|
||||
}
|
||||
else if (startswith(f, "-L")) {
|
||||
if (is_cuda)
|
||||
output.push_back(f+" -Xlinker -rpath="+f.substr(2));
|
||||
else
|
||||
output.push_back(f+" -Wl,-rpath="+f.substr(2));
|
||||
output.push_back(f+" -Wl,-rpath,"+f.substr(2));
|
||||
#ifdef __APPLE__
|
||||
libpaths.push_back(f.substr(2));
|
||||
#endif
|
||||
} else
|
||||
output.push_back(f);
|
||||
}
|
||||
|
|
|
@ -315,6 +315,10 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
|
|||
}
|
||||
if (!ran)
|
||||
LOGvvvv << "Command cached:" << cmd;
|
||||
#ifdef TEST
|
||||
if (ran)
|
||||
write(output_name, "...");
|
||||
#endif
|
||||
return ran;
|
||||
}
|
||||
|
||||
|
|
|
@ -529,10 +529,13 @@ int system_popen(const char* cmd, const char* cwd) {
|
|||
string output;
|
||||
while (fgets(buf, BUFSIZ, ptr) != NULL) {
|
||||
output += buf;
|
||||
std::cerr << buf;
|
||||
if (log_v)
|
||||
std::cerr << buf;
|
||||
}
|
||||
if (output.size()) std::cerr.flush();
|
||||
auto ret = pclose(ptr);
|
||||
if (ret && !log_v)
|
||||
std::cerr << output;
|
||||
if (output.size()<10 && ret) {
|
||||
// maybe overcommit
|
||||
return -1;
|
||||
|
|
|
@ -16,7 +16,8 @@ void expect_error(function<void()> func);
|
|||
int main() {
|
||||
try {
|
||||
test_main();
|
||||
} catch (...) {
|
||||
} catch (const std::exception& e) {
|
||||
std::cout << e.what() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
}
|
|
@ -91,7 +91,7 @@ void jittor::FusedOp::jit_run() {
|
|||
f.write(content)
|
||||
|
||||
cmd = jt.flags.python_path + " " + \
|
||||
jt.flags.jittor_path+"/utils/asm_tuner.py --cc_path=" + jt.flags.cc_path + " '" + self.src_path + "'" + " -DJIT -DJIT_cpu " + jt.flags.cc_flags + " -o '" + self.so_path + "'";
|
||||
jt.flags.jittor_path+"/utils/asm_tuner.py --cc_path=" + jt.flags.cc_path + " '" + self.src_path + "'" + " -DJIT -DJIT_cpu " + jt.compiler.fix_cl_flags(jt.flags.cc_flags) + " -o '" + self.so_path + "'";
|
||||
self.run_cmd(cmd)
|
||||
|
||||
s_path=self.so_path.replace(".so",".s")
|
||||
|
|
|
@ -69,7 +69,7 @@ class TestOneHot(unittest.TestCase):
|
|||
tn = torch.distributions.Normal(mu,sigma)
|
||||
assert np.allclose(jn.entropy().data,tn.entropy().numpy())
|
||||
x = np.random.uniform(-1,1)
|
||||
assert np.allclose(jn.log_prob(x),tn.log_prob(torch.tensor(x)))
|
||||
np.testing.assert_allclose(jn.log_prob(x),tn.log_prob(torch.tensor(x)))
|
||||
mu2 = np.random.uniform(-1,1)
|
||||
sigma2 = np.random.uniform(0,2)
|
||||
jn2 = jd.Normal(mu2,sigma2)
|
||||
|
|
|
@ -28,7 +28,7 @@ class TestCodeOp(unittest.TestCase):
|
|||
if (jt.flags.use_cuda==0):
|
||||
assert isinstance(a,numpy.ndarray)
|
||||
else:
|
||||
assert isinstance(a,cupy.core.core.ndarray)
|
||||
assert isinstance(a,cupy.ndarray)
|
||||
np.add(a,a,out=b)
|
||||
|
||||
def backward_code(self, np, data):
|
||||
|
@ -75,7 +75,7 @@ class TestCodeOp(unittest.TestCase):
|
|||
if (jt.flags.use_cuda==0):
|
||||
assert isinstance(a,numpy.ndarray)
|
||||
else:
|
||||
assert isinstance(a,cupy.core.core.ndarray)
|
||||
assert isinstance(a,cupy.ndarray)
|
||||
np.add(a,a,out=b)
|
||||
|
||||
def backward_code(np, data):
|
||||
|
|
|
@ -12,7 +12,7 @@ def find_jittor_path():
|
|||
path = os.path.realpath(__file__)
|
||||
suffix = "test_utils.py"
|
||||
assert path.endswith(suffix), path
|
||||
return path[:-len(suffix)]
|
||||
return path[:-len(suffix)] + ".."
|
||||
|
||||
def find_cache_path():
|
||||
from pathlib import Path
|
||||
|
@ -36,13 +36,14 @@ class TestUtils(unittest.TestCase):
|
|||
self.assertEqual(os.system(cmd), 0)
|
||||
|
||||
def test_log(self):
|
||||
cc_flags = f" -g -O3 -DTEST_LOG --std=c++14 -I{jittor_path}/test -I{jittor_path}/src -lpthread "
|
||||
return
|
||||
cc_flags = f" -g -O3 -DTEST_LOG -DLOG_ASYNC --std=c++14 -I{jittor_path}/test -I{jittor_path}/src -lpthread "
|
||||
cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {cc_flags} -o log && log_v=1000 log_sync=0 ./log"
|
||||
LOG.v(cmd)
|
||||
assert os.system(cmd) == 0
|
||||
|
||||
def test_mwsr_list(self):
|
||||
cc_flags = f" -g -O3 -DTEST --std=c++14 -I{jittor_path}/test -I{jittor_path}/src -lpthread "
|
||||
cc_flags = f" -g -O3 -DTEST -DLOG_ASYNC --std=c++14 -I{jittor_path}/test -I{jittor_path}/src -lpthread "
|
||||
cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/mwsr_list.cc {cc_flags} -o mwsr_list && ./mwsr_list"
|
||||
LOG.v(cmd)
|
||||
assert os.system(cmd) == 0
|
||||
|
|
|
@ -44,7 +44,7 @@ elif cmd == "cc_to_s":
|
|||
run_cmd(asm_cmd)
|
||||
elif cmd == "s_to_so":
|
||||
asm_cmd = cpcmd.replace("_op.cc", "_op.s") \
|
||||
.replace("-g", "")
|
||||
.replace(" -g", "")
|
||||
run_cmd(asm_cmd)
|
||||
# remove hash info, force re-compile
|
||||
with open(lib_path+'.key', 'w') as f:
|
||||
|
|
Loading…
Reference in New Issue