update groupnorm

This commit is contained in:
zhouwy19 2020-07-31 08:48:08 +08:00
parent cad1845e68
commit 1e32486f59
1 changed files with 1 additions and 7 deletions

View File

@ -311,13 +311,11 @@ class InstanceNorm2d(Module):
return norm_x * w + b
class GroupNorm(Module):
def __init__(self, num_groups, num_channels, eps=1e-05, affine=None, is_train=True, sync=True):
def __init__(self, num_groups, num_channels, eps=1e-05, affine=None, is_train=True):
assert affine == None
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.is_train = is_train
self.sync = sync
self.weight = init.constant((num_channels,), "float32", 1.0)
self.bias = init.constant((num_channels,), "float32", 0.0)
@ -330,10 +328,6 @@ class GroupNorm(Module):
])
xmean = jt.mean(x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"])
x2mean = jt.mean(x_*x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"])
if self.sync and jt.in_mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
xvar = jt.maximum(x2mean-xmean*xmean, 0)
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
w = self.weight.broadcast(x, [0,2,3])