mirror of https://github.com/Jittor/Jittor
commit
225a8f4944
|
@ -29,6 +29,16 @@ def matmul_transpose(a, b):
|
||||||
b = b.broadcast(shape)
|
b = b.broadcast(shape)
|
||||||
return (a*b).sum(len(shape)-1)
|
return (a*b).sum(len(shape)-1)
|
||||||
|
|
||||||
|
|
||||||
|
def bmm(a, b):
|
||||||
|
assert len(a.shape) >= 2 and len(b.shape) >= 2
|
||||||
|
assert a.shape[-1] == b.shape[-2]
|
||||||
|
|
||||||
|
shape = list(a.shape) + [b.shape[-1]]
|
||||||
|
a = a.broadcast(shape, [len(shape)-1])
|
||||||
|
b = b.broadcast(shape, [len(shape)-3])
|
||||||
|
return (a*b).sum(len(shape)-2)
|
||||||
|
|
||||||
def matmul(a, b):
|
def matmul(a, b):
|
||||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||||
assert a.shape[-1] == b.shape[-2]
|
assert a.shape[-1] == b.shape[-2]
|
||||||
|
|
Loading…
Reference in New Issue