update bmm

This commit is contained in:
Meng-Hao 2020-07-12 17:33:41 +08:00
commit a859396aa8
1 changed files with 2 additions and 5 deletions

View File

@ -33,18 +33,15 @@ def bmm(a, b):
'''
'''
if jt.flags.use_cuda:
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
assert len(a.shape) >= 2 and len(b.shape) >= 2
assert a.shape[-1] == b.shape[-2]
if jt.flags.use_cuda:
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
shape = list(a.shape) + [b.shape[-1]]
a = a.broadcast(shape, [len(shape)-1])
b = b.broadcast(shape, [len(shape)-3])
return (a*b).sum(len(shape)-2)
def matmul(a, b):
assert len(a.shape) >= 2 and len(b.shape) == 2
assert a.shape[-1] == b.shape[-2]