mirror of https://github.com/Jittor/Jittor
update matmul & ConstantPad2d
This commit is contained in:
parent
90a1422b3c
commit
cad1845e68
|
@ -57,13 +57,33 @@ Example::
|
||||||
return (a*b).sum(len(shape)-2)
|
return (a*b).sum(len(shape)-2)
|
||||||
|
|
||||||
def matmul(a, b):
|
def matmul(a, b):
|
||||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
assert len(a.shape) >= 2 and len(b.shape) >= 2
|
||||||
assert a.shape[-1] == b.shape[-2]
|
assert a.shape[-1] == b.shape[-2]
|
||||||
|
|
||||||
shape = list(a.shape) + [b.shape[-1]]
|
len_a = len(a.shape)
|
||||||
a = a.broadcast(shape, [len(shape)-1])
|
len_b = len(b.shape)
|
||||||
b = b.broadcast(shape)
|
len_max = max(len_a, len_b)
|
||||||
return (a*b).sum(len(shape)-2)
|
a_shape = (len_max - len_a) * [1,] + list(a.shape)
|
||||||
|
b_shape = (len_max - len_b) * [1,] + list(b.shape)
|
||||||
|
|
||||||
|
a_rep = []
|
||||||
|
b_rep = []
|
||||||
|
for i in range(len_max-2):
|
||||||
|
if a_shape[i] == 1 or b_shape[i] == 1:
|
||||||
|
a_rep.append(b_shape[i])
|
||||||
|
b_rep.append(a_shape[i])
|
||||||
|
else:
|
||||||
|
if a_shape[i] == b_shape[i]:
|
||||||
|
a_rep.append(1)
|
||||||
|
b_rep.append(1)
|
||||||
|
else:
|
||||||
|
raise(f"{a_shape[i]} and {b_shape[i]} must be same.")
|
||||||
|
a_rep += [1,1,b.shape[-1],]
|
||||||
|
b_rep += [a.shape[-2],1,1,]
|
||||||
|
a = a.unsqueeze(-1).repeat(a_rep)
|
||||||
|
b = b.unsqueeze(-3).repeat(b_rep)
|
||||||
|
|
||||||
|
return (a*b).sum(len(a.shape)-2)
|
||||||
jt.Var.matmul = jt.Var.__matmul__ = matmul
|
jt.Var.matmul = jt.Var.__matmul__ = matmul
|
||||||
jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
|
jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
|
||||||
|
|
||||||
|
@ -522,8 +542,15 @@ class ConstantPad2d(Module):
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
n,c,h,w = x.shape
|
assert len(x.shape) >= 2
|
||||||
return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"], overflow_value=self.value)
|
shape = x.shape
|
||||||
|
tar_shape = shape[0:-2] + [shape[-2]+self.pt+self.pb,shape[-1]+self.pl+self.pr]
|
||||||
|
tar_dims = []
|
||||||
|
for i in range(len(shape)-2):
|
||||||
|
tar_dims.append(f"i{i}")
|
||||||
|
tar_dims.append(f"i{i+1}-{self.pt}")
|
||||||
|
tar_dims.append(f"i{i+2}-{self.pl}")
|
||||||
|
return x.reindex(tar_shape, tar_dims, overflow_value=self.value)
|
||||||
|
|
||||||
class ReplicationPad2d(Module):
|
class ReplicationPad2d(Module):
|
||||||
def __init__(self, padding):
|
def __init__(self, padding):
|
||||||
|
|
|
@ -59,6 +59,18 @@ class TestCore(unittest.TestCase):
|
||||||
c = np.matmul(a, b)
|
c = np.matmul(a, b)
|
||||||
jtc = jt.matmul(jt.array(a), jt.array(b)).data
|
jtc = jt.matmul(jt.array(a), jt.array(b)).data
|
||||||
assert np.all(jtc == c)
|
assert np.all(jtc == c)
|
||||||
|
|
||||||
|
a = np.random.random((128,3,10,20))
|
||||||
|
b = np.random.random((20,30))
|
||||||
|
c = np.matmul(a, b)
|
||||||
|
jtc = jt.matmul(jt.array(a), jt.array(b)).data
|
||||||
|
assert np.all(jtc == c)
|
||||||
|
|
||||||
|
a = np.random.random((128,3,10,20))
|
||||||
|
b = np.random.random((128,3,20,30))
|
||||||
|
c = np.matmul(a, b)
|
||||||
|
jtc = jt.matmul(jt.array(a), jt.array(b)).data
|
||||||
|
assert np.all(jtc == c)
|
||||||
|
|
||||||
def test_var_holder(self):
|
def test_var_holder(self):
|
||||||
jt.clean()
|
jt.clean()
|
||||||
|
|
|
@ -49,6 +49,10 @@ class TestPad(unittest.TestCase):
|
||||||
check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2))
|
check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2))
|
||||||
check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2))
|
check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2))
|
||||||
|
|
||||||
|
arr = np.random.randn(16,3,224,10,10)
|
||||||
|
check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2))
|
||||||
|
check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2))
|
||||||
|
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
# Test ZeroPad2d Layer
|
# Test ZeroPad2d Layer
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
Loading…
Reference in New Issue