mirror of https://github.com/Jittor/Jittor
fix test_max_pool2d
This commit is contained in:
parent
49ac9be019
commit
ced499215d
|
@ -173,7 +173,7 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
check(jt_model, torch_model, shape, False)
|
||||
print('finish')
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
def test_max_pool2d(self):
|
||||
from torch.nn.functional import max_pool2d
|
||||
arr = np.random.random((2, 16, 33, 33))
|
||||
jt_model = j_max_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True)
|
||||
|
|
Loading…
Reference in New Issue