polish slice ellipsis_with_none

This commit is contained in:
Dun Liang 2021-07-23 14:38:45 +08:00
parent 51a28574c8
commit 7052251098
3 changed files with 11 additions and 1 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.77'
__version__ = '1.2.3.78'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -93,6 +93,9 @@ void GetitemOp::infer_slices(
} else
if (s.is_ellipsis()) {
auto remain_slice = vs.n-vid-1;
for (int i=vid+1; i<vs.n; i++)
if (vs.slices[i].is_none())
remain_slice--;
auto remain_idims = nin-i;
auto ellipsis_size = remain_idims - remain_slice;
ASSERT(ellipsis_size>=0) << "NDims not match";

View File

@ -220,6 +220,13 @@ class TestSetitem(unittest.TestCase):
y = x.roll(shifts=(2, 1), dims=(0, 1))
assert (y.numpy() == [[6,5],[8,7],[2,1],[4,3]]).all()
def test_ellipsis_with_none(self):
a = jt.arange(2*4*4).reshape(2,4,4)
b = a[...,:,None,:2]
assert b.shape == [2,4,1,2]
np.testing.assert_allclose(b.data, a.data[...,:,None,:2])
if __name__ == "__main__":
unittest.main()