From 987d42cc886cc4097f48ef62bb687c2d26bfe7a6 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Sat, 11 Apr 2020 17:51:01 +0800 Subject: [PATCH] fix lock test --- python/jittor/compiler.py | 1 + python/jittor/test/__main__.py | 32 +++++++++++++++++--------- python/jittor/test/test_arg_pool_op.py | 1 + python/jittor/test/test_clone.py | 1 + python/jittor/test/test_cuda.py | 18 +++++++-------- python/jittor/test/test_cutt.py | 3 +-- python/jittor/test/test_lock.py | 13 ++++------- python/jittor/test/test_misc_issue.py | 3 +++ 8 files changed, 42 insertions(+), 30 deletions(-) diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index 6216c6da..c0a0bbcd 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -519,6 +519,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""): """ return jit_src +@lock.lock_scope() def compile_custom_op(header, source, op_name, warp=True): """Compile a single custom op header: code of op header, not path diff --git a/python/jittor/test/__main__.py b/python/jittor/test/__main__.py index ce74a6df..c6374c62 100644 --- a/python/jittor/test/__main__.py +++ b/python/jittor/test/__main__.py @@ -5,23 +5,33 @@ # *************************************************************** if __name__ == "__main__": - import unittest + import unittest, os suffix = "__main__.py" assert __file__.endswith(suffix) test_dir = __file__[:-len(suffix)] - import os + + skip_l = int(os.environ.get("test_skip_l", "0")) + skip_r = int(os.environ.get("test_skip_r", "1000000")) + test_only = None + if "test_only" in os.environ: + test_only = set(os.environ.get("test_only").split(",")) + test_files = os.listdir(test_dir) - for test_file in test_files: + test_files = sorted(test_files) + suite = unittest.TestSuite() + + for _, test_file in enumerate(test_files): if not test_file.startswith("test_"): continue + if _ < skip_l or _ > skip_r: + continue test_name = test_file.split(".")[0] - exec(f"from . import {test_name}") - test_mod = globals()[test_name] - print(test_name) - for i in dir(test_mod): - obj = getattr(test_mod, i) - if isinstance(obj, type) and issubclass(obj, unittest.TestCase): - globals()[test_name+"_"+i] = obj + if test_only and test_name not in test_only: + continue - unittest.main() + print("Add Test", _, test_name) + suite.addTest(unittest.defaultTestLoader.loadTestsFromName( + "jittor.test."+test_name)) + + unittest.TextTestRunner(verbosity=3).run(suite) \ No newline at end of file diff --git a/python/jittor/test/test_arg_pool_op.py b/python/jittor/test/test_arg_pool_op.py index 6fa758de..6481d30e 100644 --- a/python/jittor/test/test_arg_pool_op.py +++ b/python/jittor/test/test_arg_pool_op.py @@ -18,6 +18,7 @@ import pickle as pk skip_this_test = False try: + jt.dirty_fix_pytorch_runtime_error() import torch from torch.nn import MaxPool2d, Sequential except: diff --git a/python/jittor/test/test_clone.py b/python/jittor/test/test_clone.py index 95a08cc6..b63b9468 100644 --- a/python/jittor/test/test_clone.py +++ b/python/jittor/test/test_clone.py @@ -11,6 +11,7 @@ import numpy as np class TestClone(unittest.TestCase): def test(self): + jt.clean() b = a = jt.array(1) for i in range(10): b = b.clone() diff --git a/python/jittor/test/test_cuda.py b/python/jittor/test/test_cuda.py index 0a0f9738..0778f878 100644 --- a/python/jittor/test/test_cuda.py +++ b/python/jittor/test/test_cuda.py @@ -18,11 +18,12 @@ def test_cuda(use_cuda=1): @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") class TestCuda(unittest.TestCase): + @jt.flag_scope(use_cuda=1) def test_cuda_flags(self): - with jt.var_scope(use_cuda=1): - a = jt.random((10, 10)) - a.sync() + a = jt.random((10, 10)) + a.sync() + @jt.flag_scope(use_cuda=2) def test_no_cuda_op(self): no_cuda_op = jt.compile_custom_op(""" struct NoCudaOp : Op { @@ -49,10 +50,10 @@ class TestCuda(unittest.TestCase): """, "no_cuda") # force use cuda - with jt.var_scope(use_cuda=2): - a = no_cuda_op([3,4,5], 'float') - expect_error(lambda: a()) + a = no_cuda_op([3,4,5], 'float') + expect_error(lambda: a()) + @jt.flag_scope(use_cuda=1) def test_cuda_custom_op(self): my_op = jt.compile_custom_op(""" struct MyCudaOp : Op { @@ -94,9 +95,8 @@ class TestCuda(unittest.TestCase): #endif // JIT """, "my_cuda") - with jt.var_scope(use_cuda=1): - a = my_op([3,4,5], 'float') - na = a.data + a = my_op([3,4,5], 'float') + na = a.data assert a.shape == [3,4,5] and a.dtype == 'float' assert (-na.flatten() == range(3*4*5)).all(), na diff --git a/python/jittor/test/test_cutt.py b/python/jittor/test/test_cutt.py index d5b9b714..282cc07c 100644 --- a/python/jittor/test/test_cutt.py +++ b/python/jittor/test/test_cutt.py @@ -22,7 +22,6 @@ class TestCutt(unittest.TestCase): @jt.flag_scope(use_cuda=1) def test(self): t = cutt_ops.cutt_test("213") - jt.sync_all(True) - print(t.data) + assert t.data == 123 if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_lock.py b/python/jittor/test/test_lock.py index 65c93131..7cff705e 100644 --- a/python/jittor/test/test_lock.py +++ b/python/jittor/test/test_lock.py @@ -11,19 +11,16 @@ import os, sys import jittor as jt from pathlib import Path -@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found") class TestLock(unittest.TestCase): def test(self): - mpi = jt.compile_extern.mpi - mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun") if os.environ.get('lock_full_test', '0') == '1': cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock") - cmd = f"rm -rf {cache_path} && cache_name=lock {mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example" + assert os.system(f"rm -rf {cache_path}") == 0 + cmd = f"cache_name=lock {sys.executable} -m jittor.test.test_example" else: - cache_path = os.path.join(str(Path.home()), ".cache", "jittor") - cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example" - print("run cmd", cmd) - assert os.system(cmd) == 0 + cmd = f"{sys.executable} -m jittor.test.test_example" + print("run cmd twice", cmd) + assert os.system(f"{cmd} & {cmd} & wait %1 && wait %2") == 0 if __name__ == "__main__": diff --git a/python/jittor/test/test_misc_issue.py b/python/jittor/test/test_misc_issue.py index bdbd6643..f13630ff 100644 --- a/python/jittor/test/test_misc_issue.py +++ b/python/jittor/test/test_misc_issue.py @@ -11,6 +11,7 @@ import numpy as np class TestMiscIssue(unittest.TestCase): def test_issue4(self): try: + jt.dirty_fix_pytorch_runtime_error() import torch except: return @@ -42,6 +43,7 @@ b.sync() def test_mkl_conflict1(self): try: + jt.dirty_fix_pytorch_runtime_error() import torch except: return @@ -67,6 +69,7 @@ m(torch.rand(*nchw)) def test_mkl_conflict2(self): try: + jt.dirty_fix_pytorch_runtime_error() import torch except: return