add test pool

This commit is contained in:
zhouwy19 2021-03-24 17:50:34 +08:00
parent 99b7fbeec1
commit c4a9816ec0
1 changed files with 22 additions and 0 deletions

View File

@ -90,6 +90,18 @@ class TestArgPoolOp(unittest.TestCase):
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
@jt.flag_scope(use_cuda=1)
def test_cuda_tuple(self):
jt_model = jt.nn.Sequential(Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1), ceil_mode=True), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool(3, 1, 1))
torch_model = Sequential(MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1), ceil_mode=True), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d(3, 1, 1))
shape = [2, 3, 300, 300]
check(jt_model, torch_model, shape, False)
shape = [2, 3, 157, 300]
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
@unittest.skipIf(True, "TODO: cannot pass this test, fix me")
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
@jt.flag_scope(use_cuda=1)
@ -120,6 +132,16 @@ class TestArgPoolOp(unittest.TestCase):
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
def test_cpu_tuple(self):
jt_model = jt.nn.Sequential(Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1), ceil_mode=True), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool(3, 1, 1))
torch_model = Sequential(MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1), ceil_mode=True), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d(3, 1, 1))
shape = [2, 3, 300, 300]
check(jt_model, torch_model, shape, False)
shape = [2, 3, 157, 300]
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
@jt.flag_scope(use_cuda=1)