fix lock test

This commit is contained in:
Dun Liang 2020-04-11 17:51:01 +08:00
parent 9178b3459a
commit 987d42cc88
8 changed files with 42 additions and 30 deletions

View File

@ -519,6 +519,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
""" """
return jit_src return jit_src
@lock.lock_scope()
def compile_custom_op(header, source, op_name, warp=True): def compile_custom_op(header, source, op_name, warp=True):
"""Compile a single custom op """Compile a single custom op
header: code of op header, not path header: code of op header, not path

View File

@ -5,23 +5,33 @@
# *************************************************************** # ***************************************************************
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest, os
suffix = "__main__.py" suffix = "__main__.py"
assert __file__.endswith(suffix) assert __file__.endswith(suffix)
test_dir = __file__[:-len(suffix)] test_dir = __file__[:-len(suffix)]
import os
skip_l = int(os.environ.get("test_skip_l", "0"))
skip_r = int(os.environ.get("test_skip_r", "1000000"))
test_only = None
if "test_only" in os.environ:
test_only = set(os.environ.get("test_only").split(","))
test_files = os.listdir(test_dir) test_files = os.listdir(test_dir)
for test_file in test_files: test_files = sorted(test_files)
suite = unittest.TestSuite()
for _, test_file in enumerate(test_files):
if not test_file.startswith("test_"): if not test_file.startswith("test_"):
continue continue
if _ < skip_l or _ > skip_r:
continue
test_name = test_file.split(".")[0] test_name = test_file.split(".")[0]
exec(f"from . import {test_name}") if test_only and test_name not in test_only:
test_mod = globals()[test_name] continue
print(test_name)
for i in dir(test_mod):
obj = getattr(test_mod, i)
if isinstance(obj, type) and issubclass(obj, unittest.TestCase):
globals()[test_name+"_"+i] = obj
unittest.main() print("Add Test", _, test_name)
suite.addTest(unittest.defaultTestLoader.loadTestsFromName(
"jittor.test."+test_name))
unittest.TextTestRunner(verbosity=3).run(suite)

View File

@ -18,6 +18,7 @@ import pickle as pk
skip_this_test = False skip_this_test = False
try: try:
jt.dirty_fix_pytorch_runtime_error()
import torch import torch
from torch.nn import MaxPool2d, Sequential from torch.nn import MaxPool2d, Sequential
except: except:

View File

@ -11,6 +11,7 @@ import numpy as np
class TestClone(unittest.TestCase): class TestClone(unittest.TestCase):
def test(self): def test(self):
jt.clean()
b = a = jt.array(1) b = a = jt.array(1)
for i in range(10): for i in range(10):
b = b.clone() b = b.clone()

View File

@ -18,11 +18,12 @@ def test_cuda(use_cuda=1):
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
class TestCuda(unittest.TestCase): class TestCuda(unittest.TestCase):
@jt.flag_scope(use_cuda=1)
def test_cuda_flags(self): def test_cuda_flags(self):
with jt.var_scope(use_cuda=1): a = jt.random((10, 10))
a = jt.random((10, 10)) a.sync()
a.sync()
@jt.flag_scope(use_cuda=2)
def test_no_cuda_op(self): def test_no_cuda_op(self):
no_cuda_op = jt.compile_custom_op(""" no_cuda_op = jt.compile_custom_op("""
struct NoCudaOp : Op { struct NoCudaOp : Op {
@ -49,10 +50,10 @@ class TestCuda(unittest.TestCase):
""", """,
"no_cuda") "no_cuda")
# force use cuda # force use cuda
with jt.var_scope(use_cuda=2): a = no_cuda_op([3,4,5], 'float')
a = no_cuda_op([3,4,5], 'float') expect_error(lambda: a())
expect_error(lambda: a())
@jt.flag_scope(use_cuda=1)
def test_cuda_custom_op(self): def test_cuda_custom_op(self):
my_op = jt.compile_custom_op(""" my_op = jt.compile_custom_op("""
struct MyCudaOp : Op { struct MyCudaOp : Op {
@ -94,9 +95,8 @@ class TestCuda(unittest.TestCase):
#endif // JIT #endif // JIT
""", """,
"my_cuda") "my_cuda")
with jt.var_scope(use_cuda=1): a = my_op([3,4,5], 'float')
a = my_op([3,4,5], 'float') na = a.data
na = a.data
assert a.shape == [3,4,5] and a.dtype == 'float' assert a.shape == [3,4,5] and a.dtype == 'float'
assert (-na.flatten() == range(3*4*5)).all(), na assert (-na.flatten() == range(3*4*5)).all(), na

View File

@ -22,7 +22,6 @@ class TestCutt(unittest.TestCase):
@jt.flag_scope(use_cuda=1) @jt.flag_scope(use_cuda=1)
def test(self): def test(self):
t = cutt_ops.cutt_test("213") t = cutt_ops.cutt_test("213")
jt.sync_all(True) assert t.data == 123
print(t.data)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -11,19 +11,16 @@ import os, sys
import jittor as jt import jittor as jt
from pathlib import Path from pathlib import Path
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestLock(unittest.TestCase): class TestLock(unittest.TestCase):
def test(self): def test(self):
mpi = jt.compile_extern.mpi
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
if os.environ.get('lock_full_test', '0') == '1': if os.environ.get('lock_full_test', '0') == '1':
cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock") cache_path = os.path.join(str(Path.home()), ".cache", "jittor", "lock")
cmd = f"rm -rf {cache_path} && cache_name=lock {mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example" assert os.system(f"rm -rf {cache_path}") == 0
cmd = f"cache_name=lock {sys.executable} -m jittor.test.test_example"
else: else:
cache_path = os.path.join(str(Path.home()), ".cache", "jittor") cmd = f"{sys.executable} -m jittor.test.test_example"
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_example" print("run cmd twice", cmd)
print("run cmd", cmd) assert os.system(f"{cmd} & {cmd} & wait %1 && wait %2") == 0
assert os.system(cmd) == 0
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -11,6 +11,7 @@ import numpy as np
class TestMiscIssue(unittest.TestCase): class TestMiscIssue(unittest.TestCase):
def test_issue4(self): def test_issue4(self):
try: try:
jt.dirty_fix_pytorch_runtime_error()
import torch import torch
except: except:
return return
@ -42,6 +43,7 @@ b.sync()
def test_mkl_conflict1(self): def test_mkl_conflict1(self):
try: try:
jt.dirty_fix_pytorch_runtime_error()
import torch import torch
except: except:
return return
@ -67,6 +69,7 @@ m(torch.rand(*nchw))
def test_mkl_conflict2(self): def test_mkl_conflict2(self):
try: try:
jt.dirty_fix_pytorch_runtime_error()
import torch import torch
except: except:
return return