mirror of https://github.com/Jittor/Jittor
add make_grid & unbind
This commit is contained in:
parent
87b1933447
commit
614a17ef8f
|
@ -9,6 +9,7 @@
|
|||
# ***************************************************************
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
|
||||
def repeat(x, *shape):
|
||||
|
@ -229,4 +230,35 @@ def normalize(input, p=2, dim=1, eps=1e-12):
|
|||
assert p == 2
|
||||
if p == 2:
|
||||
return input / jt.maximum(input.sqr().sum(dim,True).sqrt(), eps)
|
||||
jt.Var.normalize = normalize
|
||||
jt.Var.normalize = normalize
|
||||
|
||||
def unbind(x, dim=0):
|
||||
r'''
|
||||
Removes a var dimension.
|
||||
|
||||
Returns a tuple of all slices along a given dimension, already without it.
|
||||
|
||||
Args:
|
||||
|
||||
input (var) – the var to unbind
|
||||
|
||||
dim (int) – dimension to remove
|
||||
|
||||
Example:
|
||||
|
||||
jt.random((3,3))
|
||||
|
||||
'''
|
||||
if dim < 0: dim += len(input.shape)
|
||||
return [x[(slice(None),)*dim+(i,)] for i in range(x.shape[dim])]
|
||||
|
||||
def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0):
|
||||
assert range == None
|
||||
assert scale_each == False
|
||||
if isinstance(x, list): x = jt.stack(x)
|
||||
if normalize: x = (x - x.min()) / (x.max() - x.min())
|
||||
b,c,h,w = x.shape
|
||||
ncol = math.ceil(b / nrow)
|
||||
return x.reindex([c, h*ncol+(ncol+1)*padding, w*nrow+(nrow+1)*padding],
|
||||
[f"i1/{padding+h}*{nrow}+i2/{padding+w}", "i0",
|
||||
f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value)
|
|
@ -18,9 +18,11 @@ try:
|
|||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
import torchvision
|
||||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
torchvision = None
|
||||
skip_this_test = True
|
||||
|
||||
def check_equal(res1, res2, eps=1e-5):
|
||||
|
@ -73,5 +75,24 @@ class TestPad(unittest.TestCase):
|
|||
check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=3), jt.normalize(jt.array(arr), dim=3), 1e-1)
|
||||
print('pass normalize test ...')
|
||||
|
||||
def test_make_grid(self):
|
||||
arr = np.random.randn(16,3,10,10)
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr)), jt.make_grid(jt.array(arr)))
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=2), jt.make_grid(jt.array(arr), nrow=2))
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3), jt.make_grid(jt.array(arr), nrow=3))
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4), jt.make_grid(jt.array(arr), nrow=3, padding=4))
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, padding=4, pad_value=-1))
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1))
|
||||
print('pass make_grid test ...')
|
||||
|
||||
def test_unbind(self):
|
||||
arr = np.random.randn(2,3,4)
|
||||
for dim in range(len(arr.shape)):
|
||||
t_res = torch.unbind(torch.Tensor(arr), dim=dim)
|
||||
j_res = jt.unbind(jt.array(arr), dim=dim)
|
||||
for idx in range(len(t_res)):
|
||||
assert np.allclose(t_res[idx].numpy(), j_res[idx].numpy())
|
||||
print('pass unbind test ...')
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -45,7 +45,7 @@ struct CodeOp : Op {
|
|||
Example-1::
|
||||
|
||||
from jittor import Function
|
||||
from jittor import jt
|
||||
import jittor as jt
|
||||
|
||||
class Func(Function):
|
||||
def execute(self, x):
|
||||
|
@ -140,7 +140,7 @@ struct CodeOp : Op {
|
|||
CUDA Example-1::
|
||||
|
||||
#This example shows how to use CUDA in code op.
|
||||
from jittor import jt
|
||||
import jittor as jt
|
||||
from jittor import Function
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
|
@ -185,7 +185,7 @@ struct CodeOp : Op {
|
|||
CUDA Example-2::
|
||||
|
||||
#This example shows how to use multi dimension data with CUDA.
|
||||
from jittor import jt
|
||||
import jittor as jt
|
||||
from jittor import Function
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
|
|
Loading…
Reference in New Issue