diff --git a/python/jittor/nn.py b/python/jittor/nn.py index ecaa553f..a0c80fa3 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -29,6 +29,16 @@ def matmul_transpose(a, b): b = b.broadcast(shape) return (a*b).sum(len(shape)-1) + +def bmm(a, b): + assert len(a.shape) >= 2 and len(b.shape) >= 2 + assert a.shape[-1] == b.shape[-2] + + shape = list(a.shape) + [b.shape[-1]] + a = a.broadcast(shape, [len(shape)-1]) + b = b.broadcast(shape, [len(shape)-3]) + return (a*b).sum(len(shape)-2) + def matmul(a, b): assert len(a.shape) >= 2 and len(b.shape) == 2 assert a.shape[-1] == b.shape[-2]