mirror of https://github.com/Jittor/Jittor
add documentation and tests.
This commit is contained in:
parent
56255578f9
commit
b258cf3a84
|
@ -1770,6 +1770,9 @@ def resize(img, size, mode="nearest", align_corners=False, tf_mode=False):
|
|||
x = hid * (h / H)
|
||||
y = wid * (w / W)
|
||||
elif mode == "area":
|
||||
'''
|
||||
Area interpolation uses AdaptivePool2D to resize origin images.
|
||||
'''
|
||||
stride = (h // H, w // W)
|
||||
assert stride[0] > 0 and stride[1] > 0
|
||||
x, y = jt.meshgrid(jt.arange(0, H, 1), jt.arange(0, W, 1))
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2022 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Haoyang Peng <2247838039@qq.com>
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
from jittor import nn
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
try:
|
||||
import torch
|
||||
has_torch = True
|
||||
except:
|
||||
has_torch = False
|
||||
|
||||
@unittest.skipIf(not has_torch, "No pytorch installation found.")
|
||||
class TestInterpolation(unittest.TestCase):
|
||||
def test_interpolation_area(self):
|
||||
img = np.random.uniform(0, 1, (1, 3, 24, 10))
|
||||
output_shape = (12, 5)
|
||||
jimg = jt.array(img)
|
||||
timg = torch.from_numpy(img)
|
||||
joutput = nn.interpolate(jimg, output_shape, mode="area")
|
||||
toutput = torch.nn.functional.interpolate(timg, output_shape, mode="area")
|
||||
np.testing.assert_allclose(joutput.numpy(), toutput.numpy(), rtol=1e-7, atol=1e-7)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue