update pad2d

This commit is contained in:
zwy 2020-05-08 12:13:33 +08:00
parent f38058bbe7
commit 9b6de1fbd8
2 changed files with 13 additions and 65 deletions

View File

@ -546,38 +546,10 @@ class ReflectionPad2d(Module):
r = self.pl + w - 1
t = self.pt
b = self.pt + h - 1
x_idx = np.zeros((oh,ow))
y_idx = np.zeros((oh,ow))
for j in range(oh):
for i in range(ow):
if i >= l and i <= r and j >= t and j <= b:
x_idx[j,i] = i
y_idx[j,i] = j
elif i < l and j < t:
x_idx[j,i] = 2 * l - i
y_idx[j,i] = 2 * t - j
elif i < l and j > b:
x_idx[j,i] = 2 * l - i
y_idx[j,i] = 2 * b - j
elif i > r and j < t:
x_idx[j,i] = 2 * r - i
y_idx[j,i] = 2 * t - j
elif i > r and j > b:
x_idx[j,i] = 2 * r - i
y_idx[j,i] = 2 * b - j
elif i < l:
x_idx[j,i] = 2 * l - i
y_idx[j,i] = j
elif i > r:
x_idx[j,i] = 2 * r - i
y_idx[j,i] = j
elif j < t:
x_idx[j,i] = i
y_idx[j,i] = 2 * t - j
elif j > b:
x_idx[j,i] = i
y_idx[j,i] = 2 * b - j
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
return x.reindex([n,c,oh,ow], ["i0","i1",
f"i2<{t} ? {t}-i2 : i2 > {b} ? {h-1+b}-i2 : i2-{t}",
f"i3<{l} ? {l}-i3 : i3 > {r} ? {w-1+r}-i3 : i3-{l}",
])
class ZeroPad2d(Module):
def __init__(self, padding):
@ -635,39 +607,11 @@ class ReplicationPad2d(Module):
r = self.pl + w - 1
t = self.pt
b = self.pt + h - 1
x_idx = np.zeros((oh,ow))
y_idx = np.zeros((oh,ow))
for j in range(oh):
for i in range(ow):
if i >= l and i <= r and j >= t and j <= b:
x_idx[j,i] = i
y_idx[j,i] = j
elif i < l and j < t:
x_idx[j,i] = l
y_idx[j,i] = t
elif i < l and j > b:
x_idx[j,i] = l
y_idx[j,i] = b
elif i > r and j < t:
x_idx[j,i] = r
y_idx[j,i] = t
elif i > r and j > b:
x_idx[j,i] = r
y_idx[j,i] = b
elif i < l:
x_idx[j,i] = l
y_idx[j,i] = j
elif i > r:
x_idx[j,i] = r
y_idx[j,i] = j
elif j < t:
x_idx[j,i] = i
y_idx[j,i] = t
elif j > b:
x_idx[j,i] = i
y_idx[j,i] = b
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
return x.reindex([n,c,oh,ow], ["i0","i1",
f"i2<{t} ? 0 : i2 > {b} ? {h-1} : i2-{t}",
f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}"
])
class BCELoss(Module):
def __init__(self):
pass

View File

@ -39,6 +39,8 @@ class TestPad(unittest.TestCase):
arr = np.random.randn(16,3,224,224)
check_equal(arr, jnn.ReplicationPad2d(10), tnn.ReplicationPad2d(10))
check_equal(arr, jnn.ReplicationPad2d((1,23,4,5)), tnn.ReplicationPad2d((1,23,4,5)))
check_equal(arr, jnn.ReplicationPad2d((1,0,1,5)), tnn.ReplicationPad2d((1,0,1,5)))
check_equal(arr, jnn.ReplicationPad2d((100)), tnn.ReplicationPad2d((100)))
# ***************************************************************
# Test ConstantPad2d Layer
@ -60,6 +62,8 @@ class TestPad(unittest.TestCase):
arr = np.random.randn(16,3,224,224)
check_equal(arr, jnn.ReflectionPad2d(20), tnn.ReflectionPad2d(20))
check_equal(arr, jnn.ReflectionPad2d((2,3,34,1)), tnn.ReflectionPad2d((2,3,34,1)))
check_equal(arr, jnn.ReflectionPad2d((10,123,34,1)), tnn.ReflectionPad2d((10,123,34,1)))
check_equal(arr, jnn.ReflectionPad2d((100)), tnn.ReflectionPad2d((100)))
if __name__ == "__main__":
unittest.main()