mirror of https://github.com/Jittor/Jittor
add roll
This commit is contained in:
parent
a07eb6bc12
commit
b14a60ad74
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.19'
|
||||
__version__ = '1.2.3.20'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -1197,7 +1197,7 @@ def gather(x, dim, index):
|
|||
|
||||
Parameters::
|
||||
|
||||
* input (jt.Var) – the source array
|
||||
* x (jt.Var) – the source array
|
||||
* dim (int) – the axis along which to index
|
||||
* index (jt.Var) – the indices of elements to gather
|
||||
|
||||
|
@ -1216,3 +1216,42 @@ Example::
|
|||
return x.getitem(tuple(indexes))
|
||||
|
||||
jt.Var.gather = gather
|
||||
|
||||
def roll(x, shifts, dims=None):
|
||||
'''Roll the tensor along the given dimension(s).
|
||||
|
||||
Parameters::
|
||||
|
||||
* x (jt.Var) – the source array
|
||||
* shifts (int or tuple) – shift offset of dims
|
||||
* dims (int or tuple) – shift dims
|
||||
|
||||
Examples::
|
||||
|
||||
x = jt.array([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
|
||||
y = x.roll(1, 0)
|
||||
assert (y.numpy() == [[7,8],[1,2],[3,4],[5,6]]).all()
|
||||
y = x.roll(-1, 0)
|
||||
assert (y.numpy() == [[3,4],[5,6],[7,8],[1,2]]).all()
|
||||
y = x.roll(shifts=(2, 1), dims=(0, 1))
|
||||
assert (y.numpy() == [[6,5],[8,7],[2,1],[4,3]]).all()
|
||||
|
||||
'''
|
||||
if isinstance(shifts, int):
|
||||
shifts = (shifts,)
|
||||
if dims is None:
|
||||
dims = tuple(range(len(shifts)))
|
||||
elif isinstance(dims, int):
|
||||
dims = (dims,)
|
||||
assert len(dims) == len(shifts)
|
||||
ids = [ f'i{i}' for i in range(x.ndim) ]
|
||||
for i in range(len(dims)):
|
||||
shift = shifts[i]
|
||||
d = dims[i]
|
||||
size = x.shape[d]
|
||||
shift = shift % size
|
||||
if shift<0: shift += size
|
||||
ids[d] = f'(i{d}<{shift}?i{d}+{size-shift}:(i{d}-{shift}))'
|
||||
return x.reindex(x.shape, ids)
|
||||
|
||||
jt.Var.roll = roll
|
||||
|
|
|
@ -201,6 +201,15 @@ class TestSetitem(unittest.TestCase):
|
|||
a = jt.array([1,2])
|
||||
assert a[None,:,None,None,...,None].shape == (1,2,1,1,1)
|
||||
|
||||
def test_roll(self):
|
||||
x = jt.array([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
|
||||
y = x.roll(1, 0)
|
||||
assert (y.numpy() == [[7,8],[1,2],[3,4],[5,6]]).all(), y
|
||||
y = x.roll(-1, 0)
|
||||
assert (y.numpy() == [[3,4],[5,6],[7,8],[1,2]]).all()
|
||||
y = x.roll(shifts=(2, 1), dims=(0, 1))
|
||||
assert (y.numpy() == [[6,5],[8,7],[2,1],[4,3]]).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue