mirror of https://github.com/Jittor/Jittor
align resize and upsample
This commit is contained in:
parent
a6dab87634
commit
370a3cc8ef
|
@ -499,10 +499,10 @@ class PixelShuffle(Module):
|
|||
def execute(self, x):
|
||||
n,c,h,w = x.shape
|
||||
r = self.upscale_factor
|
||||
assert c%(r**2)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
|
||||
assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle"
|
||||
return x.reindex([n,int(c/r**2),h*r,w*r], [
|
||||
"i0",
|
||||
f"i1*{r**2}+i2%{r}*{r}+i3%{r}",
|
||||
f"i1*{r*r}+i2%{r}*{r}+i3%{r}",
|
||||
f"i2/{r}",
|
||||
f"i3/{r}"
|
||||
])
|
||||
|
@ -519,30 +519,58 @@ class Sigmoid(Module):
|
|||
def execute(self, x) :
|
||||
return x.sigmoid()
|
||||
|
||||
def resize(x, size, mode="nearest"):
|
||||
img = x
|
||||
n,c,h,w = x.shape
|
||||
H,W = size
|
||||
new_size = [n,c,H,W]
|
||||
nid, cid, hid, wid = jt.index(new_size)
|
||||
x = hid * h / H
|
||||
y = wid * w / W
|
||||
class Resize(Module):
|
||||
def __init__(self, size, mode="nearest", align_corners=False):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
def execute(self, x):
|
||||
return resize(x, self.size, self.mode, self.align_corners)
|
||||
|
||||
def _interpolate(img, x, y, ids, mode):
|
||||
if mode=="nearest":
|
||||
return img.reindex([nid, cid, x.floor(), y.floor()])
|
||||
return img.reindex([*ids, x.floor(), y.floor()])
|
||||
if mode=="bilinear":
|
||||
fx, fy = x.floor(), y.floor()
|
||||
cx, cy = fx+1, fy+1
|
||||
dx, dy = x-fx, y-fy
|
||||
a = img.reindex_var([nid, cid, fx, fy])
|
||||
b = img.reindex_var([nid, cid, cx, fy])
|
||||
c = img.reindex_var([nid, cid, fx, cy])
|
||||
d = img.reindex_var([nid, cid, cx, cy])
|
||||
a = img.reindex_var([*ids, fx, fy])
|
||||
b = img.reindex_var([*ids, cx, fy])
|
||||
c = img.reindex_var([*ids, fx, cy])
|
||||
d = img.reindex_var([*ids, cx, cy])
|
||||
dnx, dny = 1-dx, 1-dy
|
||||
ab = dx*b + dnx*a
|
||||
cd = dx*d + dnx*c
|
||||
o = ab*dny + cd*dy
|
||||
return o
|
||||
raise(f"Not support {interpolation}")
|
||||
raise(f"Not support interpolation mode: {mode}")
|
||||
|
||||
def resize(img, size, mode="nearest", align_corners=False):
|
||||
n,c,h,w = img.shape
|
||||
H,W = size
|
||||
nid, cid, hid, wid = jt.index((n,c,H,W))
|
||||
if align_corners:
|
||||
x = hid * ((h-1) / max(1, H-1))
|
||||
y = wid * ((w-1) / max(1, W-1))
|
||||
else:
|
||||
x = hid * (h / H) + (h/H*0.5 - 0.5)
|
||||
if H>h: x = x.clamp(0, h-1)
|
||||
y = wid * (w / W) + (w/W*0.5 - 0.5)
|
||||
if W>w: y = y.clamp(0, w-1)
|
||||
return _interpolate(img, x, y, (nid,cid), mode)
|
||||
|
||||
def upsample(img, size, mode="nearest", align_corners=False):
|
||||
n,c,h,w = img.shape
|
||||
H,W = size
|
||||
nid, cid, hid, wid = jt.index((n,c,H,W))
|
||||
if align_corners:
|
||||
x = hid * ((h-1) / max(1, H-1))
|
||||
y = wid * ((w-1) / max(1, W-1))
|
||||
else:
|
||||
x = hid * (h / H)
|
||||
y = wid * (w / W)
|
||||
return _interpolate(img, x, y, (nid,cid), mode)
|
||||
|
||||
class Upsample(Module):
|
||||
def __init__(self, scale_factor=None, mode='nearest'):
|
||||
|
@ -550,7 +578,11 @@ class Upsample(Module):
|
|||
self.mode = mode
|
||||
|
||||
def execute(self, x):
|
||||
return resize(x, size=(int(x.shape[2]*self.scale_factor[0]), int(x.shape[3]*self.scale_factor[1])), mode=self.mode)
|
||||
return upsample(x,
|
||||
size=(
|
||||
int(x.shape[2]*self.scale_factor[0]),
|
||||
int(x.shape[3]*self.scale_factor[1])),
|
||||
mode=self.mode)
|
||||
|
||||
class Sequential(Module):
|
||||
def __init__(self, *args):
|
||||
|
|
|
@ -99,16 +99,25 @@ class TestResizeAndCrop(unittest.TestCase):
|
|||
test_case(20, [1024, 1024], [1.2, 1.8][mid])
|
||||
test_case(20, [1024, 666], [0.8,1.0][mid])
|
||||
|
||||
def test_resize(self):
|
||||
import torch.nn.functional as F
|
||||
x = np.array(range(2*3*25)).reshape(2,3,5,5).astype("float32")
|
||||
for r_size in [3,4,5,6]:
|
||||
for align_corners in [True,False]:
|
||||
check_equal(x,
|
||||
jnn.Resize((r_size, r_size), 'bilinear', align_corners),
|
||||
lambda x: F.interpolate(x, size=(r_size, r_size), mode='bilinear',align_corners=align_corners))
|
||||
|
||||
def test_upsample(self):
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
arr = np.random.randn(2,3,224,224)
|
||||
check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2))
|
||||
check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2))
|
||||
|
||||
def test_pixelshuffle(self):
|
||||
arr = np.random.randn(16,16,224,224)
|
||||
arr = np.random.randn(2,4,224,224)
|
||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2))
|
||||
arr = np.random.randn(1,16*16,224,224)
|
||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=16), tnn.PixelShuffle(upscale_factor=16))
|
||||
arr = np.random.randn(1,3*3,224,224)
|
||||
check_equal(arr, jnn.PixelShuffle(upscale_factor=3), tnn.PixelShuffle(upscale_factor=3))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue