mirror of https://github.com/Jittor/Jittor
add cuda bmm test
This commit is contained in:
parent
84d9434ecf
commit
f296712c37
|
@ -49,8 +49,6 @@ Example::
|
|||
'''
|
||||
assert len(a.shape) > 2 and len(b.shape) > 2
|
||||
return matmul(a, b)
|
||||
if jt.flags.use_cuda:
|
||||
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
|
||||
|
||||
def matmul(a, b):
|
||||
''' matrix multiply,
|
||||
|
|
|
@ -339,6 +339,11 @@ class TestMatmul(unittest.TestCase):
|
|||
check([10,3,4], [4,5])
|
||||
check([10,3,4], [10,4,5])
|
||||
check([8,1,3,4], [10,4,5])
|
||||
check([5,10,3,4], [5,10,4,5])
|
||||
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_matmul_example2_cuda(self):
|
||||
self.test_matmul_example2()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue