align resize and upsample

This commit is contained in:
Dun Liang 2020-05-28 14:53:38 +08:00
parent a6dab87634
commit 370a3cc8ef
2 changed files with 62 additions and 21 deletions

View File

@ -499,10 +499,10 @@ class PixelShuffle(Module):
def execute(self, x):
n,c,h,w = x.shape
r = self.upscale_factor
assert c%(r**2)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
return x.reindex([n,int(c/r**2),h*r,w*r], [
"i0",
f"i1*{r**2}+i2%{r}*{r}+i3%{r}",
f"i1*{r*r}+i2%{r}*{r}+i3%{r}",
f"i2/{r}",
f"i3/{r}"
])
@ -519,30 +519,58 @@ class Sigmoid(Module):
def execute(self, x) :
return x.sigmoid()
def resize(x, size, mode="nearest"):
img = x
n,c,h,w = x.shape
H,W = size
new_size = [n,c,H,W]
nid, cid, hid, wid = jt.index(new_size)
x = hid * h / H
y = wid * w / W
class Resize(Module):
def __init__(self, size, mode="nearest", align_corners=False):
super().__init__()
self.size = size
self.mode = mode
self.align_corners = align_corners
def execute(self, x):
return resize(x, self.size, self.mode, self.align_corners)
def _interpolate(img, x, y, ids, mode):
if mode=="nearest":
return img.reindex([nid, cid, x.floor(), y.floor()])
return img.reindex([*ids, x.floor(), y.floor()])
if mode=="bilinear":
fx, fy = x.floor(), y.floor()
cx, cy = fx+1, fy+1
dx, dy = x-fx, y-fy
a = img.reindex_var([nid, cid, fx, fy])
b = img.reindex_var([nid, cid, cx, fy])
c = img.reindex_var([nid, cid, fx, cy])
d = img.reindex_var([nid, cid, cx, cy])
a = img.reindex_var([*ids, fx, fy])
b = img.reindex_var([*ids, cx, fy])
c = img.reindex_var([*ids, fx, cy])
d = img.reindex_var([*ids, cx, cy])
dnx, dny = 1-dx, 1-dy
ab = dx*b + dnx*a
cd = dx*d + dnx*c
o = ab*dny + cd*dy
return o
raise(f"Not support {interpolation}")
raise(f"Not support interpolation mode: {mode}")
def resize(img, size, mode="nearest", align_corners=False):
n,c,h,w = img.shape
H,W = size
nid, cid, hid, wid = jt.index((n,c,H,W))
if align_corners:
x = hid * ((h-1) / max(1, H-1))
y = wid * ((w-1) / max(1, W-1))
else:
x = hid * (h / H) + (h/H*0.5 - 0.5)
if H>h: x = x.clamp(0, h-1)
y = wid * (w / W) + (w/W*0.5 - 0.5)
if W>w: y = y.clamp(0, w-1)
return _interpolate(img, x, y, (nid,cid), mode)
def upsample(img, size, mode="nearest", align_corners=False):
n,c,h,w = img.shape
H,W = size
nid, cid, hid, wid = jt.index((n,c,H,W))
if align_corners:
x = hid * ((h-1) / max(1, H-1))
y = wid * ((w-1) / max(1, W-1))
else:
x = hid * (h / H)
y = wid * (w / W)
return _interpolate(img, x, y, (nid,cid), mode)
class Upsample(Module):
def __init__(self, scale_factor=None, mode='nearest'):
@ -550,7 +578,11 @@ class Upsample(Module):
self.mode = mode
def execute(self, x):
return resize(x, size=(int(x.shape[2]*self.scale_factor[0]), int(x.shape[3]*self.scale_factor[1])), mode=self.mode)
return upsample(x,
size=(
int(x.shape[2]*self.scale_factor[0]),
int(x.shape[3]*self.scale_factor[1])),
mode=self.mode)
class Sequential(Module):
def __init__(self, *args):

View File

@ -99,16 +99,25 @@ class TestResizeAndCrop(unittest.TestCase):
test_case(20, [1024, 1024], [1.2, 1.8][mid])
test_case(20, [1024, 666], [0.8,1.0][mid])
def test_resize(self):
import torch.nn.functional as F
x = np.array(range(2*3*25)).reshape(2,3,5,5).astype("float32")
for r_size in [3,4,5,6]:
for align_corners in [True,False]:
check_equal(x,
jnn.Resize((r_size, r_size), 'bilinear', align_corners),
lambda x: F.interpolate(x, size=(r_size, r_size), mode='bilinear',align_corners=align_corners))
def test_upsample(self):
arr = np.random.randn(16,10,224,224)
arr = np.random.randn(2,3,224,224)
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
def test_pixelshuffle(self):
arr = np.random.randn(16,16,224,224)
arr = np.random.randn(2,4,224,224)
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
arr = np.random.randn(1,16*16,224,224)
check_equal(arr, jnn.PixelShuffle(upscale_factor=16), tnn.PixelShuffle(upscale_factor=16))
arr = np.random.randn(1,3*3,224,224)
check_equal(arr, jnn.PixelShuffle(upscale_factor=3), tnn.PixelShuffle(upscale_factor=3))
if __name__ == "__main__":
unittest.main()