polish make_grid interface

This commit is contained in:
Dun Liang 2021-01-31 11:22:22 +08:00
parent 4883d75e1d
commit 84967c21c4
3 changed files with 18 additions and 1 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.23'
__version__ = '1.2.2.24'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -363,6 +363,10 @@ def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=Fals
assert isinstance(range, tuple) or range is None
assert scale_each == False
if isinstance(x, list): x = jt.stack(x)
assert isinstance(x, jt.Var)
if x.ndim < 4: return x
if x.ndim == 4 and x.shape[0] <= 1: return x
nrow = min(nrow, x.shape[0])
if normalize:
if range is None: x = (x - x.min()) / (x.max() - x.min())
else: x = (x - range[0]) / (range[1] - range[0])

View File

@ -88,6 +88,19 @@ class TestPad(unittest.TestCase):
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)))
print('pass make_grid test ...')
def test_make_grid2(self):
def check(shape):
arr = np.random.randn(*shape)
check_equal(torchvision.utils.make_grid(torch.Tensor(arr)), jt.make_grid(jt.array(arr)))
check((3,100,200))
check((1,100,200))
check((100,200))
check((1,3,100,200))
check((4,3,100,200))
check((10,3,100,200))
def test_save_image(self):
arr = jt.array(np.random.randn(16,3,10,10))
jt.save_image(arr, "/tmp/a.jpg")