add cuda bmm test

This commit is contained in:
Dun Liang 2020-07-31 21:14:07 +08:00
parent 84d9434ecf
commit f296712c37
2 changed files with 5 additions and 2 deletions

View File

@ -49,8 +49,6 @@ Example::
'''
assert len(a.shape) > 2 and len(b.shape) > 2
return matmul(a, b)
if jt.flags.use_cuda:
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
def matmul(a, b):
''' matrix multiply,

View File

@ -339,6 +339,11 @@ class TestMatmul(unittest.TestCase):
check([10,3,4], [4,5])
check([10,3,4], [10,4,5])
check([8,1,3,4], [10,4,5])
check([5,10,3,4], [5,10,4,5])
@jt.flag_scope(use_cuda=1)
def test_matmul_example2_cuda(self):
self.test_matmul_example2()
if __name__ == "__main__":
unittest.main()