diff --git a/python/jittor/test/test_nano_vector.py b/python/jittor/test/test_nano_vector.py index 7697305a..0451dce2 100644 --- a/python/jittor/test/test_nano_vector.py +++ b/python/jittor/test/test_nano_vector.py @@ -38,6 +38,11 @@ class TestNanoVector(unittest.TestCase): a += [3,4] assert a == [1,2,3,4], a + def test_slice_bug(self): + a = jt.NanoVector([2,3,4,5]) + assert a[:] == [2,3,4,5] + assert a[1:] == [3,4,5] + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/src/misc/nano_vector.h b/src/misc/nano_vector.h index 59dfb6d9..32b99cc0 100644 --- a/src/misc/nano_vector.h +++ b/src/misc/nano_vector.h @@ -21,7 +21,7 @@ static inline int lzcnt(int64 v) { } struct Slice { - int start, stop, step; + int64 start, stop, step, mask; }; // @pyjt(NanoVector) @@ -108,9 +108,10 @@ struct NanoVector { // @pyjt(__map_getitem__) inline NanoVector slice(Slice slice) { + if (slice.mask&2) slice.stop = size(); if (slice.start<0) slice.start += size(); if (slice.stop<0) slice.stop += size(); - ASSERT(slice.start>=0 && slice.stop>=0 && slice.start=0 && slice.stop>=0 && slice.start0) { diff --git a/src/pyjt/py_converter.h b/src/pyjt/py_converter.h index dde2897e..1e497359 100644 --- a/src/pyjt/py_converter.h +++ b/src/pyjt/py_converter.h @@ -123,8 +123,13 @@ DEF_IS(Slice, bool) is_type(PyObject* obj) { } DEF_IS(Slice, T) from_py_object(PyObject* obj) { Py_ssize_t start, stop, step; + auto slice = (PySliceObject*)obj; + PySlice_Unpack(obj, &start, &stop, &step); - return {(int)start, (int)stop, (int)step}; + return {start, stop, step, + (slice->start == Py_None) | + ((slice->stop == Py_None) << 1) | + ((slice->step == Py_None) << 2)}; } #define GET_RAW_PTR(T, obj) ((T*)(((char*)obj) + sizeof(PyObject)))