mirror of https://github.com/Jittor/Jittor
polish extra none slice
This commit is contained in:
parent
fcaf0f9da5
commit
a0f2516626
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.39'
|
||||
__version__ = '1.2.2.40'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -1118,7 +1118,7 @@ Example::
|
|||
|
||||
'''
|
||||
shape = index.shape
|
||||
if src.shape != shape:
|
||||
if src.shape != shape and src.numel() != 1:
|
||||
src = src[tuple( slice(None,s) for s in shape )]
|
||||
indexes = [ f'i{i}' for i in range(len(shape)) ]
|
||||
indexes[dim] = index
|
||||
|
|
|
@ -197,6 +197,9 @@ class TestSetitem(unittest.TestCase):
|
|||
a[b] = jt.array([-1,-2])
|
||||
assert (a.data == [-1,2,-2,4]).all()
|
||||
|
||||
def test_slice_none(self):
|
||||
a = jt.array([1,2])
|
||||
assert a[None,:,None,None,...,None].shape == (1,2,1,1,1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -150,6 +150,13 @@ void GetitemOp::infer_slices(
|
|||
}
|
||||
}
|
||||
}
|
||||
while (vid < vs.n) {
|
||||
auto& s = vs.slices[vid++];
|
||||
if (s.is_none()) {
|
||||
out_shape.push_back(1);
|
||||
} else
|
||||
CHECK(s.is_ellipsis()) << "Too many slices" << vs << "shape:" << in->shape;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims) {
|
||||
|
|
|
@ -21,6 +21,7 @@ std::ostream& operator<<(std::ostream& os, const VarSlice& s) {
|
|||
if (s.is_ellipsis()) return os << "...";
|
||||
if (s.is_slice()) return os << s.slice;
|
||||
if (s.is_int()) return os << s.i;
|
||||
if (s.is_str()) return os << (const char*)&s;
|
||||
return os << "-";
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue