This commit is contained in:
Dun Liang 2021-06-05 23:05:01 +08:00
parent a07eb6bc12
commit b14a60ad74
3 changed files with 50 additions and 2 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.2.3.19' __version__ = '1.2.3.20'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -1197,7 +1197,7 @@ def gather(x, dim, index):
Parameters:: Parameters::
* input (jt.Var) the source array * x (jt.Var) the source array
* dim (int) the axis along which to index * dim (int) the axis along which to index
* index (jt.Var) the indices of elements to gather * index (jt.Var) the indices of elements to gather
@ -1216,3 +1216,42 @@ Example::
return x.getitem(tuple(indexes)) return x.getitem(tuple(indexes))
jt.Var.gather = gather jt.Var.gather = gather
def roll(x, shifts, dims=None):
'''Roll the tensor along the given dimension(s).
Parameters::
* x (jt.Var) the source array
* shifts (int or tuple) shift offset of dims
* dims (int or tuple) shift dims
Examples::
x = jt.array([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
y = x.roll(1, 0)
assert (y.numpy() == [[7,8],[1,2],[3,4],[5,6]]).all()
y = x.roll(-1, 0)
assert (y.numpy() == [[3,4],[5,6],[7,8],[1,2]]).all()
y = x.roll(shifts=(2, 1), dims=(0, 1))
assert (y.numpy() == [[6,5],[8,7],[2,1],[4,3]]).all()
'''
if isinstance(shifts, int):
shifts = (shifts,)
if dims is None:
dims = tuple(range(len(shifts)))
elif isinstance(dims, int):
dims = (dims,)
assert len(dims) == len(shifts)
ids = [ f'i{i}' for i in range(x.ndim) ]
for i in range(len(dims)):
shift = shifts[i]
d = dims[i]
size = x.shape[d]
shift = shift % size
if shift<0: shift += size
ids[d] = f'(i{d}<{shift}?i{d}+{size-shift}:(i{d}-{shift}))'
return x.reindex(x.shape, ids)
jt.Var.roll = roll

View File

@ -201,6 +201,15 @@ class TestSetitem(unittest.TestCase):
a = jt.array([1,2]) a = jt.array([1,2])
assert a[None,:,None,None,...,None].shape == (1,2,1,1,1) assert a[None,:,None,None,...,None].shape == (1,2,1,1,1)
def test_roll(self):
x = jt.array([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
y = x.roll(1, 0)
assert (y.numpy() == [[7,8],[1,2],[3,4],[5,6]]).all(), y
y = x.roll(-1, 0)
assert (y.numpy() == [[3,4],[5,6],[7,8],[1,2]]).all()
y = x.roll(shifts=(2, 1), dims=(0, 1))
assert (y.numpy() == [[6,5],[8,7],[2,1],[4,3]]).all()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()