polish interpolate

This commit is contained in:
Dun Liang 2021-12-06 12:21:02 +08:00
parent eae0357224
commit 6bc17cb99c
3 changed files with 8 additions and 2 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.1.21'
__version__ = '1.3.1.22'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -1653,7 +1653,7 @@ upsample = resize
def interpolate(X, size=None, scale_factor=None, mode='bilinear', align_corners=False, tf_mode=False):
if scale_factor is not None:
size = [X.shape[-2] * scale_factor, X.shape[-1] * scale_factor]
size = [int(X.shape[-2] * scale_factor), int(X.shape[-1] * scale_factor)]
if isinstance(size, int):
size = (size, size)
if scale_factor is not None and scale_factor > 1:

View File

@ -130,6 +130,12 @@ class TestResizeAndCrop(unittest.TestCase):
arr = np.random.randn(1,1,2,2)
check_equal(arr, jnn.Resize((4,4)), tnn.Upsample(scale_factor=2))
# check_equal(arr, jnn.Upsample(scale_factor=0.5), tnn.Upsample(scale_factor=0.5))
def test_interpolate(self):
a = jt.rand(1,3,64,64)
b = jt.nn.interpolate(a, scale_factor=0.5)
b.sync()
assert b.shape == (1,3,32,32)
if __name__ == "__main__":