mirror of https://github.com/Jittor/Jittor
add sin grad
This commit is contained in:
parent
10b8406784
commit
13a5c8cbc8
|
@ -34,11 +34,19 @@ class TestUnaryOp(unittest.TestCase):
|
|||
check("sqrt", a)
|
||||
|
||||
def test_grad(self):
|
||||
ops = ["abs", "negative", "log", "exp", "sqrt"]
|
||||
ops = ["abs", "negative", "log", "exp", "sqrt",
|
||||
"sin", "arcsin", "sinh", "arcsinh",
|
||||
"tan", "arctan", "tanh", "arctanh",
|
||||
"cos", "arccos", "cosh", "arccosh",
|
||||
]
|
||||
a = [1.1, 2.2, 3.3, 4.4]
|
||||
for op in ops:
|
||||
if op == "abs":
|
||||
b = np.array(a+[-1,])
|
||||
elif op == "arccosh":
|
||||
b = np.array(a)
|
||||
elif "sin" in op or "cos" in op or "tan" in op:
|
||||
b = np.array(a) / 5
|
||||
else:
|
||||
b = np.array(a)
|
||||
func = lambda x: eval(f"np.{op}(x[0]).sum()")
|
||||
|
|
|
@ -50,16 +50,22 @@ static unordered_set<string> unary_ops = {
|
|||
"floor",
|
||||
"ceil",
|
||||
"sin",
|
||||
// @pybind(asin, arcsin)
|
||||
"asin",
|
||||
"sinh",
|
||||
// @pybind(asinh, arcsinh)
|
||||
"asinh",
|
||||
"tan",
|
||||
// @pybind(atan, arctan)
|
||||
"atan",
|
||||
"tanh",
|
||||
// @pybind(atanh, arctanh)
|
||||
"atanh",
|
||||
"cos",
|
||||
// @pybind(acos, arccos)
|
||||
"acos",
|
||||
"cosh",
|
||||
// @pybind(acosh, arccosh)
|
||||
"acosh",
|
||||
};
|
||||
|
||||
|
@ -104,6 +110,79 @@ VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
auto twoy = make_binary(two, y, ns_multiply);
|
||||
return make_binary(dout, twoy, ns_divide);
|
||||
}
|
||||
// dsin(x) = cos(x)
|
||||
if (ns == ns_sin)
|
||||
return make_binary(dout, make_unary(x, ns_cos), ns_multiply);
|
||||
// dcos(x) = -sin(x)
|
||||
if (ns == ns_cos)
|
||||
return make_binary(dout, make_unary(make_unary(x, ns_sin), ns_negative), ns_multiply);
|
||||
// dtan(x) = 1/cos^2(x)
|
||||
if (ns == ns_tan) {
|
||||
auto one = make_number(1, x);
|
||||
auto cosx = make_unary(x, ns_cos);
|
||||
auto cos2x = make_binary(cosx, cosx, ns_multiply);
|
||||
return make_binary(dout, cos2x, ns_divide);
|
||||
}
|
||||
// dasin(x) = 1/sqrt(1-x^2)
|
||||
if (ns == ns_asin) {
|
||||
auto one = make_number(1, x);
|
||||
auto x2 = make_binary(x, x, ns_multiply);
|
||||
x2 = make_binary(one, x2, ns_subtract);
|
||||
x2 = make_unary(x2, ns_sqrt);
|
||||
return make_binary(dout, x2, ns_divide);
|
||||
}
|
||||
// dacos(x) = -1/sqrt(1-x^2)
|
||||
if (ns == ns_acos) {
|
||||
auto one = make_number(1, x);
|
||||
auto x2 = make_binary(x, x, ns_multiply);
|
||||
x2 = make_binary(one, x2, ns_subtract);
|
||||
x2 = make_unary(x2, ns_sqrt);
|
||||
return make_unary(make_binary(dout, x2, ns_divide), ns_negative);
|
||||
}
|
||||
// datan(x) = 1/(x^2+1)
|
||||
if (ns == ns_atan) {
|
||||
auto one = make_number(1, x);
|
||||
auto x2 = make_binary(x, x, ns_multiply);
|
||||
x2 = make_binary(one, x2, ns_add);
|
||||
return make_binary(dout, x2, ns_divide);
|
||||
}
|
||||
|
||||
// dsinh(x) = cosh(x)
|
||||
if (ns == ns_sinh)
|
||||
return make_binary(dout, make_unary(x, ns_cosh), ns_multiply);
|
||||
// dcosh(x) = sinh(x)
|
||||
if (ns == ns_cosh)
|
||||
return make_binary(dout, make_unary(x, ns_sinh), ns_multiply);
|
||||
// dtanh(x) = 1/cosh^2(x)
|
||||
if (ns == ns_tanh) {
|
||||
auto cosx = make_unary(x, ns_cosh);
|
||||
auto cos2x = make_binary(cosx, cosx, ns_multiply);
|
||||
return make_binary(dout, cos2x, ns_divide);
|
||||
}
|
||||
|
||||
// dasinh(x) = 1/sqrt(x^2+1)
|
||||
if (ns == ns_asinh) {
|
||||
auto one = make_number(1, x);
|
||||
auto x2 = make_binary(x, x, ns_multiply);
|
||||
x2 = make_binary(x2, one, ns_add);
|
||||
x2 = make_unary(x2, ns_sqrt);
|
||||
return make_binary(dout, x2, ns_divide);
|
||||
}
|
||||
// dacosh(x) = 1/sqrt(x^2-1)
|
||||
if (ns == ns_acosh) {
|
||||
auto one = make_number(1, x);
|
||||
auto x2 = make_binary(x, x, ns_multiply);
|
||||
x2 = make_binary(x2, one, ns_subtract);
|
||||
x2 = make_unary(x2, ns_sqrt);
|
||||
return make_binary(dout, x2, ns_divide);
|
||||
}
|
||||
// datanh(x) = 1/(1-x^2)
|
||||
if (ns == ns_atanh) {
|
||||
auto one = make_number(1, x);
|
||||
auto x2 = make_binary(x, x, ns_multiply);
|
||||
x2 = make_binary(one, x2, ns_subtract);
|
||||
return make_binary(dout, x2, ns_divide);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue