add bmm function

This commit is contained in:
gmh 2020-05-07 15:37:47 +08:00
parent 1a0d7e3810
commit d61dcc8f8e
1 changed files with 10 additions and 0 deletions

View File

@ -28,6 +28,16 @@ def matmul_transpose(a, b):
b = b.broadcast(shape)
return (a*b).sum(len(shape)-1)
def bmm(a, b):
assert len(a.shape) >= 2 and len(b.shape) >= 2
assert a.shape[-1] == b.shape[-2]
shape = list(a.shape) + [b.shape[-1]]
a = a.broadcast(shape, [len(shape)-1])
b = b.broadcast(shape, [len(shape)-3])
return (a*b).sum(len(shape)-2)
def matmul(a, b):
assert len(a.shape) >= 2 and len(b.shape) == 2
assert a.shape[-1] == b.shape[-2]