mirror of https://github.com/Jittor/Jittor
sync batchnorm
This commit is contained in:
parent
6c9194bb89
commit
a0f6275926
|
@ -42,31 +42,7 @@ def get_init_var_rand(shape, dtype):
|
|||
return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32))
|
||||
|
||||
@jt.var_scope('batch_norm')
|
||||
def batch_norm(x, is_train, eps=1e-5, momentum=0.1):
|
||||
w = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
|
||||
b = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
|
||||
running_mean = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
|
||||
running_var = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
|
||||
|
||||
w = w.broadcast(x, [0,2,3])
|
||||
b = b.broadcast(x, [0,2,3])
|
||||
if is_train:
|
||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+eps)
|
||||
|
||||
running_mean += (xmean.sum([0,2,3])-running_mean)*momentum
|
||||
running_var += (xvar.sum([0,2,3])-running_var)*momentum
|
||||
else:
|
||||
running_mean = running_mean.broadcast(x, [0,2,3])
|
||||
running_var = running_var.broadcast(x, [0,2,3])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+eps)
|
||||
|
||||
return norm_x * w + b
|
||||
|
||||
@jt.var_scope('sync_batch_norm')
|
||||
def sync_batch_norm(x, is_train, eps=1e-5, momentum=0.1):
|
||||
def batch_norm(x, is_train, eps=1e-5, momentum=0.1, sync=True):
|
||||
assert not (jt.compile_extern.mpi_ops is None)
|
||||
w = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
|
||||
b = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
|
||||
|
@ -76,10 +52,14 @@ def sync_batch_norm(x, is_train, eps=1e-5, momentum=0.1):
|
|||
w = w.broadcast(x, [0,2,3])
|
||||
b = b.broadcast(x, [0,2,3])
|
||||
if is_train:
|
||||
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
|
||||
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
|
||||
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
|
||||
if self.sync and not (jt.compile_extern.mpi_ops is None):
|
||||
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
|
||||
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
|
||||
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
|
||||
else:
|
||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+eps)
|
||||
|
||||
|
@ -140,7 +120,7 @@ def cross_entropy_loss(output, target, ignore_index=None):
|
|||
target = target.reshape((-1, ))
|
||||
target = target.broadcast(output, [1])
|
||||
target = target.index(1) == target
|
||||
|
||||
|
||||
output = output - output.max([1], keepdims=True)
|
||||
loss = output.exp().sum(1).log()
|
||||
loss = loss - (output*target).sum(1)
|
||||
|
@ -267,9 +247,10 @@ class Linear(Module):
|
|||
return x
|
||||
|
||||
class BatchNorm(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
|
@ -281,40 +262,14 @@ class BatchNorm(Module):
|
|||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
||||
running_var = self.running_var.broadcast(x, [0,2,3])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
class SyncBatchNorm(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True):
|
||||
assert affine == None
|
||||
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
assert not (jt.compile_extern.mpi_ops is None)
|
||||
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
|
||||
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
|
||||
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
|
||||
if self.sync and not (jt.compile_extern.mpi_ops is None):
|
||||
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
|
||||
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
|
||||
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
|
||||
else:
|
||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
|
|
Loading…
Reference in New Issue