fix Layernorm bugs

This commit is contained in:
li-xl 2020-11-27 15:19:31 +08:00
parent f54701a71c
commit 3c7d59ea16
1 changed files with 4 additions and 3 deletions

View File

@ -420,13 +420,14 @@ class LayerNorm(Module):
def execute(self, x):
dims = [-i for i in range(len(self.normalized_shape), 0, -1)]
xmean = jt.mean(x, dims=dims)
x2mean = jt.mean(x*x, dims=dims)
xmean = jt.mean(x, dims=dims, keepdims=1)
x2mean = jt.mean(x*x, dims=dims, keepdims=1)
xvar = (x2mean-xmean*xmean).maximum(0.0)
w = self.weight / jt.sqrt(xvar+self.eps)
b = self.bias - xmean * w
return x * w.broadcast(x, dims) + b.broadcast(x, dims)
return x * w + b
LayerNorm2d = LayerNorm1d = LayerNorm