This commit is contained in:
zwy 2020-05-15 15:15:37 +08:00
parent 96ad1db2d8
commit d79bd8e09c
1 changed files with 2 additions and 1 deletions

View File

@ -88,6 +88,8 @@ class TestBatchNorm(unittest.TestCase):
# Test BatchNorm1d Layer
# ***************************************************************
arr = np.random.randn(16,10)
check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=True), tnn.BatchNorm1d(10), 1e-3)
class Model(tnn.Module):
def __init__(self):
super(Model, self).__init__()
@ -97,7 +99,6 @@ class TestBatchNorm(unittest.TestCase):
model = Model()
model.eval()
check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=False), model, False)
check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=True), tnn.BatchNorm1d(10))
if __name__ == "__main__":
unittest.main()