delete batchnorm1d

This commit is contained in:
zwy 2020-04-29 23:05:35 +08:00
parent 856db28b00
commit dd58aace62
2 changed files with 14 additions and 51 deletions

View File

@ -285,40 +285,6 @@ class BatchNorm(Module):
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
class BatchNorm1d(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
assert affine == None
self.sync = sync
self.num_features = num_features
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
def execute(self, x):
if self.is_train:
xmean = jt.mean(x, dims=[0], keepdims=1)
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
if self.sync and jt.mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean += (xmean.sum([0])-self.running_mean)*self.momentum
self.running_var += (xvar.sum([0])-self.running_var)*self.momentum
else:
running_mean = self.running_mean.broadcast(x, [0])
running_var = self.running_var.broadcast(x, [0])
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
w = self.weight.broadcast(x, [0])
b = self.bias.broadcast(x, [0])
return norm_x * w + b
Relu = jt.make_module(relu)
ReLU = Relu
Leaky_relu = jt.make_module(leaky_relu, 2)

View File

@ -23,11 +23,23 @@ except:
tnn = None
skip_this_test = True
def check_equal(arr, j_layer, p_layer, threshold=1e-5):
def check_equal(arr, j_layer, p_layer, is_train=True, threshold=1e-5):
jittor_arr = jt.array(arr)
pytorch_arr = torch.Tensor(arr)
if is_train:
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
assert np.allclose(p_layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
else:
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
assert np.allclose(p_layer.layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
jittor_result = j_layer(jittor_arr)
pytorch_result = p_layer(pytorch_arr)
if is_train:
assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
assert np.allclose(p_layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
else:
assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold)
assert np.allclose(p_layer.layer.running_var.detach().numpy(), j_layer.running_var.numpy(), threshold)
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold)
@unittest.skipIf(skip_this_test, "No Torch found")
@ -47,23 +59,8 @@ class TestBatchNorm(unittest.TestCase):
return self.layer(x)
model = Model()
model.eval()
check_equal(arr, jnn.BatchNorm(10, is_train=False), model)
check_equal(arr, jnn.BatchNorm(10, is_train=False), model, False)
# ***************************************************************
# Test BatchNorm1d Layer
# ***************************************************************
arr = np.random.randn(16,1000)
check_equal(arr, jnn.BatchNorm1d(1000), tnn.BatchNorm1d(1000), 1e-3)
class Model(tnn.Module):
def __init__(self):
super(Model, self).__init__()
self.layer = tnn.BatchNorm1d(1000)
def forward(self, x):
return self.layer(x)
model = Model()
model.eval()
check_equal(arr, jnn.BatchNorm1d(1000, is_train=False), model, 1e-3)
if __name__ == "__main__":
unittest.main()