Merge pull request #79 from Jittor/gmh

add bmm function
This commit is contained in:
MenghaoGuo 2020-06-26 11:15:14 +08:00 committed by GitHub
commit 225a8f4944
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 0 deletions

View File

@ -29,6 +29,16 @@ def matmul_transpose(a, b):
b = b.broadcast(shape) b = b.broadcast(shape)
return (a*b).sum(len(shape)-1) return (a*b).sum(len(shape)-1)
def bmm(a, b):
assert len(a.shape) >= 2 and len(b.shape) >= 2
assert a.shape[-1] == b.shape[-2]
shape = list(a.shape) + [b.shape[-1]]
a = a.broadcast(shape, [len(shape)-1])
b = b.broadcast(shape, [len(shape)-3])
return (a*b).sum(len(shape)-2)
def matmul(a, b): def matmul(a, b):
assert len(a.shape) >= 2 and len(b.shape) == 2 assert len(a.shape) >= 2 and len(b.shape) == 2
assert a.shape[-1] == b.shape[-2] assert a.shape[-1] == b.shape[-2]