sync batchnorm

This commit is contained in:
guowei yang 2020-04-19 15:40:27 +08:00
parent 6c9194bb89
commit a0f6275926
1 changed files with 20 additions and 65 deletions

View File

@ -42,31 +42,7 @@ def get_init_var_rand(shape, dtype):
return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32))
@jt.var_scope('batch_norm')
def batch_norm(x, is_train, eps=1e-5, momentum=0.1):
w = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
b = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
running_mean = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
running_var = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
w = w.broadcast(x, [0,2,3])
b = b.broadcast(x, [0,2,3])
if is_train:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+eps)
running_mean += (xmean.sum([0,2,3])-running_mean)*momentum
running_var += (xvar.sum([0,2,3])-running_var)*momentum
else:
running_mean = running_mean.broadcast(x, [0,2,3])
running_var = running_var.broadcast(x, [0,2,3])
norm_x = (x-running_mean)/jt.sqrt(running_var+eps)
return norm_x * w + b
@jt.var_scope('sync_batch_norm')
def sync_batch_norm(x, is_train, eps=1e-5, momentum=0.1):
def batch_norm(x, is_train, eps=1e-5, momentum=0.1, sync=True):
assert not (jt.compile_extern.mpi_ops is None)
w = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
b = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
@ -76,10 +52,14 @@ def sync_batch_norm(x, is_train, eps=1e-5, momentum=0.1):
w = w.broadcast(x, [0,2,3])
b = b.broadcast(x, [0,2,3])
if is_train:
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
if self.sync and not (jt.compile_extern.mpi_ops is None):
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
else:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+eps)
@ -140,7 +120,7 @@ def cross_entropy_loss(output, target, ignore_index=None):
target = target.reshape((-1, ))
target = target.broadcast(output, [1])
target = target.index(1) == target
output = output - output.max([1], keepdims=True)
loss = output.exp().sum(1).log()
loss = loss - (output*target).sum(1)
@ -267,9 +247,10 @@ class Linear(Module):
return x
class BatchNorm(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
assert affine == None
self.sync = sync
self.num_features = num_features
self.is_train = is_train
self.eps = eps
@ -281,40 +262,14 @@ class BatchNorm(Module):
def execute(self, x):
if self.is_train:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
else:
running_mean = self.running_mean.broadcast(x, [0,2,3])
running_var = self.running_var.broadcast(x, [0,2,3])
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
w = self.weight.broadcast(x, [0,2,3])
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
class SyncBatchNorm(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True):
assert affine == None
self.num_features = num_features
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
def execute(self, x):
if self.is_train:
assert not (jt.compile_extern.mpi_ops is None)
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
if self.sync and not (jt.compile_extern.mpi_ops is None):
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
else:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)