update sign

This commit is contained in:
zhouwy19 2020-10-28 14:07:42 +08:00
parent 8e93e36f53
commit b742939605
1 changed files with 4 additions and 6 deletions

View File

@ -142,9 +142,10 @@ def get_init_var_rand(shape, dtype):
def relu(x): return jt.maximum(x, 0) def relu(x): return jt.maximum(x, 0)
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale) def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
def relu6(x): return jt.minimum(jt.maximum(x, 0), 6) def relu6(x): return jt.minimum(jt.maximum(x, 0), 6)
def sign(x): def sign(x):
x = jt.ternary(x>0, jt.ones(x.shape), x) one = jt.ones(x.shape)
return jt.ternary(x<0, -1*jt.ones(x.shape), x) x = jt.ternary(x>0, one, x)
return jt.ternary(x<0, -one, x)
def gelu(x): def gelu(x):
_sqrt2 = 1.4142135623730951 _sqrt2 = 1.4142135623730951
@ -1173,9 +1174,6 @@ class Sequential(Module):
self.append(m) self.append(m)
else: else:
self.append(mod) self.append(mod)
def __iter__(self):
for v in self.layers.values():
yield v
def __getitem__(self, idx): def __getitem__(self, idx):
return self.layers[idx] return self.layers[idx]
def __iter__(self): def __iter__(self):