mirror of https://github.com/Jittor/Jittor
polish slice ellipsis_with_none
This commit is contained in:
parent
51a28574c8
commit
7052251098
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.77'
|
||||
__version__ = '1.2.3.78'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -93,6 +93,9 @@ void GetitemOp::infer_slices(
|
|||
} else
|
||||
if (s.is_ellipsis()) {
|
||||
auto remain_slice = vs.n-vid-1;
|
||||
for (int i=vid+1; i<vs.n; i++)
|
||||
if (vs.slices[i].is_none())
|
||||
remain_slice--;
|
||||
auto remain_idims = nin-i;
|
||||
auto ellipsis_size = remain_idims - remain_slice;
|
||||
ASSERT(ellipsis_size>=0) << "NDims not match";
|
||||
|
|
|
@ -220,6 +220,13 @@ class TestSetitem(unittest.TestCase):
|
|||
y = x.roll(shifts=(2, 1), dims=(0, 1))
|
||||
assert (y.numpy() == [[6,5],[8,7],[2,1],[4,3]]).all()
|
||||
|
||||
def test_ellipsis_with_none(self):
|
||||
a = jt.arange(2*4*4).reshape(2,4,4)
|
||||
b = a[...,:,None,:2]
|
||||
assert b.shape == [2,4,1,2]
|
||||
np.testing.assert_allclose(b.data, a.data[...,:,None,:2])
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue