add save_image and Function apply

This commit is contained in:
Dun Liang 2020-12-14 17:07:17 +08:00
parent 90321aa65d
commit 5d321c13cf
4 changed files with 42 additions and 1 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.4'
__version__ = '1.2.2.5'
from . import lock
with lock.lock_scope():
ori_int = int
@ -831,6 +831,11 @@ can also be None)::
def dfs(self, parents, k, callback, callback_leave=None):
pass
@classmethod
def apply(cls, *args, **kw):
func = cls()
return func(*args, **kw)
def make_module(func, exec_n_args=1):
class MakeModule(Module):

View File

@ -345,6 +345,25 @@ def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=Fals
[f"i1/{padding+h}*{nrow}+i2/{padding+w}", "i0",
f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value)
def save_image(
x,
filepath,
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
range = None,
scale_each = False,
pad_value = 0,
format = None
):
from PIL import Image
grid = make_grid(x, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
ndarr = (grid*255+0.5).clamp(0, 255).permute(1, 2, 0).uint8().numpy()
im = Image.fromarray(ndarr)
im.save(filepath, format=format)
def _ntuple(n):
def parse(x):

View File

@ -26,6 +26,19 @@ class TestFunction(unittest.TestCase):
da = jt.grad(b, a)
assert da.data == -1
def test_apply(self):
class MyFunc(Function):
def execute(self, x):
return x+1
def grad(self, grad):
return grad-2
a = jt.ones(1)
func = MyFunc.apply
b = func(a)
da = jt.grad(b, a)
assert da.data == -1
def test2(self):
class MyFunc(Function):
def execute(self, x):

View File

@ -85,6 +85,10 @@ class TestPad(unittest.TestCase):
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1))
print('pass make_grid test ...')
def test_save_image(self):
arr = jt.array(np.random.randn(16,3,10,10))
jt.save_image(arr, "/tmp/a.jpg")
def test_unbind(self):
arr = np.random.randn(2,3,4)
for dim in range(len(arr.shape)):