From d61dcc8f8e33b627131d2336acdc30f350f67e0e Mon Sep 17 00:00:00 2001 From: gmh Date: Thu, 7 May 2020 15:37:47 +0800 Subject: [PATCH] add bmm function --- python/jittor/nn.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 238e40fa..e4c33b67 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -28,6 +28,16 @@ def matmul_transpose(a, b): b = b.broadcast(shape) return (a*b).sum(len(shape)-1) + +def bmm(a, b): + assert len(a.shape) >= 2 and len(b.shape) >= 2 + assert a.shape[-1] == b.shape[-2] + + shape = list(a.shape) + [b.shape[-1]] + a = a.broadcast(shape, [len(shape)-1]) + b = b.broadcast(shape, [len(shape)-3]) + return (a*b).sum(len(shape)-2) + def matmul(a, b): assert len(a.shape) >= 2 and len(b.shape) == 2 assert a.shape[-1] == b.shape[-2]