diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 72620dc2..94019b1d 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -142,9 +142,10 @@ def get_init_var_rand(shape, dtype): def relu(x): return jt.maximum(x, 0) def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale) def relu6(x): return jt.minimum(jt.maximum(x, 0), 6) -def sign(x): - x = jt.ternary(x>0, jt.ones(x.shape), x) - return jt.ternary(x<0, -1*jt.ones(x.shape), x) +def sign(x): + one = jt.ones(x.shape) + x = jt.ternary(x>0, one, x) + return jt.ternary(x<0, -one, x) def gelu(x): _sqrt2 = 1.4142135623730951 @@ -1173,9 +1174,6 @@ class Sequential(Module): self.append(m) else: self.append(mod) - def __iter__(self): - for v in self.layers.values(): - yield v def __getitem__(self, idx): return self.layers[idx] def __iter__(self):