distributed batchnorm

This commit is contained in:
guowei yang 2020-04-10 13:31:40 +08:00
parent 93d7f59985
commit 0fcdf20cfa
1 changed files with 8 additions and 2 deletions

View File

@ -51,8 +51,14 @@ def batch_norm(x, is_train, eps=1e-5, momentum=0.1):
w = w.broadcast(x, [0,2,3])
b = b.broadcast(x, [0,2,3])
if is_train:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
if (jt.compile_extern.mpi_ops is None):
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
else:
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+eps)