fix issue #401, improve atan and atan2

This commit is contained in:
Dun Liang 2022-10-05 17:21:51 +08:00
parent 9a5e7ea6f5
commit a57de764f6
4 changed files with 31 additions and 8 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.5.17'
__version__ = '1.3.5.18'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -762,11 +762,14 @@ jt.Var.deg2rad = deg2rad
def arctan2(y,x):
angle = jt.zeros(x.shape,dtype=x.dtype)
x = (x!=0.0).ternary(x, x+1e-30)
x = (x!=0.0).ternary(x, 1e-30)
angle = (y/x).arctan()
mask = y<0 | ((y==0) & (x<0))
mask = (x<0)&(y<0)
angle = angle - mask*np.pi
mask = (x<0)&(y>=0)
angle = angle + mask*np.pi
return angle
atan2 = arctan2
def nonzero(x):

View File

@ -95,6 +95,21 @@ static unordered_set<string> float_ops = {
"sqrt",
"mean",
"divide",
"sin",
"asin",
"sinh",
"asinh",
"tan",
"atan",
"tanh",
"atanh",
"cos",
"acos",
"cosh",
"acosh",
"sigmoid",
"erf",
"erfinv"
};
static unordered_set<string> int_ops = {
"round_int",

View File

@ -256,14 +256,19 @@ class TestOther(unittest.TestCase):
assert (x[3]['b'] == np.array([1,2,3])).all()
def test_arctan2(self):
a = jt.arctan2(jt.array([1,1.0,0]), jt.array([1,0.0,-1]))
np.testing.assert_allclose(a.data, [0.7853982,1.5707964,3.1415927])
y = jt.random((100,))
x = jt.random((100,))
x = jt.float32([1,1,-1,-1, 1,-1,0,0,0])
y = jt.float32([-1,1,-1,1, 0,0,1,-1,0])
z = jt.arctan2(y, x)
z2 = np.arctan2(y.data, x.data)
np.testing.assert_allclose(z.data, z2, atol=1e-6)
y = jt.random((100,)) * 2 - 1
x = jt.random((100,)) * 2 - 1
z = jt.arctan2(y, x)
z2 = np.arctan2(y.data, x.data)
np.testing.assert_allclose(z.data, z2, atol=1e-6)
np.testing.assert_allclose(jt.array([1]).arctan().item(), 0.7853982)
def test_softmax_precision(self):
# jt.flags.use_cuda = 1