mirror of https://github.com/Jittor/Jittor
fix some acl issue
This commit is contained in:
parent
a57de764f6
commit
524564b763
|
@ -1235,8 +1235,9 @@ from .extern.rocm import rocm_compiler
|
|||
jit_utils.add_backend(rocm_compiler)
|
||||
|
||||
for mod in jit_utils.backends:
|
||||
if mod.check():
|
||||
break
|
||||
mod.check()
|
||||
# if mod.check():
|
||||
# break
|
||||
|
||||
# build core
|
||||
gen_jit_flags()
|
||||
|
|
|
@ -20,8 +20,16 @@ def install():
|
|||
global has_acl, cc_flags
|
||||
acl_compiler_home = os.path.dirname(__file__)
|
||||
cc_files = sorted(glob.glob(acl_compiler_home+"/**/*.cc", recursive=True))
|
||||
cc_flags += f" -DHAS_CUDA -DIS_ACL -I/usr/local/Ascend/latest/x86_64-linux/include/ -I/usr/local/Ascend/latest/x86_64-linux/include/acl -L/usr/local/Ascend/latest/x86_64-linux/lib64 -I/usr/local/Ascend/runtime/include -I/usr/local/Ascend/driver/include -L/usr/local/Ascend/compiler/lib64 -L/usr/local/Ascend/runtime/lib64 -I{acl_compiler_home} -ltikc_runtime -lascendcl "
|
||||
cc_flags += f" -DHAS_CUDA -DIS_ACL \
|
||||
-I/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/include/ \
|
||||
-L/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/lib64 \
|
||||
-I{acl_compiler_home} -ltikc_runtime -lascendcl "
|
||||
ctypes.CDLL("libascendcl.so", dlopen_flags)
|
||||
'''
|
||||
-I/usr/local/Ascend/driver/include \
|
||||
-L/usr/local/Ascend/compiler/lib64 \
|
||||
-L/usr/local/Ascend/runtime/lib64 \
|
||||
'''
|
||||
jittor_utils.LOG.i("ACL detected")
|
||||
|
||||
mod = jittor_utils.compile_module('''
|
||||
|
|
|
@ -158,7 +158,19 @@ string process_acl(const string& src, const string& name, const map<string,strin
|
|||
}
|
||||
}
|
||||
if (!edit) return src;
|
||||
return join(tokens, "");
|
||||
string new_src = join(tokens, "");
|
||||
if (name == "executor.cc") {
|
||||
new_src = "#include <Python.h>\n#include <pystate.h>\n"+
|
||||
replace(new_src, "op->do_run_after_prepare(jkl);",
|
||||
R"({
|
||||
auto state = _PyThreadState_UncheckedGet();
|
||||
op->do_run_after_prepare(jkl);
|
||||
if (!_PyThreadState_UncheckedGet()) {
|
||||
PyEval_AcquireThread(state);
|
||||
}
|
||||
})");
|
||||
}
|
||||
return new_src;
|
||||
}
|
||||
|
||||
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags) {
|
||||
|
|
|
@ -687,7 +687,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
|
|||
} catch (const std::exception& e) {
|
||||
// log memory info
|
||||
display_memory_info(__FILELINE__, false, true);
|
||||
throw e;
|
||||
throw;
|
||||
}
|
||||
event_queue.flush();
|
||||
}
|
||||
|
|
|
@ -40,6 +40,13 @@ class TestACL(unittest.TestCase):
|
|||
y = jt.float32(x)
|
||||
np.testing.assert_allclose(x, y.numpy())
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_array_cast_half(self):
|
||||
# this test cannot pass because cast error
|
||||
x = np.random.rand(10).astype("float32")
|
||||
y = jt.float16(x)
|
||||
np.testing.assert_allclose(x, y.numpy())
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_rand(self):
|
||||
a = jt.rand(10)
|
||||
|
|
Loading…
Reference in New Issue