mirror of https://github.com/Jittor/Jittor
update matmul & ConstantPad2d
This commit is contained in:
parent
90a1422b3c
commit
cad1845e68
|
@ -57,13 +57,33 @@ Example::
|
|||
return (a*b).sum(len(shape)-2)
|
||||
|
||||
def matmul(a, b):
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert len(a.shape) >= 2 and len(b.shape) >= 2
|
||||
assert a.shape[-1] == b.shape[-2]
|
||||
|
||||
shape = list(a.shape) + [b.shape[-1]]
|
||||
a = a.broadcast(shape, [len(shape)-1])
|
||||
b = b.broadcast(shape)
|
||||
return (a*b).sum(len(shape)-2)
|
||||
len_a = len(a.shape)
|
||||
len_b = len(b.shape)
|
||||
len_max = max(len_a, len_b)
|
||||
a_shape = (len_max - len_a) * [1,] + list(a.shape)
|
||||
b_shape = (len_max - len_b) * [1,] + list(b.shape)
|
||||
|
||||
a_rep = []
|
||||
b_rep = []
|
||||
for i in range(len_max-2):
|
||||
if a_shape[i] == 1 or b_shape[i] == 1:
|
||||
a_rep.append(b_shape[i])
|
||||
b_rep.append(a_shape[i])
|
||||
else:
|
||||
if a_shape[i] == b_shape[i]:
|
||||
a_rep.append(1)
|
||||
b_rep.append(1)
|
||||
else:
|
||||
raise(f"{a_shape[i]} and {b_shape[i]} must be same.")
|
||||
a_rep += [1,1,b.shape[-1],]
|
||||
b_rep += [a.shape[-2],1,1,]
|
||||
a = a.unsqueeze(-1).repeat(a_rep)
|
||||
b = b.unsqueeze(-3).repeat(b_rep)
|
||||
|
||||
return (a*b).sum(len(a.shape)-2)
|
||||
jt.Var.matmul = jt.Var.__matmul__ = matmul
|
||||
jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
|
||||
|
||||
|
@ -522,8 +542,15 @@ class ConstantPad2d(Module):
|
|||
self.value = value
|
||||
|
||||
def execute(self, x):
|
||||
n,c,h,w = x.shape
|
||||
return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"], overflow_value=self.value)
|
||||
assert len(x.shape) >= 2
|
||||
shape = x.shape
|
||||
tar_shape = shape[0:-2] + [shape[-2]+self.pt+self.pb,shape[-1]+self.pl+self.pr]
|
||||
tar_dims = []
|
||||
for i in range(len(shape)-2):
|
||||
tar_dims.append(f"i{i}")
|
||||
tar_dims.append(f"i{i+1}-{self.pt}")
|
||||
tar_dims.append(f"i{i+2}-{self.pl}")
|
||||
return x.reindex(tar_shape, tar_dims, overflow_value=self.value)
|
||||
|
||||
class ReplicationPad2d(Module):
|
||||
def __init__(self, padding):
|
||||
|
|
|
@ -60,6 +60,18 @@ class TestCore(unittest.TestCase):
|
|||
jtc = jt.matmul(jt.array(a), jt.array(b)).data
|
||||
assert np.all(jtc == c)
|
||||
|
||||
a = np.random.random((128,3,10,20))
|
||||
b = np.random.random((20,30))
|
||||
c = np.matmul(a, b)
|
||||
jtc = jt.matmul(jt.array(a), jt.array(b)).data
|
||||
assert np.all(jtc == c)
|
||||
|
||||
a = np.random.random((128,3,10,20))
|
||||
b = np.random.random((128,3,20,30))
|
||||
c = np.matmul(a, b)
|
||||
jtc = jt.matmul(jt.array(a), jt.array(b)).data
|
||||
assert np.all(jtc == c)
|
||||
|
||||
def test_var_holder(self):
|
||||
jt.clean()
|
||||
expect_error(lambda: jt.matmul(1,1))
|
||||
|
|
|
@ -49,6 +49,10 @@ class TestPad(unittest.TestCase):
|
|||
check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2))
|
||||
check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2))
|
||||
|
||||
arr = np.random.randn(16,3,224,10,10)
|
||||
check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2))
|
||||
check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ZeroPad2d Layer
|
||||
# ***************************************************************
|
||||
|
|
Loading…
Reference in New Issue