cutt update to v1.1

This commit is contained in:
Dun Liang 2021-03-18 13:46:56 +08:00
parent 36a3c24a46
commit d20916f972
3 changed files with 12 additions and 6 deletions

View File

@ -94,7 +94,7 @@ void CuttTransposeOp::run() {
cuttExecute(iter->second, xp, yp);
} else {
cuttHandle plan;
cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0);
CHECK(0==cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0));
cutt_plan_cache[jk.to_string()] = plan;
cuttExecute(plan, xp, yp);
}

View File

@ -196,13 +196,12 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
def install_cutt(root_folder):
# Modified from: https://github.com/ap-hynninen/cutt
url = "https://github.com/Jittor/cutt/archive/master.zip"
url = "https://codeload.github.com/Jittor/cutt/zip/master"
url = "https://codeload.github.com/Jittor/cutt/zip/v1.1"
filename = "cutt-master.zip"
filename = "cutt-1.1.zip"
fullname = os.path.join(root_folder, filename)
dirname = os.path.join(root_folder, filename.replace(".zip",""))
true_md5 = "af5bc35eea1832a42c0e0011659b7209"
true_md5 = "7bb71cf7c49dbe57772539bf043778f7"
if os.path.exists(fullname):
md5 = run_cmd('md5sum '+fullname).split()[0]
@ -250,7 +249,7 @@ def setup_cutt():
make_cache_dir(cutt_path)
install_cutt(cutt_path)
cutt_home = os.path.join(cutt_path, "cutt-master")
cutt_home = os.path.join(cutt_path, "cutt-1.1")
cutt_include_path = os.path.join(cutt_home, "src")
cutt_lib_path = os.path.join(cutt_home, "lib")

View File

@ -67,5 +67,12 @@ class TestTransposeOp(unittest.TestCase):
assert a.permute().shape == [4,3,2]
assert a.permute(0,2,1).shape == [2,4,3]
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_cutt(self):
a = jt.rand((10,2)) > 0.5
b = a.transpose()
assert (a.data.transpose() == b.data).all()
if __name__ == "__main__":
unittest.main()