mirror of https://github.com/Jittor/Jittor
update groupnorm
This commit is contained in:
parent
cad1845e68
commit
1e32486f59
|
@ -311,13 +311,11 @@ class InstanceNorm2d(Module):
|
|||
return norm_x * w + b
|
||||
|
||||
class GroupNorm(Module):
|
||||
def __init__(self, num_groups, num_channels, eps=1e-05, affine=None, is_train=True, sync=True):
|
||||
def __init__(self, num_groups, num_channels, eps=1e-05, affine=None, is_train=True):
|
||||
assert affine == None
|
||||
self.num_groups = num_groups
|
||||
self.num_channels = num_channels
|
||||
self.eps = eps
|
||||
self.is_train = is_train
|
||||
self.sync = sync
|
||||
self.weight = init.constant((num_channels,), "float32", 1.0)
|
||||
self.bias = init.constant((num_channels,), "float32", 0.0)
|
||||
|
||||
|
@ -330,10 +328,6 @@ class GroupNorm(Module):
|
|||
])
|
||||
xmean = jt.mean(x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"])
|
||||
x2mean = jt.mean(x_*x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"])
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = jt.maximum(x2mean-xmean*xmean, 0)
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
|
|
Loading…
Reference in New Issue