mirror of https://github.com/Jittor/Jittor
cutt update to v1.1
This commit is contained in:
parent
36a3c24a46
commit
d20916f972
|
@ -94,7 +94,7 @@ void CuttTransposeOp::run() {
|
|||
cuttExecute(iter->second, xp, yp);
|
||||
} else {
|
||||
cuttHandle plan;
|
||||
cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0);
|
||||
CHECK(0==cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0));
|
||||
cutt_plan_cache[jk.to_string()] = plan;
|
||||
cuttExecute(plan, xp, yp);
|
||||
}
|
||||
|
|
|
@ -196,13 +196,12 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
|
||||
def install_cutt(root_folder):
|
||||
# Modified from: https://github.com/ap-hynninen/cutt
|
||||
url = "https://github.com/Jittor/cutt/archive/master.zip"
|
||||
url = "https://codeload.github.com/Jittor/cutt/zip/master"
|
||||
url = "https://codeload.github.com/Jittor/cutt/zip/v1.1"
|
||||
|
||||
filename = "cutt-master.zip"
|
||||
filename = "cutt-1.1.zip"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, filename.replace(".zip",""))
|
||||
true_md5 = "af5bc35eea1832a42c0e0011659b7209"
|
||||
true_md5 = "7bb71cf7c49dbe57772539bf043778f7"
|
||||
|
||||
if os.path.exists(fullname):
|
||||
md5 = run_cmd('md5sum '+fullname).split()[0]
|
||||
|
@ -250,7 +249,7 @@ def setup_cutt():
|
|||
|
||||
make_cache_dir(cutt_path)
|
||||
install_cutt(cutt_path)
|
||||
cutt_home = os.path.join(cutt_path, "cutt-master")
|
||||
cutt_home = os.path.join(cutt_path, "cutt-1.1")
|
||||
cutt_include_path = os.path.join(cutt_home, "src")
|
||||
cutt_lib_path = os.path.join(cutt_home, "lib")
|
||||
|
||||
|
|
|
@ -67,5 +67,12 @@ class TestTransposeOp(unittest.TestCase):
|
|||
assert a.permute().shape == [4,3,2]
|
||||
assert a.permute(0,2,1).shape == [2,4,3]
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cutt(self):
|
||||
a = jt.rand((10,2)) > 0.5
|
||||
b = a.transpose()
|
||||
assert (a.data.transpose() == b.data).all()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue