polish extra none slice

This commit is contained in:
Dun Liang 2021-03-08 22:23:06 +08:00
parent fcaf0f9da5
commit a0f2516626
5 changed files with 13 additions and 2 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.39'
__version__ = '1.2.2.40'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -1118,7 +1118,7 @@ Example::
'''
shape = index.shape
if src.shape != shape:
if src.shape != shape and src.numel() != 1:
src = src[tuple( slice(None,s) for s in shape )]
indexes = [ f'i{i}' for i in range(len(shape)) ]
indexes[dim] = index

View File

@ -197,6 +197,9 @@ class TestSetitem(unittest.TestCase):
a[b] = jt.array([-1,-2])
assert (a.data == [-1,2,-2,4]).all()
def test_slice_none(self):
a = jt.array([1,2])
assert a[None,:,None,None,...,None].shape == (1,2,1,1,1)
if __name__ == "__main__":

View File

@ -150,6 +150,13 @@ void GetitemOp::infer_slices(
}
}
}
while (vid < vs.n) {
auto& s = vs.slices[vid++];
if (s.is_none()) {
out_shape.push_back(1);
} else
CHECK(s.is_ellipsis()) << "Too many slices" << vs << "shape:" << in->shape;
}
}
void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims) {

View File

@ -21,6 +21,7 @@ std::ostream& operator<<(std::ostream& os, const VarSlice& s) {
if (s.is_ellipsis()) return os << "...";
if (s.is_slice()) return os << s.slice;
if (s.is_int()) return os << s.i;
if (s.is_str()) return os << (const char*)&s;
return os << "-";
}