mirror of https://github.com/Jittor/Jittor
transpose fit pt
This commit is contained in:
parent
34a9dab5fb
commit
e24dd4f14a
|
@ -360,6 +360,11 @@ origin_transpose = transpose
|
|||
def transpose(x, *dim):
|
||||
if len(dim) == 1 and isinstance(dim[0], (Sequence, NanoVector)):
|
||||
dim = dim[0]
|
||||
elif len(dim) == 2:
|
||||
axes = list(range(x.ndim))
|
||||
a, b = dim
|
||||
axes[a], axes[b] = axes[b], axes[a]
|
||||
dim = axes
|
||||
return origin_transpose(x, dim)
|
||||
transpose.__doc__ = origin_transpose.__doc__
|
||||
Var.transpose = Var.permute = permute = transpose
|
||||
|
|
|
@ -68,6 +68,10 @@ class TestTransposeOp(unittest.TestCase):
|
|||
assert a.permute().shape == [4,3,2]
|
||||
assert a.permute(0,2,1).shape == [2,4,3]
|
||||
|
||||
def test_transpose_3d2i(self):
|
||||
a = jt.ones([2,3,4])
|
||||
assert a.transpose(0,1).shape == (3,2,4)
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cutt(self):
|
||||
|
|
Loading…
Reference in New Issue