add cos/sin/tan[ah] and fix tanh nan #61

This commit is contained in:
Dun Liang 2020-04-16 15:53:05 +08:00
parent 68608dd74b
commit 5b7e057f6c
6 changed files with 80 additions and 4 deletions

View File

@ -367,7 +367,7 @@ class Tanh(Module):
def __init__(self):
super().__init__()
def execute(self, x) :
return ((jt.exp (x) - jt.exp(-x)) / (jt.exp(x) + jt.exp (-x)))
return x.tanh()
class Sigmoid(Module):
def __init__(self):

View File

@ -136,6 +136,11 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
da = jt.grad(a**2, a)
assert np.isnan(da.data).sum()==0, da.data
def test_tanh_nan(self):
m=jt.nn.Tanh()
a = m(jt.array([1000]))
assert np.isnan(a.data).sum()==0, a
if __name__ == "__main__":
unittest.main()

View File

@ -61,6 +61,18 @@ static unordered_set<string> unary_ops = {
"floor",
"ceil",
"cast",
"sin",
"asin",
"sinh",
"asinh",
"tan",
"atan",
"tanh",
"atanh",
"cos",
"acos",
"cosh",
"acosh",
};
static unordered_set<string> unary_float_ops = {

View File

@ -63,6 +63,19 @@ namespace jittor {
m(floor) \
m(ceil) \
m(cast) \
\
m(sin) \
m(asin) \
m(sinh) \
m(asinh) \
m(tan) \
m(atan) \
m(tanh) \
m(atanh) \
m(cos) \
m(acos) \
m(cosh) \
m(acosh) \
struct NanoString;
#define DECLEAR_NS(T) extern NanoString ns_##T;

View File

@ -49,6 +49,18 @@ static unordered_set<string> unary_ops = {
"round",
"floor",
"ceil",
"sin",
"asin",
"sinh",
"asinh",
"tan",
"atan",
"tanh",
"atanh",
"cos",
"acos",
"cosh",
"acosh",
};
UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {

View File

@ -11,13 +11,30 @@ namespace jittor {
#define bitwise_not(T,x) (~(x))
#define negative(T,x) (-(x))
#ifdef JIT_cuda
// TODO: add float64 version
#define abs(T,x) ::abs(x)
#define log(T,x) ::log((T)(x))
#define exp(T,x) ::exp((T)(x))
#define sqrt(T,x) ::sqrt((T)(x))
#define log(T,x) ::logf((T)(x))
#define exp(T,x) ::expf((T)(x))
#define sqrt(T,x) ::sqrtf((T)(x))
#define round(T,x) ((T) ::roundf((x)))
#define floor(T,x) ((T) ::floorf((x)))
#define ceil(T,x) ((T) ::ceilf((x)))
#define sin(T,x) ((T) ::sinf((x)))
#define asin(T,x) ((T) ::asinf((x)))
#define sinh(T,x) ((T) ::sinhf((x)))
#define asinh(T,x) ((T) ::asinhf((x)))
#define cos(T,x) ((T) ::cosf((x)))
#define acos(T,x) ((T) ::acosf((x)))
#define cosh(T,x) ((T) ::coshf((x)))
#define acosh(T,x) ((T) ::acoshf((x)))
#define tan(T,x) ((T) ::tanf((x)))
#define atan(T,x) ((T) ::atanf((x)))
#define tanh(T,x) ((T) ::tanhf((x)))
#define atanh(T,x) ((T) ::atanhf((x)))
#else
#define abs(T,x) std::abs(x)
#define log(T,x) std::log((T)(x))
@ -26,7 +43,24 @@ namespace jittor {
#define round(T,x) ((T)std::round((x)))
#define floor(T,x) ((T)std::floor((x)))
#define ceil(T,x) ((T)std::ceil((x)))
#define sin(T,x) ((T) std::sin((x)))
#define asin(T,x) ((T) std::asin((x)))
#define sinh(T,x) ((T) std::sinh((x)))
#define asinh(T,x) ((T) std::asinh((x)))
#define cos(T,x) ((T) std::cos((x)))
#define acos(T,x) ((T) std::acos((x)))
#define cosh(T,x) ((T) std::cosh((x)))
#define acosh(T,x) ((T) std::acosh((x)))
#define tan(T,x) ((T) std::tan((x)))
#define atan(T,x) ((T) std::atan((x)))
#define tanh(T,x) ((T) std::tanh((x)))
#define atanh(T,x) ((T) std::atanh((x)))
#endif
#define cast(T,x) ((T)(x))
} // jittor