add sin grad

This commit is contained in:
Dun Liang 2020-04-19 22:35:05 +08:00
parent 10b8406784
commit 13a5c8cbc8
2 changed files with 88 additions and 1 deletions

View File

@ -34,11 +34,19 @@ class TestUnaryOp(unittest.TestCase):
check("sqrt", a)
def test_grad(self):
ops = ["abs", "negative", "log", "exp", "sqrt"]
ops = ["abs", "negative", "log", "exp", "sqrt",
"sin", "arcsin", "sinh", "arcsinh",
"tan", "arctan", "tanh", "arctanh",
"cos", "arccos", "cosh", "arccosh",
]
a = [1.1, 2.2, 3.3, 4.4]
for op in ops:
if op == "abs":
b = np.array(a+[-1,])
elif op == "arccosh":
b = np.array(a)
elif "sin" in op or "cos" in op or "tan" in op:
b = np.array(a) / 5
else:
b = np.array(a)
func = lambda x: eval(f"np.{op}(x[0]).sum()")

View File

@ -50,16 +50,22 @@ static unordered_set<string> unary_ops = {
"floor",
"ceil",
"sin",
// @pybind(asin, arcsin)
"asin",
"sinh",
// @pybind(asinh, arcsinh)
"asinh",
"tan",
// @pybind(atan, arctan)
"atan",
"tanh",
// @pybind(atanh, arctanh)
"atanh",
"cos",
// @pybind(acos, arccos)
"acos",
"cosh",
// @pybind(acosh, arccosh)
"acosh",
};
@ -104,6 +110,79 @@ VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
auto twoy = make_binary(two, y, ns_multiply);
return make_binary(dout, twoy, ns_divide);
}
// dsin(x) = cos(x)
if (ns == ns_sin)
return make_binary(dout, make_unary(x, ns_cos), ns_multiply);
// dcos(x) = -sin(x)
if (ns == ns_cos)
return make_binary(dout, make_unary(make_unary(x, ns_sin), ns_negative), ns_multiply);
// dtan(x) = 1/cos^2(x)
if (ns == ns_tan) {
auto one = make_number(1, x);
auto cosx = make_unary(x, ns_cos);
auto cos2x = make_binary(cosx, cosx, ns_multiply);
return make_binary(dout, cos2x, ns_divide);
}
// dasin(x) = 1/sqrt(1-x^2)
if (ns == ns_asin) {
auto one = make_number(1, x);
auto x2 = make_binary(x, x, ns_multiply);
x2 = make_binary(one, x2, ns_subtract);
x2 = make_unary(x2, ns_sqrt);
return make_binary(dout, x2, ns_divide);
}
// dacos(x) = -1/sqrt(1-x^2)
if (ns == ns_acos) {
auto one = make_number(1, x);
auto x2 = make_binary(x, x, ns_multiply);
x2 = make_binary(one, x2, ns_subtract);
x2 = make_unary(x2, ns_sqrt);
return make_unary(make_binary(dout, x2, ns_divide), ns_negative);
}
// datan(x) = 1/(x^2+1)
if (ns == ns_atan) {
auto one = make_number(1, x);
auto x2 = make_binary(x, x, ns_multiply);
x2 = make_binary(one, x2, ns_add);
return make_binary(dout, x2, ns_divide);
}
// dsinh(x) = cosh(x)
if (ns == ns_sinh)
return make_binary(dout, make_unary(x, ns_cosh), ns_multiply);
// dcosh(x) = sinh(x)
if (ns == ns_cosh)
return make_binary(dout, make_unary(x, ns_sinh), ns_multiply);
// dtanh(x) = 1/cosh^2(x)
if (ns == ns_tanh) {
auto cosx = make_unary(x, ns_cosh);
auto cos2x = make_binary(cosx, cosx, ns_multiply);
return make_binary(dout, cos2x, ns_divide);
}
// dasinh(x) = 1/sqrt(x^2+1)
if (ns == ns_asinh) {
auto one = make_number(1, x);
auto x2 = make_binary(x, x, ns_multiply);
x2 = make_binary(x2, one, ns_add);
x2 = make_unary(x2, ns_sqrt);
return make_binary(dout, x2, ns_divide);
}
// dacosh(x) = 1/sqrt(x^2-1)
if (ns == ns_acosh) {
auto one = make_number(1, x);
auto x2 = make_binary(x, x, ns_multiply);
x2 = make_binary(x2, one, ns_subtract);
x2 = make_unary(x2, ns_sqrt);
return make_binary(dout, x2, ns_divide);
}
// datanh(x) = 1/(1-x^2)
if (ns == ns_atanh) {
auto one = make_number(1, x);
auto x2 = make_binary(x, x, ns_multiply);
x2 = make_binary(one, x2, ns_subtract);
return make_binary(dout, x2, ns_divide);
}
return nullptr;
}