mirror of https://github.com/Jittor/Jittor
add bmm function
This commit is contained in:
parent
1a0d7e3810
commit
d61dcc8f8e
|
@ -28,6 +28,16 @@ def matmul_transpose(a, b):
|
|||
b = b.broadcast(shape)
|
||||
return (a*b).sum(len(shape)-1)
|
||||
|
||||
|
||||
def bmm(a, b):
|
||||
assert len(a.shape) >= 2 and len(b.shape) >= 2
|
||||
assert a.shape[-1] == b.shape[-2]
|
||||
|
||||
shape = list(a.shape) + [b.shape[-1]]
|
||||
a = a.broadcast(shape, [len(shape)-1])
|
||||
b = b.broadcast(shape, [len(shape)-3])
|
||||
return (a*b).sum(len(shape)-2)
|
||||
|
||||
def matmul(a, b):
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert a.shape[-1] == b.shape[-2]
|
||||
|
|
Loading…
Reference in New Issue