mirror of https://github.com/Jittor/Jittor
fix Layernorm bugs
This commit is contained in:
parent
f54701a71c
commit
3c7d59ea16
|
@ -420,13 +420,14 @@ class LayerNorm(Module):
|
|||
|
||||
def execute(self, x):
|
||||
dims = [-i for i in range(len(self.normalized_shape), 0, -1)]
|
||||
xmean = jt.mean(x, dims=dims)
|
||||
x2mean = jt.mean(x*x, dims=dims)
|
||||
xmean = jt.mean(x, dims=dims, keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=dims, keepdims=1)
|
||||
|
||||
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
||||
w = self.weight / jt.sqrt(xvar+self.eps)
|
||||
b = self.bias - xmean * w
|
||||
return x * w.broadcast(x, dims) + b.broadcast(x, dims)
|
||||
return x * w + b
|
||||
|
||||
|
||||
LayerNorm2d = LayerNorm1d = LayerNorm
|
||||
|
||||
|
|
Loading…
Reference in New Issue