mirror of https://github.com/Jittor/Jittor
fix float mod
This commit is contained in:
parent
4b26d65d96
commit
2dc161a00e
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.1.7.7'
|
||||
__version__ = '1.1.7.8'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
|
|
@ -126,6 +126,20 @@ class TestBinaryOp(unittest.TestCase):
|
|||
for jd, nd in zip(jgrads, grads):
|
||||
assert (np.abs(jd.data-nd)<1e-4).all(), f"\n{jd.data}\n{nd}"
|
||||
|
||||
def test_mod_float(self):
|
||||
a = jt.random((10,))
|
||||
b = jt.random((10,))
|
||||
c = a % b
|
||||
assert np.allclose(c.data, a.data % b.data)
|
||||
a = jt.random((10,), 'float64')
|
||||
b = jt.random((10,), 'float64')
|
||||
c = a % b
|
||||
assert np.allclose(c.data, a.data % b.data)
|
||||
a = jt.random((10,)) * 1000
|
||||
b = (jt.random((10,)) * 10).int() + 1
|
||||
c = a % b
|
||||
assert np.allclose(c.data, a.data % b.data), (c.data, a.data%b.data)
|
||||
|
||||
|
||||
class TestBinaryOpCuda(TestBinaryOp, test_cuda(2)):
|
||||
pass
|
||||
|
|
|
@ -12,17 +12,18 @@ namespace jittor {
|
|||
#define pow(T,a,b) ::pow(a,b)
|
||||
#define maximum(T,a,b) ::max(T(a), T(b))
|
||||
#define minimum(T,a,b) ::min(T(a), T(b))
|
||||
#define mod(T,a,b) @if(@strcmp(@Tx,float32)==0,::fmodf(T(a),T(b)),@if(@strcmp(@Tx,float64)==0,::fmod(T(a),T(b)),((a)%(b))))
|
||||
#else // JIT_cpu
|
||||
#define pow(T,a,b) std::pow(a,b)
|
||||
#define maximum(T,a,b) std::max(T(a), T(b))
|
||||
#define minimum(T,a,b) std::min(T(a), T(b))
|
||||
#define mod(T,a,b) @if(@strcmp(@Tx,float32)==0 || @strcmp(@Tx,float64)==0,std::fmod((T)a,(T)b),((a)%(b)))
|
||||
#endif
|
||||
#define add(T,a,b) ((a)+(b))
|
||||
#define subtract(T,a,b) ((a)-(b))
|
||||
#define multiply(T,a,b) ((a)*(b))
|
||||
#define divide(T,a,b) (T((T(a))/(T(b))))
|
||||
#define floor_divide(T,a,b) (T((T(a))/(T(b))))
|
||||
#define mod(T,a,b) ((a)%(b))
|
||||
#define less(T,a,b) ((a)<(b))
|
||||
#define less_equal(T,a,b) ((a)<=(b))
|
||||
#define greater(T,a,b) ((a)>(b))
|
||||
|
|
Loading…
Reference in New Issue