transpose fit pt

This commit is contained in:
Dun Liang 2021-10-18 16:02:02 +08:00
parent 34a9dab5fb
commit e24dd4f14a
2 changed files with 9 additions and 0 deletions

View File

@ -360,6 +360,11 @@ origin_transpose = transpose
def transpose(x, *dim):
if len(dim) == 1 and isinstance(dim[0], (Sequence, NanoVector)):
dim = dim[0]
elif len(dim) == 2:
axes = list(range(x.ndim))
a, b = dim
axes[a], axes[b] = axes[b], axes[a]
dim = axes
return origin_transpose(x, dim)
transpose.__doc__ = origin_transpose.__doc__
Var.transpose = Var.permute = permute = transpose

View File

@ -68,6 +68,10 @@ class TestTransposeOp(unittest.TestCase):
assert a.permute().shape == [4,3,2]
assert a.permute(0,2,1).shape == [2,4,3]
def test_transpose_3d2i(self):
a = jt.ones([2,3,4])
assert a.transpose(0,1).shape == (3,2,4)
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_cutt(self):