flip support multiple dims

This commit is contained in:
Dun Liang 2020-12-14 17:15:40 +08:00
parent 5d321c13cf
commit 8430c48ad5
2 changed files with 11 additions and 7 deletions

View File

@ -209,15 +209,18 @@ def flip(x, dim=0):
>>> x.flip(1)
[[4 3 2 1]]
'''
assert isinstance(dim, int)
if dim<0:
dim+=x.ndim
assert dim>=0 and dim<len(x.shape)
if isinstance(dim, int):
dim = [dim]
for i in range(len(dim)):
if dim[i]<0:
dim[i] += x.ndim
assert dim[i]>=0 and dim[i]<x.ndim
dim = set(dim)
tar_dims = []
for i in range(len(x.shape)):
if i == dim:
tar_dims.append(f"{x.shape[dim]-1}-i{i}")
if i in dim:
tar_dims.append(f"xshape{i}-1-i{i}")
else:
tar_dims.append(f"i{i}")
return x.reindex(x.shape, tar_dims)
@ -359,7 +362,7 @@ def save_image(
from PIL import Image
grid = make_grid(x, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
ndarr = (grid*255+0.5).clamp(0, 255).permute(1, 2, 0).uint8().numpy()
im = Image.fromarray(ndarr)
im.save(filepath, format=format)

View File

@ -54,6 +54,7 @@ class TestPad(unittest.TestCase):
check_equal(torch.Tensor(arr).flip(1), jt.array(arr).flip(1))
check_equal(torch.Tensor(arr).flip(2), jt.array(arr).flip(2))
check_equal(torch.Tensor(arr).flip(3), jt.array(arr).flip(3))
check_equal(torch.Tensor(arr).flip([2,3]), jt.array(arr).flip([2,3]))
print('pass flip test ...')
def test_cross(self):