mirror of https://github.com/Jittor/Jittor
update
This commit is contained in:
parent
96ad1db2d8
commit
d79bd8e09c
|
@ -88,6 +88,8 @@ class TestBatchNorm(unittest.TestCase):
|
|||
# Test BatchNorm1d Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10)
|
||||
check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=True), tnn.BatchNorm1d(10), 1e-3)
|
||||
|
||||
class Model(tnn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
@ -97,7 +99,6 @@ class TestBatchNorm(unittest.TestCase):
|
|||
model = Model()
|
||||
model.eval()
|
||||
check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=False), model, False)
|
||||
check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=True), tnn.BatchNorm1d(10))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue