mirror of https://github.com/Jittor/Jittor
add save_image and Function apply
This commit is contained in:
parent
90321aa65d
commit
5d321c13cf
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.4'
|
||||
__version__ = '1.2.2.5'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -831,6 +831,11 @@ can also be None)::
|
|||
def dfs(self, parents, k, callback, callback_leave=None):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *args, **kw):
|
||||
func = cls()
|
||||
return func(*args, **kw)
|
||||
|
||||
|
||||
def make_module(func, exec_n_args=1):
|
||||
class MakeModule(Module):
|
||||
|
|
|
@ -345,6 +345,25 @@ def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=Fals
|
|||
[f"i1/{padding+h}*{nrow}+i2/{padding+w}", "i0",
|
||||
f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value)
|
||||
|
||||
def save_image(
|
||||
x,
|
||||
filepath,
|
||||
nrow: int = 8,
|
||||
padding: int = 2,
|
||||
normalize: bool = False,
|
||||
range = None,
|
||||
scale_each = False,
|
||||
pad_value = 0,
|
||||
format = None
|
||||
):
|
||||
from PIL import Image
|
||||
grid = make_grid(x, nrow=nrow, padding=padding, pad_value=pad_value,
|
||||
normalize=normalize, range=range, scale_each=scale_each)
|
||||
|
||||
ndarr = (grid*255+0.5).clamp(0, 255).permute(1, 2, 0).uint8().numpy()
|
||||
im = Image.fromarray(ndarr)
|
||||
im.save(filepath, format=format)
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
|
|
|
@ -26,6 +26,19 @@ class TestFunction(unittest.TestCase):
|
|||
da = jt.grad(b, a)
|
||||
assert da.data == -1
|
||||
|
||||
def test_apply(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x):
|
||||
return x+1
|
||||
|
||||
def grad(self, grad):
|
||||
return grad-2
|
||||
a = jt.ones(1)
|
||||
func = MyFunc.apply
|
||||
b = func(a)
|
||||
da = jt.grad(b, a)
|
||||
assert da.data == -1
|
||||
|
||||
def test2(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x):
|
||||
|
|
|
@ -85,6 +85,10 @@ class TestPad(unittest.TestCase):
|
|||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1))
|
||||
print('pass make_grid test ...')
|
||||
|
||||
def test_save_image(self):
|
||||
arr = jt.array(np.random.randn(16,3,10,10))
|
||||
jt.save_image(arr, "/tmp/a.jpg")
|
||||
|
||||
def test_unbind(self):
|
||||
arr = np.random.randn(2,3,4)
|
||||
for dim in range(len(arr.shape)):
|
||||
|
|
Loading…
Reference in New Issue