update test_knn

This commit is contained in:
zhouwy19 2022-05-03 15:08:44 +08:00
parent 228dbd9583
commit 5dcc1392e4
1 changed files with 7 additions and 6 deletions

View File

@ -44,12 +44,13 @@ class TestKnnOp(unittest.TestCase):
a2 *= -1
np.testing.assert_allclose(a1.data, a2.data, atol=1e-4)
jt.flags.use_cuda = 1
jt_a = jt.randn(32,512,3)
a1, b1 = jt.misc.knn(jt_a, jt_a, 16)
a2, b2 = knn(jt_a.transpose(0,2,1), 16)
a2 *= -1
np.testing.assert_allclose(a1.data, a2.data, atol=1e-4)
if jt.has_cuda:
with jt.flag_scope(use_cuda=1):
jt_a = jt.randn(32,512,3)
a1, b1 = jt.misc.knn(jt_a, jt_a, 16)
a2, b2 = knn(jt_a.transpose(0,2,1), 16)
a2 *= -1
np.testing.assert_allclose(a1.data, a2.data, atol=1e-4)
if __name__ == "__main__":
unittest.main()