mirror of https://github.com/Jittor/Jittor
align resize and upsample
This commit is contained in:
parent
a6dab87634
commit
370a3cc8ef
|
@ -499,10 +499,10 @@ class PixelShuffle(Module):
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
n,c,h,w = x.shape
|
n,c,h,w = x.shape
|
||||||
r = self.upscale_factor
|
r = self.upscale_factor
|
||||||
assert c%(r**2)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
|
assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
|
||||||
return x.reindex([n,int(c/r**2),h*r,w*r], [
|
return x.reindex([n,int(c/r**2),h*r,w*r], [
|
||||||
"i0",
|
"i0",
|
||||||
f"i1*{r**2}+i2%{r}*{r}+i3%{r}",
|
f"i1*{r*r}+i2%{r}*{r}+i3%{r}",
|
||||||
f"i2/{r}",
|
f"i2/{r}",
|
||||||
f"i3/{r}"
|
f"i3/{r}"
|
||||||
])
|
])
|
||||||
|
@ -519,30 +519,58 @@ class Sigmoid(Module):
|
||||||
def execute(self, x) :
|
def execute(self, x) :
|
||||||
return x.sigmoid()
|
return x.sigmoid()
|
||||||
|
|
||||||
def resize(x, size, mode="nearest"):
|
class Resize(Module):
|
||||||
img = x
|
def __init__(self, size, mode="nearest", align_corners=False):
|
||||||
n,c,h,w = x.shape
|
super().__init__()
|
||||||
H,W = size
|
self.size = size
|
||||||
new_size = [n,c,H,W]
|
self.mode = mode
|
||||||
nid, cid, hid, wid = jt.index(new_size)
|
self.align_corners = align_corners
|
||||||
x = hid * h / H
|
def execute(self, x):
|
||||||
y = wid * w / W
|
return resize(x, self.size, self.mode, self.align_corners)
|
||||||
|
|
||||||
|
def _interpolate(img, x, y, ids, mode):
|
||||||
if mode=="nearest":
|
if mode=="nearest":
|
||||||
return img.reindex([nid, cid, x.floor(), y.floor()])
|
return img.reindex([*ids, x.floor(), y.floor()])
|
||||||
if mode=="bilinear":
|
if mode=="bilinear":
|
||||||
fx, fy = x.floor(), y.floor()
|
fx, fy = x.floor(), y.floor()
|
||||||
cx, cy = fx+1, fy+1
|
cx, cy = fx+1, fy+1
|
||||||
dx, dy = x-fx, y-fy
|
dx, dy = x-fx, y-fy
|
||||||
a = img.reindex_var([nid, cid, fx, fy])
|
a = img.reindex_var([*ids, fx, fy])
|
||||||
b = img.reindex_var([nid, cid, cx, fy])
|
b = img.reindex_var([*ids, cx, fy])
|
||||||
c = img.reindex_var([nid, cid, fx, cy])
|
c = img.reindex_var([*ids, fx, cy])
|
||||||
d = img.reindex_var([nid, cid, cx, cy])
|
d = img.reindex_var([*ids, cx, cy])
|
||||||
dnx, dny = 1-dx, 1-dy
|
dnx, dny = 1-dx, 1-dy
|
||||||
ab = dx*b + dnx*a
|
ab = dx*b + dnx*a
|
||||||
cd = dx*d + dnx*c
|
cd = dx*d + dnx*c
|
||||||
o = ab*dny + cd*dy
|
o = ab*dny + cd*dy
|
||||||
return o
|
return o
|
||||||
raise(f"Not support {interpolation}")
|
raise(f"Not support interpolation mode: {mode}")
|
||||||
|
|
||||||
|
def resize(img, size, mode="nearest", align_corners=False):
|
||||||
|
n,c,h,w = img.shape
|
||||||
|
H,W = size
|
||||||
|
nid, cid, hid, wid = jt.index((n,c,H,W))
|
||||||
|
if align_corners:
|
||||||
|
x = hid * ((h-1) / max(1, H-1))
|
||||||
|
y = wid * ((w-1) / max(1, W-1))
|
||||||
|
else:
|
||||||
|
x = hid * (h / H) + (h/H*0.5 - 0.5)
|
||||||
|
if H>h: x = x.clamp(0, h-1)
|
||||||
|
y = wid * (w / W) + (w/W*0.5 - 0.5)
|
||||||
|
if W>w: y = y.clamp(0, w-1)
|
||||||
|
return _interpolate(img, x, y, (nid,cid), mode)
|
||||||
|
|
||||||
|
def upsample(img, size, mode="nearest", align_corners=False):
|
||||||
|
n,c,h,w = img.shape
|
||||||
|
H,W = size
|
||||||
|
nid, cid, hid, wid = jt.index((n,c,H,W))
|
||||||
|
if align_corners:
|
||||||
|
x = hid * ((h-1) / max(1, H-1))
|
||||||
|
y = wid * ((w-1) / max(1, W-1))
|
||||||
|
else:
|
||||||
|
x = hid * (h / H)
|
||||||
|
y = wid * (w / W)
|
||||||
|
return _interpolate(img, x, y, (nid,cid), mode)
|
||||||
|
|
||||||
class Upsample(Module):
|
class Upsample(Module):
|
||||||
def __init__(self, scale_factor=None, mode='nearest'):
|
def __init__(self, scale_factor=None, mode='nearest'):
|
||||||
|
@ -550,7 +578,11 @@ class Upsample(Module):
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
return resize(x, size=(int(x.shape[2]*self.scale_factor[0]), int(x.shape[3]*self.scale_factor[1])), mode=self.mode)
|
return upsample(x,
|
||||||
|
size=(
|
||||||
|
int(x.shape[2]*self.scale_factor[0]),
|
||||||
|
int(x.shape[3]*self.scale_factor[1])),
|
||||||
|
mode=self.mode)
|
||||||
|
|
||||||
class Sequential(Module):
|
class Sequential(Module):
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
|
|
|
@ -99,16 +99,25 @@ class TestResizeAndCrop(unittest.TestCase):
|
||||||
test_case(20, [1024, 1024], [1.2, 1.8][mid])
|
test_case(20, [1024, 1024], [1.2, 1.8][mid])
|
||||||
test_case(20, [1024, 666], [0.8,1.0][mid])
|
test_case(20, [1024, 666], [0.8,1.0][mid])
|
||||||
|
|
||||||
|
def test_resize(self):
|
||||||
|
import torch.nn.functional as F
|
||||||
|
x = np.array(range(2*3*25)).reshape(2,3,5,5).astype("float32")
|
||||||
|
for r_size in [3,4,5,6]:
|
||||||
|
for align_corners in [True,False]:
|
||||||
|
check_equal(x,
|
||||||
|
jnn.Resize((r_size, r_size), 'bilinear', align_corners),
|
||||||
|
lambda x: F.interpolate(x, size=(r_size, r_size), mode='bilinear',align_corners=align_corners))
|
||||||
|
|
||||||
def test_upsample(self):
|
def test_upsample(self):
|
||||||
arr = np.random.randn(16,10,224,224)
|
arr = np.random.randn(2,3,224,224)
|
||||||
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
|
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
|
||||||
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
|
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
|
||||||
|
|
||||||
def test_pixelshuffle(self):
|
def test_pixelshuffle(self):
|
||||||
arr = np.random.randn(16,16,224,224)
|
arr = np.random.randn(2,4,224,224)
|
||||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
|
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
|
||||||
arr = np.random.randn(1,16*16,224,224)
|
arr = np.random.randn(1,3*3,224,224)
|
||||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=16), tnn.PixelShuffle(upscale_factor=16))
|
check_equal(arr, jnn.PixelShuffle(upscale_factor=3), tnn.PixelShuffle(upscale_factor=3))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue