From 370a3cc8ef6c2ce9e984f86976ac488ef3e54567 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Thu, 28 May 2020 14:53:38 +0800 Subject: [PATCH] align resize and upsample --- python/jittor/nn.py | 66 ++++++++++++++++------ python/jittor/test/test_resize_and_crop.py | 17 ++++-- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 854b4d77..899aff46 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -499,10 +499,10 @@ class PixelShuffle(Module): def execute(self, x): n,c,h,w = x.shape r = self.upscale_factor - assert c%(r**2)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle" + assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle" return x.reindex([n,int(c/r**2),h*r,w*r], [ "i0", - f"i1*{r**2}+i2%{r}*{r}+i3%{r}", + f"i1*{r*r}+i2%{r}*{r}+i3%{r}", f"i2/{r}", f"i3/{r}" ]) @@ -519,30 +519,58 @@ class Sigmoid(Module): def execute(self, x) : return x.sigmoid() -def resize(x, size, mode="nearest"): - img = x - n,c,h,w = x.shape - H,W = size - new_size = [n,c,H,W] - nid, cid, hid, wid = jt.index(new_size) - x = hid * h / H - y = wid * w / W +class Resize(Module): + def __init__(self, size, mode="nearest", align_corners=False): + super().__init__() + self.size = size + self.mode = mode + self.align_corners = align_corners + def execute(self, x): + return resize(x, self.size, self.mode, self.align_corners) + +def _interpolate(img, x, y, ids, mode): if mode=="nearest": - return img.reindex([nid, cid, x.floor(), y.floor()]) + return img.reindex([*ids, x.floor(), y.floor()]) if mode=="bilinear": fx, fy = x.floor(), y.floor() cx, cy = fx+1, fy+1 dx, dy = x-fx, y-fy - a = img.reindex_var([nid, cid, fx, fy]) - b = img.reindex_var([nid, cid, cx, fy]) - c = img.reindex_var([nid, cid, fx, cy]) - d = img.reindex_var([nid, cid, cx, cy]) + a = img.reindex_var([*ids, fx, fy]) + b = img.reindex_var([*ids, cx, fy]) + c = img.reindex_var([*ids, fx, cy]) + d = img.reindex_var([*ids, cx, cy]) dnx, dny = 1-dx, 1-dy ab = dx*b + dnx*a cd = dx*d + dnx*c o = ab*dny + cd*dy return o - raise(f"Not support {interpolation}") + raise(f"Not support interpolation mode: {mode}") + +def resize(img, size, mode="nearest", align_corners=False): + n,c,h,w = img.shape + H,W = size + nid, cid, hid, wid = jt.index((n,c,H,W)) + if align_corners: + x = hid * ((h-1) / max(1, H-1)) + y = wid * ((w-1) / max(1, W-1)) + else: + x = hid * (h / H) + (h/H*0.5 - 0.5) + if H>h: x = x.clamp(0, h-1) + y = wid * (w / W) + (w/W*0.5 - 0.5) + if W>w: y = y.clamp(0, w-1) + return _interpolate(img, x, y, (nid,cid), mode) + +def upsample(img, size, mode="nearest", align_corners=False): + n,c,h,w = img.shape + H,W = size + nid, cid, hid, wid = jt.index((n,c,H,W)) + if align_corners: + x = hid * ((h-1) / max(1, H-1)) + y = wid * ((w-1) / max(1, W-1)) + else: + x = hid * (h / H) + y = wid * (w / W) + return _interpolate(img, x, y, (nid,cid), mode) class Upsample(Module): def __init__(self, scale_factor=None, mode='nearest'): @@ -550,7 +578,11 @@ class Upsample(Module): self.mode = mode def execute(self, x): - return resize(x, size=(int(x.shape[2]*self.scale_factor[0]), int(x.shape[3]*self.scale_factor[1])), mode=self.mode) + return upsample(x, + size=( + int(x.shape[2]*self.scale_factor[0]), + int(x.shape[3]*self.scale_factor[1])), + mode=self.mode) class Sequential(Module): def __init__(self, *args): diff --git a/python/jittor/test/test_resize_and_crop.py b/python/jittor/test/test_resize_and_crop.py index 5020e936..b6e2c3c6 100644 --- a/python/jittor/test/test_resize_and_crop.py +++ b/python/jittor/test/test_resize_and_crop.py @@ -99,16 +99,25 @@ class TestResizeAndCrop(unittest.TestCase): test_case(20, [1024, 1024], [1.2, 1.8][mid]) test_case(20, [1024, 666], [0.8,1.0][mid]) + def test_resize(self): + import torch.nn.functional as F + x = np.array(range(2*3*25)).reshape(2,3,5,5).astype("float32") + for r_size in [3,4,5,6]: + for align_corners in [True,False]: + check_equal(x, + jnn.Resize((r_size, r_size), 'bilinear', align_corners), + lambda x: F.interpolate(x, size=(r_size, r_size), mode='bilinear',align_corners=align_corners)) + def test_upsample(self): - arr = np.random.randn(16,10,224,224) + arr = np.random.randn(2,3,224,224) check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2)) check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2)) def test_pixelshuffle(self): - arr = np.random.randn(16,16,224,224) + arr = np.random.randn(2,4,224,224) check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2)) - arr = np.random.randn(1,16*16,224,224) - check_equal(arr, jnn.PixelShuffle(upscale_factor=16), tnn.PixelShuffle(upscale_factor=16)) + arr = np.random.randn(1,3*3,224,224) + check_equal(arr, jnn.PixelShuffle(upscale_factor=3), tnn.PixelShuffle(upscale_factor=3)) if __name__ == "__main__": unittest.main()