From b7429396057c6fe0976fce57d54f682afb7b7a5e Mon Sep 17 00:00:00 2001 From: zhouwy19 Date: Wed, 28 Oct 2020 14:07:42 +0800 Subject: [PATCH] update sign --- python/jittor/nn.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 72620dc2..94019b1d 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -142,9 +142,10 @@ def get_init_var_rand(shape, dtype): def relu(x): return jt.maximum(x, 0) def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale) def relu6(x): return jt.minimum(jt.maximum(x, 0), 6) -def sign(x): - x = jt.ternary(x>0, jt.ones(x.shape), x) - return jt.ternary(x<0, -1*jt.ones(x.shape), x) +def sign(x): + one = jt.ones(x.shape) + x = jt.ternary(x>0, one, x) + return jt.ternary(x<0, -one, x) def gelu(x): _sqrt2 = 1.4142135623730951 @@ -1173,9 +1174,6 @@ class Sequential(Module): self.append(m) else: self.append(mod) - def __iter__(self): - for v in self.layers.values(): - yield v def __getitem__(self, idx): return self.layers[idx] def __iter__(self):