mirror of https://github.com/Jittor/Jittor
fix issue #401, improve atan and atan2
This commit is contained in:
parent
9a5e7ea6f5
commit
a57de764f6
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.5.17'
|
||||
__version__ = '1.3.5.18'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -762,11 +762,14 @@ jt.Var.deg2rad = deg2rad
|
|||
|
||||
def arctan2(y,x):
|
||||
angle = jt.zeros(x.shape,dtype=x.dtype)
|
||||
x = (x!=0.0).ternary(x, x+1e-30)
|
||||
x = (x!=0.0).ternary(x, 1e-30)
|
||||
angle = (y/x).arctan()
|
||||
mask = y<0 | ((y==0) & (x<0))
|
||||
mask = (x<0)&(y<0)
|
||||
angle = angle - mask*np.pi
|
||||
mask = (x<0)&(y>=0)
|
||||
angle = angle + mask*np.pi
|
||||
return angle
|
||||
atan2 = arctan2
|
||||
|
||||
|
||||
def nonzero(x):
|
||||
|
|
|
@ -95,6 +95,21 @@ static unordered_set<string> float_ops = {
|
|||
"sqrt",
|
||||
"mean",
|
||||
"divide",
|
||||
"sin",
|
||||
"asin",
|
||||
"sinh",
|
||||
"asinh",
|
||||
"tan",
|
||||
"atan",
|
||||
"tanh",
|
||||
"atanh",
|
||||
"cos",
|
||||
"acos",
|
||||
"cosh",
|
||||
"acosh",
|
||||
"sigmoid",
|
||||
"erf",
|
||||
"erfinv"
|
||||
};
|
||||
static unordered_set<string> int_ops = {
|
||||
"round_int",
|
||||
|
|
|
@ -256,14 +256,19 @@ class TestOther(unittest.TestCase):
|
|||
assert (x[3]['b'] == np.array([1,2,3])).all()
|
||||
|
||||
def test_arctan2(self):
|
||||
a = jt.arctan2(jt.array([1,1.0,0]), jt.array([1,0.0,-1]))
|
||||
np.testing.assert_allclose(a.data, [0.7853982,1.5707964,3.1415927])
|
||||
|
||||
y = jt.random((100,))
|
||||
x = jt.random((100,))
|
||||
x = jt.float32([1,1,-1,-1, 1,-1,0,0,0])
|
||||
y = jt.float32([-1,1,-1,1, 0,0,1,-1,0])
|
||||
z = jt.arctan2(y, x)
|
||||
z2 = np.arctan2(y.data, x.data)
|
||||
np.testing.assert_allclose(z.data, z2, atol=1e-6)
|
||||
|
||||
y = jt.random((100,)) * 2 - 1
|
||||
x = jt.random((100,)) * 2 - 1
|
||||
z = jt.arctan2(y, x)
|
||||
z2 = np.arctan2(y.data, x.data)
|
||||
np.testing.assert_allclose(z.data, z2, atol=1e-6)
|
||||
|
||||
np.testing.assert_allclose(jt.array([1]).arctan().item(), 0.7853982)
|
||||
|
||||
def test_softmax_precision(self):
|
||||
# jt.flags.use_cuda = 1
|
||||
|
|
Loading…
Reference in New Issue