make_grid add range support

This commit is contained in:
zhouwy19 2020-12-15 11:27:44 +08:00
parent 398746044a
commit 73eb05b36e
2 changed files with 5 additions and 2 deletions

View File

@ -338,10 +338,12 @@ def unbind(x, dim=0):
jt.Var.unbind = unbind
def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0):
assert range == None
assert isinstance(range, tuple) or range is None
assert scale_each == False
if isinstance(x, list): x = jt.stack(x)
if normalize: x = (x - x.min()) / (x.max() - x.min())
if normalize:
if range is None: x = (x - x.min()) / (x.max() - x.min())
else: x = (x - range[0]) / (range[1] - range[0])
b,c,h,w = x.shape
ncol = math.ceil(b / nrow)
return x.reindex([c, h*ncol+(ncol+1)*padding, w*nrow+(nrow+1)*padding],

View File

@ -84,6 +84,7 @@ class TestPad(unittest.TestCase):
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4), jt.make_grid(jt.array(arr), nrow=3, padding=4))
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, padding=4, pad_value=-1))
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1))
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)))
print('pass make_grid test ...')
def test_save_image(self):