mirror of https://github.com/Jittor/Jittor
make_grid add range support
This commit is contained in:
parent
398746044a
commit
73eb05b36e
|
@ -338,10 +338,12 @@ def unbind(x, dim=0):
|
|||
jt.Var.unbind = unbind
|
||||
|
||||
def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0):
|
||||
assert range == None
|
||||
assert isinstance(range, tuple) or range is None
|
||||
assert scale_each == False
|
||||
if isinstance(x, list): x = jt.stack(x)
|
||||
if normalize: x = (x - x.min()) / (x.max() - x.min())
|
||||
if normalize:
|
||||
if range is None: x = (x - x.min()) / (x.max() - x.min())
|
||||
else: x = (x - range[0]) / (range[1] - range[0])
|
||||
b,c,h,w = x.shape
|
||||
ncol = math.ceil(b / nrow)
|
||||
return x.reindex([c, h*ncol+(ncol+1)*padding, w*nrow+(nrow+1)*padding],
|
||||
|
|
|
@ -84,6 +84,7 @@ class TestPad(unittest.TestCase):
|
|||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4), jt.make_grid(jt.array(arr), nrow=3, padding=4))
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, padding=4, pad_value=-1))
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1))
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)))
|
||||
print('pass make_grid test ...')
|
||||
|
||||
def test_save_image(self):
|
||||
|
|
Loading…
Reference in New Issue