mirror of https://github.com/Jittor/Jittor
update test_knn
This commit is contained in:
parent
228dbd9583
commit
5dcc1392e4
|
@ -44,12 +44,13 @@ class TestKnnOp(unittest.TestCase):
|
|||
a2 *= -1
|
||||
np.testing.assert_allclose(a1.data, a2.data, atol=1e-4)
|
||||
|
||||
jt.flags.use_cuda = 1
|
||||
jt_a = jt.randn(32,512,3)
|
||||
a1, b1 = jt.misc.knn(jt_a, jt_a, 16)
|
||||
a2, b2 = knn(jt_a.transpose(0,2,1), 16)
|
||||
a2 *= -1
|
||||
np.testing.assert_allclose(a1.data, a2.data, atol=1e-4)
|
||||
if jt.has_cuda:
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
jt_a = jt.randn(32,512,3)
|
||||
a1, b1 = jt.misc.knn(jt_a, jt_a, 16)
|
||||
a2, b2 = knn(jt_a.transpose(0,2,1), 16)
|
||||
a2 *= -1
|
||||
np.testing.assert_allclose(a1.data, a2.data, atol=1e-4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue