mirror of https://github.com/Jittor/Jittor
commit
225a8f4944
|
@ -29,6 +29,16 @@ def matmul_transpose(a, b):
|
|||
b = b.broadcast(shape)
|
||||
return (a*b).sum(len(shape)-1)
|
||||
|
||||
|
||||
def bmm(a, b):
|
||||
assert len(a.shape) >= 2 and len(b.shape) >= 2
|
||||
assert a.shape[-1] == b.shape[-2]
|
||||
|
||||
shape = list(a.shape) + [b.shape[-1]]
|
||||
a = a.broadcast(shape, [len(shape)-1])
|
||||
b = b.broadcast(shape, [len(shape)-3])
|
||||
return (a*b).sum(len(shape)-2)
|
||||
|
||||
def matmul(a, b):
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert a.shape[-1] == b.shape[-2]
|
||||
|
|
Loading…
Reference in New Issue