mirror of https://github.com/Jittor/Jittor
add test pool
This commit is contained in:
parent
99b7fbeec1
commit
c4a9816ec0
|
@ -90,6 +90,18 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
for i in range(10):
|
||||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda_tuple(self):
|
||||
jt_model = jt.nn.Sequential(Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1), ceil_mode=True), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool(3, 1, 1))
|
||||
torch_model = Sequential(MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1), ceil_mode=True), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d(3, 1, 1))
|
||||
shape = [2, 3, 300, 300]
|
||||
check(jt_model, torch_model, shape, False)
|
||||
shape = [2, 3, 157, 300]
|
||||
check(jt_model, torch_model, shape, False)
|
||||
for i in range(10):
|
||||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
||||
@unittest.skipIf(True, "TODO: cannot pass this test, fix me")
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
|
@ -120,6 +132,16 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
check(jt_model, torch_model, shape, False)
|
||||
for i in range(10):
|
||||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
||||
def test_cpu_tuple(self):
|
||||
jt_model = jt.nn.Sequential(Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1), ceil_mode=True), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool(3, 1, 1))
|
||||
torch_model = Sequential(MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1), ceil_mode=True), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d(3, 1, 1))
|
||||
shape = [2, 3, 300, 300]
|
||||
check(jt_model, torch_model, shape, False)
|
||||
shape = [2, 3, 157, 300]
|
||||
check(jt_model, torch_model, shape, False)
|
||||
for i in range(10):
|
||||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
|
|
Loading…
Reference in New Issue