mirror of https://github.com/Jittor/Jittor
Merge pull request #80 from Jittor/zwy
change ReflectionPad2d & ReplicationPad2d
This commit is contained in:
commit
627c7d4916
|
@ -486,38 +486,10 @@ class ReflectionPad2d(Module):
|
|||
r = self.pl + w - 1
|
||||
t = self.pt
|
||||
b = self.pt + h - 1
|
||||
x_idx = np.zeros((oh,ow))
|
||||
y_idx = np.zeros((oh,ow))
|
||||
for j in range(oh):
|
||||
for i in range(ow):
|
||||
if i >= l and i <= r and j >= t and j <= b:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = j
|
||||
elif i < l and j < t:
|
||||
x_idx[j,i] = 2 * l - i
|
||||
y_idx[j,i] = 2 * t - j
|
||||
elif i < l and j > b:
|
||||
x_idx[j,i] = 2 * l - i
|
||||
y_idx[j,i] = 2 * b - j
|
||||
elif i > r and j < t:
|
||||
x_idx[j,i] = 2 * r - i
|
||||
y_idx[j,i] = 2 * t - j
|
||||
elif i > r and j > b:
|
||||
x_idx[j,i] = 2 * r - i
|
||||
y_idx[j,i] = 2 * b - j
|
||||
elif i < l:
|
||||
x_idx[j,i] = 2 * l - i
|
||||
y_idx[j,i] = j
|
||||
elif i > r:
|
||||
x_idx[j,i] = 2 * r - i
|
||||
y_idx[j,i] = j
|
||||
elif j < t:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = 2 * t - j
|
||||
elif j > b:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = 2 * b - j
|
||||
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
|
||||
return x.reindex([n,c,oh,ow], ["i0","i1",
|
||||
f"i2<{t} ? {t}-i2 : i2 > {b} ? {h-1+b}-i2 : i2-{t}",
|
||||
f"i3<{l} ? {l}-i3 : i3 > {r} ? {w-1+r}-i3 : i3-{l}",
|
||||
])
|
||||
|
||||
class ZeroPad2d(Module):
|
||||
def __init__(self, padding):
|
||||
|
@ -575,38 +547,10 @@ class ReplicationPad2d(Module):
|
|||
r = self.pl + w - 1
|
||||
t = self.pt
|
||||
b = self.pt + h - 1
|
||||
x_idx = np.zeros((oh,ow))
|
||||
y_idx = np.zeros((oh,ow))
|
||||
for j in range(oh):
|
||||
for i in range(ow):
|
||||
if i >= l and i <= r and j >= t and j <= b:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = j
|
||||
elif i < l and j < t:
|
||||
x_idx[j,i] = l
|
||||
y_idx[j,i] = t
|
||||
elif i < l and j > b:
|
||||
x_idx[j,i] = l
|
||||
y_idx[j,i] = b
|
||||
elif i > r and j < t:
|
||||
x_idx[j,i] = r
|
||||
y_idx[j,i] = t
|
||||
elif i > r and j > b:
|
||||
x_idx[j,i] = r
|
||||
y_idx[j,i] = b
|
||||
elif i < l:
|
||||
x_idx[j,i] = l
|
||||
y_idx[j,i] = j
|
||||
elif i > r:
|
||||
x_idx[j,i] = r
|
||||
y_idx[j,i] = j
|
||||
elif j < t:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = t
|
||||
elif j > b:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = b
|
||||
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
|
||||
return x.reindex([n,c,oh,ow], ["i0","i1",
|
||||
f"i2<{t} ? 0 : i2 > {b} ? {h-1} : i2-{t}",
|
||||
f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}"
|
||||
])
|
||||
|
||||
class PixelShuffle(Module):
|
||||
def __init__(self, upscale_factor):
|
||||
|
|
|
@ -39,6 +39,8 @@ class TestPad(unittest.TestCase):
|
|||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(arr, jnn.ReplicationPad2d(10), tnn.ReplicationPad2d(10))
|
||||
check_equal(arr, jnn.ReplicationPad2d((1,23,4,5)), tnn.ReplicationPad2d((1,23,4,5)))
|
||||
check_equal(arr, jnn.ReplicationPad2d((1,0,1,5)), tnn.ReplicationPad2d((1,0,1,5)))
|
||||
check_equal(arr, jnn.ReplicationPad2d((100)), tnn.ReplicationPad2d((100)))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ConstantPad2d Layer
|
||||
|
@ -60,6 +62,8 @@ class TestPad(unittest.TestCase):
|
|||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(arr, jnn.ReflectionPad2d(20), tnn.ReflectionPad2d(20))
|
||||
check_equal(arr, jnn.ReflectionPad2d((2,3,34,1)), tnn.ReflectionPad2d((2,3,34,1)))
|
||||
check_equal(arr, jnn.ReflectionPad2d((10,123,34,1)), tnn.ReflectionPad2d((10,123,34,1)))
|
||||
check_equal(arr, jnn.ReflectionPad2d((100)), tnn.ReflectionPad2d((100)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue