fix sigmoid

This commit is contained in:
Dun Liang 2020-05-07 15:22:50 +08:00
parent fb72d78e14
commit d2b571e2bc
6 changed files with 21 additions and 0 deletions

View File

@ -141,6 +141,11 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
a = m(jt.array([1000]))
assert np.isnan(a.data).sum()==0, a
def test_sigmoid_nan(self):
a = jt.float32([1,-1, -1000.1])
da = jt.grad(a.sigmoid(), a)
assert np.isnan(da.data).sum()==0, da.data
if __name__ == "__main__":
unittest.main()

View File

@ -40,6 +40,7 @@ class TestUnaryOp(unittest.TestCase):
"sin", "arcsin", "sinh", "arcsinh",
"tan", "arctan", "tanh", "arctanh",
"cos", "arccos", "cosh", "arccosh",
"sigmoid",
]
a = [1.1, 2.2, 3.3, 4.4]
for op in ops:
@ -52,6 +53,8 @@ class TestUnaryOp(unittest.TestCase):
else:
b = np.array(a)
func = lambda x: eval(f"np.{op}(x[0]).sum()")
if op == "sigmoid":
func = lambda x: (1/(1+np.exp(-x[0]))).sum()
x, (da,) = ngrad(func, [b], 1e-8)
ja = jt.array(b)
jb = eval(f"jt.{op}(ja)")

View File

@ -73,6 +73,7 @@ static unordered_set<string> unary_ops = {
"acos",
"cosh",
"acosh",
"sigmoid",
};
static unordered_set<string> unary_float_ops = {

View File

@ -76,6 +76,7 @@ namespace jittor {
m(acos) \
m(cosh) \
m(acosh) \
m(sigmoid) \
struct NanoString;
#define DECLEAR_NS(T) extern NanoString ns_##T;

View File

@ -67,6 +67,7 @@ static unordered_set<string> unary_ops = {
"cosh",
// @pybind(acosh, arccosh)
"acosh",
"sigmoid",
};
UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
@ -183,6 +184,12 @@ VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
x2 = make_binary(one, x2, ns_subtract);
return make_binary(dout, x2, ns_divide);
}
// dsigmoid(x) = sigmoid(x) - sigmoid(x)^2
if (ns == ns_sigmoid) {
auto r = make_binary(out, out, ns_multiply);
r = make_binary(out, r, ns_subtract);
return make_binary(dout, r, ns_multiply);
}
return nullptr;
}

View File

@ -35,6 +35,8 @@ namespace jittor {
#define tanh(T,x) ((T) ::tanhf((x)))
#define atanh(T,x) ((T) ::atanhf((x)))
#define sigmoid(T,x) ((T) (1.0f/(1.0f+::expf(-(x)))))
#else
#define abs(T,x) std::abs(x)
#define log(T,x) std::log((T)(x))
@ -59,6 +61,8 @@ namespace jittor {
#define tanh(T,x) ((T) std::tanh((x)))
#define atanh(T,x) ((T) std::atanh((x)))
#define sigmoid(T,x) ((T) (1.0f/(1.0f+std::exp(-(x)))))
#endif
#define cast(T,x) ((T)(x))