diff --git a/python/jittor/test/test_unary_op.py b/python/jittor/test/test_unary_op.py index 4a839396..98f9c2fc 100644 --- a/python/jittor/test/test_unary_op.py +++ b/python/jittor/test/test_unary_op.py @@ -34,11 +34,19 @@ class TestUnaryOp(unittest.TestCase): check("sqrt", a) def test_grad(self): - ops = ["abs", "negative", "log", "exp", "sqrt"] + ops = ["abs", "negative", "log", "exp", "sqrt", + "sin", "arcsin", "sinh", "arcsinh", + "tan", "arctan", "tanh", "arctanh", + "cos", "arccos", "cosh", "arccosh", + ] a = [1.1, 2.2, 3.3, 4.4] for op in ops: if op == "abs": b = np.array(a+[-1,]) + elif op == "arccosh": + b = np.array(a) + elif "sin" in op or "cos" in op or "tan" in op: + b = np.array(a) / 5 else: b = np.array(a) func = lambda x: eval(f"np.{op}(x[0]).sum()") diff --git a/src/ops/unary_op.cc b/src/ops/unary_op.cc index 954bbc83..12294288 100644 --- a/src/ops/unary_op.cc +++ b/src/ops/unary_op.cc @@ -50,16 +50,22 @@ static unordered_set unary_ops = { "floor", "ceil", "sin", + // @pybind(asin, arcsin) "asin", "sinh", + // @pybind(asinh, arcsinh) "asinh", "tan", + // @pybind(atan, arctan) "atan", "tanh", + // @pybind(atanh, arctanh) "atanh", "cos", + // @pybind(acos, arccos) "acos", "cosh", + // @pybind(acosh, arccosh) "acosh", }; @@ -104,6 +110,79 @@ VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) { auto twoy = make_binary(two, y, ns_multiply); return make_binary(dout, twoy, ns_divide); } + // dsin(x) = cos(x) + if (ns == ns_sin) + return make_binary(dout, make_unary(x, ns_cos), ns_multiply); + // dcos(x) = -sin(x) + if (ns == ns_cos) + return make_binary(dout, make_unary(make_unary(x, ns_sin), ns_negative), ns_multiply); + // dtan(x) = 1/cos^2(x) + if (ns == ns_tan) { + auto one = make_number(1, x); + auto cosx = make_unary(x, ns_cos); + auto cos2x = make_binary(cosx, cosx, ns_multiply); + return make_binary(dout, cos2x, ns_divide); + } + // dasin(x) = 1/sqrt(1-x^2) + if (ns == ns_asin) { + auto one = make_number(1, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_binary(one, x2, ns_subtract); + x2 = make_unary(x2, ns_sqrt); + return make_binary(dout, x2, ns_divide); + } + // dacos(x) = -1/sqrt(1-x^2) + if (ns == ns_acos) { + auto one = make_number(1, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_binary(one, x2, ns_subtract); + x2 = make_unary(x2, ns_sqrt); + return make_unary(make_binary(dout, x2, ns_divide), ns_negative); + } + // datan(x) = 1/(x^2+1) + if (ns == ns_atan) { + auto one = make_number(1, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_binary(one, x2, ns_add); + return make_binary(dout, x2, ns_divide); + } + + // dsinh(x) = cosh(x) + if (ns == ns_sinh) + return make_binary(dout, make_unary(x, ns_cosh), ns_multiply); + // dcosh(x) = sinh(x) + if (ns == ns_cosh) + return make_binary(dout, make_unary(x, ns_sinh), ns_multiply); + // dtanh(x) = 1/cosh^2(x) + if (ns == ns_tanh) { + auto cosx = make_unary(x, ns_cosh); + auto cos2x = make_binary(cosx, cosx, ns_multiply); + return make_binary(dout, cos2x, ns_divide); + } + + // dasinh(x) = 1/sqrt(x^2+1) + if (ns == ns_asinh) { + auto one = make_number(1, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_binary(x2, one, ns_add); + x2 = make_unary(x2, ns_sqrt); + return make_binary(dout, x2, ns_divide); + } + // dacosh(x) = 1/sqrt(x^2-1) + if (ns == ns_acosh) { + auto one = make_number(1, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_binary(x2, one, ns_subtract); + x2 = make_unary(x2, ns_sqrt); + return make_binary(dout, x2, ns_divide); + } + // datanh(x) = 1/(1-x^2) + if (ns == ns_atanh) { + auto one = make_number(1, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_binary(one, x2, ns_subtract); + return make_binary(dout, x2, ns_divide); + } return nullptr; }