fix float mod

This commit is contained in:
Dun Liang 2020-08-17 14:06:50 +08:00
parent 4b26d65d96
commit 2dc161a00e
3 changed files with 17 additions and 2 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.1.7.7'
__version__ = '1.1.7.8'
from . import lock
with lock.lock_scope():
from . import compiler

View File

@ -126,6 +126,20 @@ class TestBinaryOp(unittest.TestCase):
for jd, nd in zip(jgrads, grads):
assert (np.abs(jd.data-nd)<1e-4).all(), f"\n{jd.data}\n{nd}"
def test_mod_float(self):
a = jt.random((10,))
b = jt.random((10,))
c = a % b
assert np.allclose(c.data, a.data % b.data)
a = jt.random((10,), 'float64')
b = jt.random((10,), 'float64')
c = a % b
assert np.allclose(c.data, a.data % b.data)
a = jt.random((10,)) * 1000
b = (jt.random((10,)) * 10).int() + 1
c = a % b
assert np.allclose(c.data, a.data % b.data), (c.data, a.data%b.data)
class TestBinaryOpCuda(TestBinaryOp, test_cuda(2)):
pass

View File

@ -12,17 +12,18 @@ namespace jittor {
#define pow(T,a,b) ::pow(a,b)
#define maximum(T,a,b) ::max(T(a), T(b))
#define minimum(T,a,b) ::min(T(a), T(b))
#define mod(T,a,b) @if(@strcmp(@Tx,float32)==0,::fmodf(T(a),T(b)),@if(@strcmp(@Tx,float64)==0,::fmod(T(a),T(b)),((a)%(b))))
#else // JIT_cpu
#define pow(T,a,b) std::pow(a,b)
#define maximum(T,a,b) std::max(T(a), T(b))
#define minimum(T,a,b) std::min(T(a), T(b))
#define mod(T,a,b) @if(@strcmp(@Tx,float32)==0 || @strcmp(@Tx,float64)==0,std::fmod((T)a,(T)b),((a)%(b)))
#endif
#define add(T,a,b) ((a)+(b))
#define subtract(T,a,b) ((a)-(b))
#define multiply(T,a,b) ((a)*(b))
#define divide(T,a,b) (T((T(a))/(T(b))))
#define floor_divide(T,a,b) (T((T(a))/(T(b))))
#define mod(T,a,b) ((a)%(b))
#define less(T,a,b) ((a)<(b))
#define less_equal(T,a,b) ((a)<=(b))
#define greater(T,a,b) ((a)>(b))