mirror of https://github.com/Jittor/Jittor
polish make_grid interface
This commit is contained in:
parent
4883d75e1d
commit
84967c21c4
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.23'
|
||||
__version__ = '1.2.2.24'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -363,6 +363,10 @@ def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=Fals
|
|||
assert isinstance(range, tuple) or range is None
|
||||
assert scale_each == False
|
||||
if isinstance(x, list): x = jt.stack(x)
|
||||
assert isinstance(x, jt.Var)
|
||||
if x.ndim < 4: return x
|
||||
if x.ndim == 4 and x.shape[0] <= 1: return x
|
||||
nrow = min(nrow, x.shape[0])
|
||||
if normalize:
|
||||
if range is None: x = (x - x.min()) / (x.max() - x.min())
|
||||
else: x = (x - range[0]) / (range[1] - range[0])
|
||||
|
|
|
@ -88,6 +88,19 @@ class TestPad(unittest.TestCase):
|
|||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)))
|
||||
print('pass make_grid test ...')
|
||||
|
||||
def test_make_grid2(self):
|
||||
def check(shape):
|
||||
arr = np.random.randn(*shape)
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr)), jt.make_grid(jt.array(arr)))
|
||||
check((3,100,200))
|
||||
check((1,100,200))
|
||||
check((100,200))
|
||||
check((1,3,100,200))
|
||||
check((4,3,100,200))
|
||||
check((10,3,100,200))
|
||||
|
||||
|
||||
|
||||
def test_save_image(self):
|
||||
arr = jt.array(np.random.randn(16,3,10,10))
|
||||
jt.save_image(arr, "/tmp/a.jpg")
|
||||
|
|
Loading…
Reference in New Issue