mirror of https://github.com/Jittor/Jittor
flip support multiple dims
This commit is contained in:
parent
5d321c13cf
commit
8430c48ad5
|
@ -209,15 +209,18 @@ def flip(x, dim=0):
|
|||
>>> x.flip(1)
|
||||
[[4 3 2 1]]
|
||||
'''
|
||||
assert isinstance(dim, int)
|
||||
if dim<0:
|
||||
dim+=x.ndim
|
||||
assert dim>=0 and dim<len(x.shape)
|
||||
if isinstance(dim, int):
|
||||
dim = [dim]
|
||||
for i in range(len(dim)):
|
||||
if dim[i]<0:
|
||||
dim[i] += x.ndim
|
||||
assert dim[i]>=0 and dim[i]<x.ndim
|
||||
dim = set(dim)
|
||||
|
||||
tar_dims = []
|
||||
for i in range(len(x.shape)):
|
||||
if i == dim:
|
||||
tar_dims.append(f"{x.shape[dim]-1}-i{i}")
|
||||
if i in dim:
|
||||
tar_dims.append(f"xshape{i}-1-i{i}")
|
||||
else:
|
||||
tar_dims.append(f"i{i}")
|
||||
return x.reindex(x.shape, tar_dims)
|
||||
|
@ -359,7 +362,7 @@ def save_image(
|
|||
from PIL import Image
|
||||
grid = make_grid(x, nrow=nrow, padding=padding, pad_value=pad_value,
|
||||
normalize=normalize, range=range, scale_each=scale_each)
|
||||
|
||||
|
||||
ndarr = (grid*255+0.5).clamp(0, 255).permute(1, 2, 0).uint8().numpy()
|
||||
im = Image.fromarray(ndarr)
|
||||
im.save(filepath, format=format)
|
||||
|
|
|
@ -54,6 +54,7 @@ class TestPad(unittest.TestCase):
|
|||
check_equal(torch.Tensor(arr).flip(1), jt.array(arr).flip(1))
|
||||
check_equal(torch.Tensor(arr).flip(2), jt.array(arr).flip(2))
|
||||
check_equal(torch.Tensor(arr).flip(3), jt.array(arr).flip(3))
|
||||
check_equal(torch.Tensor(arr).flip([2,3]), jt.array(arr).flip([2,3]))
|
||||
print('pass flip test ...')
|
||||
|
||||
def test_cross(self):
|
||||
|
|
Loading…
Reference in New Issue