add documentation and tests.

This commit is contained in:
Exusial 2022-09-18 12:00:07 +08:00 committed by Zheng-Ning Liu
parent 56255578f9
commit b258cf3a84
2 changed files with 37 additions and 0 deletions

View File

@ -1770,6 +1770,9 @@ def resize(img, size, mode="nearest", align_corners=False, tf_mode=False):
x = hid * (h / H)
y = wid * (w / W)
elif mode == "area":
'''
Area interpolation uses AdaptivePool2D to resize origin images.
'''
stride = (h // H, w // W)
assert stride[0] > 0 and stride[1] > 0
x, y = jt.meshgrid(jt.arange(0, H, 1), jt.arange(0, W, 1))

View File

@ -0,0 +1,34 @@
# ***************************************************************
# Copyright (c) 2022 Jittor. All Rights Reserved.
# Maintainers:
# Haoyang Peng <2247838039@qq.com>
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from jittor import nn
import numpy as np
import unittest
try:
import torch
has_torch = True
except:
has_torch = False
@unittest.skipIf(not has_torch, "No pytorch installation found.")
class TestInterpolation(unittest.TestCase):
def test_interpolation_area(self):
img = np.random.uniform(0, 1, (1, 3, 24, 10))
output_shape = (12, 5)
jimg = jt.array(img)
timg = torch.from_numpy(img)
joutput = nn.interpolate(jimg, output_shape, mode="area")
toutput = torch.nn.functional.interpolate(timg, output_shape, mode="area")
np.testing.assert_allclose(joutput.numpy(), toutput.numpy(), rtol=1e-7, atol=1e-7)
if __name__ == "__main__":
unittest.main()