fix some acl issue

This commit is contained in:
Dun Liang 2022-10-07 14:20:18 +08:00
parent a57de764f6
commit 524564b763
5 changed files with 33 additions and 5 deletions

View File

@ -1235,8 +1235,9 @@ from .extern.rocm import rocm_compiler
jit_utils.add_backend(rocm_compiler)
for mod in jit_utils.backends:
if mod.check():
break
mod.check()
# if mod.check():
# break
# build core
gen_jit_flags()

View File

@ -20,8 +20,16 @@ def install():
global has_acl, cc_flags
acl_compiler_home = os.path.dirname(__file__)
cc_files = sorted(glob.glob(acl_compiler_home+"/**/*.cc", recursive=True))
cc_flags += f" -DHAS_CUDA -DIS_ACL -I/usr/local/Ascend/latest/x86_64-linux/include/ -I/usr/local/Ascend/latest/x86_64-linux/include/acl -L/usr/local/Ascend/latest/x86_64-linux/lib64 -I/usr/local/Ascend/runtime/include -I/usr/local/Ascend/driver/include -L/usr/local/Ascend/compiler/lib64 -L/usr/local/Ascend/runtime/lib64 -I{acl_compiler_home} -ltikc_runtime -lascendcl "
cc_flags += f" -DHAS_CUDA -DIS_ACL \
-I/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/include/ \
-L/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/lib64 \
-I{acl_compiler_home} -ltikc_runtime -lascendcl "
ctypes.CDLL("libascendcl.so", dlopen_flags)
'''
-I/usr/local/Ascend/driver/include \
-L/usr/local/Ascend/compiler/lib64 \
-L/usr/local/Ascend/runtime/lib64 \
'''
jittor_utils.LOG.i("ACL detected")
mod = jittor_utils.compile_module('''

View File

@ -158,7 +158,19 @@ string process_acl(const string& src, const string& name, const map<string,strin
}
}
if (!edit) return src;
return join(tokens, "");
string new_src = join(tokens, "");
if (name == "executor.cc") {
new_src = "#include <Python.h>\n#include <pystate.h>\n"+
replace(new_src, "op->do_run_after_prepare(jkl);",
R"({
auto state = _PyThreadState_UncheckedGet();
op->do_run_after_prepare(jkl);
if (!_PyThreadState_UncheckedGet()) {
PyEval_AcquireThread(state);
}
})");
}
return new_src;
}
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags) {

View File

@ -687,7 +687,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
} catch (const std::exception& e) {
// log memory info
display_memory_info(__FILELINE__, false, true);
throw e;
throw;
}
event_queue.flush();
}

View File

@ -40,6 +40,13 @@ class TestACL(unittest.TestCase):
y = jt.float32(x)
np.testing.assert_allclose(x, y.numpy())
@jt.flag_scope(use_acl=1)
def test_array_cast_half(self):
# this test cannot pass because cast error
x = np.random.rand(10).astype("float32")
y = jt.float16(x)
np.testing.assert_allclose(x, y.numpy())
@jt.flag_scope(use_acl=1)
def test_rand(self):
a = jt.rand(10)