mirror of https://github.com/Jittor/Jittor
fix sigmoid
This commit is contained in:
parent
fb72d78e14
commit
d2b571e2bc
|
@ -141,6 +141,11 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
|
|||
a = m(jt.array([1000]))
|
||||
assert np.isnan(a.data).sum()==0, a
|
||||
|
||||
def test_sigmoid_nan(self):
|
||||
a = jt.float32([1,-1, -1000.1])
|
||||
da = jt.grad(a.sigmoid(), a)
|
||||
assert np.isnan(da.data).sum()==0, da.data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -40,6 +40,7 @@ class TestUnaryOp(unittest.TestCase):
|
|||
"sin", "arcsin", "sinh", "arcsinh",
|
||||
"tan", "arctan", "tanh", "arctanh",
|
||||
"cos", "arccos", "cosh", "arccosh",
|
||||
"sigmoid",
|
||||
]
|
||||
a = [1.1, 2.2, 3.3, 4.4]
|
||||
for op in ops:
|
||||
|
@ -52,6 +53,8 @@ class TestUnaryOp(unittest.TestCase):
|
|||
else:
|
||||
b = np.array(a)
|
||||
func = lambda x: eval(f"np.{op}(x[0]).sum()")
|
||||
if op == "sigmoid":
|
||||
func = lambda x: (1/(1+np.exp(-x[0]))).sum()
|
||||
x, (da,) = ngrad(func, [b], 1e-8)
|
||||
ja = jt.array(b)
|
||||
jb = eval(f"jt.{op}(ja)")
|
||||
|
|
|
@ -73,6 +73,7 @@ static unordered_set<string> unary_ops = {
|
|||
"acos",
|
||||
"cosh",
|
||||
"acosh",
|
||||
"sigmoid",
|
||||
};
|
||||
|
||||
static unordered_set<string> unary_float_ops = {
|
||||
|
|
|
@ -76,6 +76,7 @@ namespace jittor {
|
|||
m(acos) \
|
||||
m(cosh) \
|
||||
m(acosh) \
|
||||
m(sigmoid) \
|
||||
|
||||
struct NanoString;
|
||||
#define DECLEAR_NS(T) extern NanoString ns_##T;
|
||||
|
|
|
@ -67,6 +67,7 @@ static unordered_set<string> unary_ops = {
|
|||
"cosh",
|
||||
// @pybind(acosh, arccosh)
|
||||
"acosh",
|
||||
"sigmoid",
|
||||
};
|
||||
|
||||
UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
|
||||
|
@ -183,6 +184,12 @@ VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
x2 = make_binary(one, x2, ns_subtract);
|
||||
return make_binary(dout, x2, ns_divide);
|
||||
}
|
||||
// dsigmoid(x) = sigmoid(x) - sigmoid(x)^2
|
||||
if (ns == ns_sigmoid) {
|
||||
auto r = make_binary(out, out, ns_multiply);
|
||||
r = make_binary(out, r, ns_subtract);
|
||||
return make_binary(dout, r, ns_multiply);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -35,6 +35,8 @@ namespace jittor {
|
|||
#define tanh(T,x) ((T) ::tanhf((x)))
|
||||
#define atanh(T,x) ((T) ::atanhf((x)))
|
||||
|
||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+::expf(-(x)))))
|
||||
|
||||
#else
|
||||
#define abs(T,x) std::abs(x)
|
||||
#define log(T,x) std::log((T)(x))
|
||||
|
@ -59,6 +61,8 @@ namespace jittor {
|
|||
#define tanh(T,x) ((T) std::tanh((x)))
|
||||
#define atanh(T,x) ((T) std::atanh((x)))
|
||||
|
||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+std::exp(-(x)))))
|
||||
|
||||
#endif
|
||||
|
||||
#define cast(T,x) ((T)(x))
|
||||
|
|
Loading…
Reference in New Issue