From cad1845e6856267ed4eea1215e340a92f8e448e5 Mon Sep 17 00:00:00 2001 From: zhouwy19 Date: Thu, 30 Jul 2020 23:33:31 +0800 Subject: [PATCH] update matmul & ConstantPad2d --- python/jittor/nn.py | 41 +++++++++++++++++++++++++++------ python/jittor/test/test_core.py | 12 ++++++++++ python/jittor/test/test_pad.py | 4 ++++ 3 files changed, 50 insertions(+), 7 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index e7308b48..ab400b3f 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -57,13 +57,33 @@ Example:: return (a*b).sum(len(shape)-2) def matmul(a, b): - assert len(a.shape) >= 2 and len(b.shape) == 2 + assert len(a.shape) >= 2 and len(b.shape) >= 2 assert a.shape[-1] == b.shape[-2] - shape = list(a.shape) + [b.shape[-1]] - a = a.broadcast(shape, [len(shape)-1]) - b = b.broadcast(shape) - return (a*b).sum(len(shape)-2) + len_a = len(a.shape) + len_b = len(b.shape) + len_max = max(len_a, len_b) + a_shape = (len_max - len_a) * [1,] + list(a.shape) + b_shape = (len_max - len_b) * [1,] + list(b.shape) + + a_rep = [] + b_rep = [] + for i in range(len_max-2): + if a_shape[i] == 1 or b_shape[i] == 1: + a_rep.append(b_shape[i]) + b_rep.append(a_shape[i]) + else: + if a_shape[i] == b_shape[i]: + a_rep.append(1) + b_rep.append(1) + else: + raise(f"{a_shape[i]} and {b_shape[i]} must be same.") + a_rep += [1,1,b.shape[-1],] + b_rep += [a.shape[-2],1,1,] + a = a.unsqueeze(-1).repeat(a_rep) + b = b.unsqueeze(-3).repeat(b_rep) + + return (a*b).sum(len(a.shape)-2) jt.Var.matmul = jt.Var.__matmul__ = matmul jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b)) @@ -522,8 +542,15 @@ class ConstantPad2d(Module): self.value = value def execute(self, x): - n,c,h,w = x.shape - return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"], overflow_value=self.value) + assert len(x.shape) >= 2 + shape = x.shape + tar_shape = shape[0:-2] + [shape[-2]+self.pt+self.pb,shape[-1]+self.pl+self.pr] + tar_dims = [] + for i in range(len(shape)-2): + tar_dims.append(f"i{i}") + tar_dims.append(f"i{i+1}-{self.pt}") + tar_dims.append(f"i{i+2}-{self.pl}") + return x.reindex(tar_shape, tar_dims, overflow_value=self.value) class ReplicationPad2d(Module): def __init__(self, padding): diff --git a/python/jittor/test/test_core.py b/python/jittor/test/test_core.py index 867b9f6e..f53165b4 100644 --- a/python/jittor/test/test_core.py +++ b/python/jittor/test/test_core.py @@ -59,6 +59,18 @@ class TestCore(unittest.TestCase): c = np.matmul(a, b) jtc = jt.matmul(jt.array(a), jt.array(b)).data assert np.all(jtc == c) + + a = np.random.random((128,3,10,20)) + b = np.random.random((20,30)) + c = np.matmul(a, b) + jtc = jt.matmul(jt.array(a), jt.array(b)).data + assert np.all(jtc == c) + + a = np.random.random((128,3,10,20)) + b = np.random.random((128,3,20,30)) + c = np.matmul(a, b) + jtc = jt.matmul(jt.array(a), jt.array(b)).data + assert np.all(jtc == c) def test_var_holder(self): jt.clean() diff --git a/python/jittor/test/test_pad.py b/python/jittor/test/test_pad.py index a776b4b5..942313a0 100644 --- a/python/jittor/test/test_pad.py +++ b/python/jittor/test/test_pad.py @@ -49,6 +49,10 @@ class TestPad(unittest.TestCase): check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2)) check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2)) + arr = np.random.randn(16,3,224,10,10) + check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2)) + check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2)) + # *************************************************************** # Test ZeroPad2d Layer # ***************************************************************