mirror of https://github.com/Jittor/Jittor
polish interpolate
This commit is contained in:
parent
eae0357224
commit
6bc17cb99c
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.21'
|
||||
__version__ = '1.3.1.22'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -1653,7 +1653,7 @@ upsample = resize
|
|||
|
||||
def interpolate(X, size=None, scale_factor=None, mode='bilinear', align_corners=False, tf_mode=False):
|
||||
if scale_factor is not None:
|
||||
size = [X.shape[-2] * scale_factor, X.shape[-1] * scale_factor]
|
||||
size = [int(X.shape[-2] * scale_factor), int(X.shape[-1] * scale_factor)]
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
if scale_factor is not None and scale_factor > 1:
|
||||
|
|
|
@ -130,6 +130,12 @@ class TestResizeAndCrop(unittest.TestCase):
|
|||
arr = np.random.randn(1,1,2,2)
|
||||
check_equal(arr, jnn.Resize((4,4)), tnn.Upsample(scale_factor=2))
|
||||
# check_equal(arr, jnn.Upsample(scale_factor=0.5), tnn.Upsample(scale_factor=0.5))
|
||||
|
||||
def test_interpolate(self):
|
||||
a = jt.rand(1,3,64,64)
|
||||
b = jt.nn.interpolate(a, scale_factor=0.5)
|
||||
b.sync()
|
||||
assert b.shape == (1,3,32,32)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue