From e25a8ac28dd1f32e0f27877effd7aeb5a45b724a Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Tue, 11 Aug 2020 19:53:48 +0800 Subject: [PATCH] fix numpy code op test --- python/jittor/__init__.py | 2 +- python/jittor/test/test_numpy_code_op.py | 18 +++++++++--------- python/jittor/test/test_parallel_pass.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 2a6c651e..814d63cf 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -7,7 +7,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.1.7.4' +__version__ = '1.1.7.5' from . import lock with lock.lock_scope(): from . import compiler diff --git a/python/jittor/test/test_numpy_code_op.py b/python/jittor/test/test_numpy_code_op.py index bc8eba1e..afa4a391 100644 --- a/python/jittor/test/test_numpy_code_op.py +++ b/python/jittor/test/test_numpy_code_op.py @@ -58,9 +58,9 @@ class TestCodeOp(unittest.TestCase): one=numpy.ones(a.shape) assert numpy.allclose(da.data,one*2.0) - jt.flags.use_cuda = 0 - check() - jt.flags.use_cuda = 1 + if jt.has_cuda: + with jt.flag_scope(use_cuda=1): + check() check() def test(self): @@ -92,9 +92,9 @@ class TestCodeOp(unittest.TestCase): one=numpy.ones(a.shape) assert numpy.allclose(da.data,one*2.0) - jt.flags.use_cuda = 0 - check() - jt.flags.use_cuda = 1 + if jt.has_cuda: + with jt.flag_scope(use_cuda=1): + check() check() def test_multi_input(self): @@ -139,9 +139,9 @@ class TestCodeOp(unittest.TestCase): assert numpy.allclose(dda.data,one) assert numpy.allclose(ddb.data,mone) - jt.flags.use_cuda = 0 - check() - jt.flags.use_cuda = 1 + if jt.has_cuda: + with jt.flag_scope(use_cuda=1): + check() check() @unittest.skipIf(True, "Memory leak testing is not in progress, Skip") diff --git a/python/jittor/test/test_parallel_pass.py b/python/jittor/test/test_parallel_pass.py index 9b334198..cda40ac2 100644 --- a/python/jittor/test/test_parallel_pass.py +++ b/python/jittor/test/test_parallel_pass.py @@ -36,7 +36,7 @@ class TestParallelPass(unittest.TestCase): b = jt.random((n, n)) a.data, b.data with jt.profile_scope(compile_options = { - "compile_shapes":1, "parallel":1 + "compile_shapes":1, "parallel":2, "try_use_32bit_index":use_int32 }, try_use_32bit_index = use_int32) as rep: c = a + b nc = c.data