From b14a60ad744e82c8db2a85e40725e1bed9e4314b Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Sat, 5 Jun 2021 23:05:01 +0800 Subject: [PATCH] add roll --- python/jittor/__init__.py | 2 +- python/jittor/misc.py | 41 +++++++++++++++++++++++++++++- python/jittor/test/test_setitem.py | 9 +++++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 44c41389..5ef63a55 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.3.19' +__version__ = '1.2.3.20' from jittor_utils import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/misc.py b/python/jittor/misc.py index ec06260f..73f71255 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -1197,7 +1197,7 @@ def gather(x, dim, index): Parameters:: - * input (jt.Var) – the source array + * x (jt.Var) – the source array * dim (int) – the axis along which to index * index (jt.Var) – the indices of elements to gather @@ -1216,3 +1216,42 @@ Example:: return x.getitem(tuple(indexes)) jt.Var.gather = gather + +def roll(x, shifts, dims=None): + '''Roll the tensor along the given dimension(s). + +Parameters:: + + * x (jt.Var) – the source array + * shifts (int or tuple) – shift offset of dims + * dims (int or tuple) – shift dims + +Examples:: + + x = jt.array([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) + y = x.roll(1, 0) + assert (y.numpy() == [[7,8],[1,2],[3,4],[5,6]]).all() + y = x.roll(-1, 0) + assert (y.numpy() == [[3,4],[5,6],[7,8],[1,2]]).all() + y = x.roll(shifts=(2, 1), dims=(0, 1)) + assert (y.numpy() == [[6,5],[8,7],[2,1],[4,3]]).all() + + ''' + if isinstance(shifts, int): + shifts = (shifts,) + if dims is None: + dims = tuple(range(len(shifts))) + elif isinstance(dims, int): + dims = (dims,) + assert len(dims) == len(shifts) + ids = [ f'i{i}' for i in range(x.ndim) ] + for i in range(len(dims)): + shift = shifts[i] + d = dims[i] + size = x.shape[d] + shift = shift % size + if shift<0: shift += size + ids[d] = f'(i{d}<{shift}?i{d}+{size-shift}:(i{d}-{shift}))' + return x.reindex(x.shape, ids) + +jt.Var.roll = roll diff --git a/python/jittor/test/test_setitem.py b/python/jittor/test/test_setitem.py index 2e86266b..1f3468f7 100644 --- a/python/jittor/test/test_setitem.py +++ b/python/jittor/test/test_setitem.py @@ -201,6 +201,15 @@ class TestSetitem(unittest.TestCase): a = jt.array([1,2]) assert a[None,:,None,None,...,None].shape == (1,2,1,1,1) + def test_roll(self): + x = jt.array([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) + y = x.roll(1, 0) + assert (y.numpy() == [[7,8],[1,2],[3,4],[5,6]]).all(), y + y = x.roll(-1, 0) + assert (y.numpy() == [[3,4],[5,6],[7,8],[1,2]]).all() + y = x.roll(shifts=(2, 1), dims=(0, 1)) + assert (y.numpy() == [[6,5],[8,7],[2,1],[4,3]]).all() + if __name__ == "__main__": unittest.main() \ No newline at end of file