fix slice bug

This commit is contained in:
Dun Liang 2020-04-29 13:36:25 +08:00
parent acfceca90f
commit 4fdb7b339b
3 changed files with 14 additions and 3 deletions

View File

@ -38,6 +38,11 @@ class TestNanoVector(unittest.TestCase):
a += [3,4] a += [3,4]
assert a == [1,2,3,4], a assert a == [1,2,3,4], a
def test_slice_bug(self):
a = jt.NanoVector([2,3,4,5])
assert a[:] == [2,3,4,5]
assert a[1:] == [3,4,5]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -21,7 +21,7 @@ static inline int lzcnt(int64 v) {
} }
struct Slice { struct Slice {
int start, stop, step; int64 start, stop, step, mask;
}; };
// @pyjt(NanoVector) // @pyjt(NanoVector)
@ -108,9 +108,10 @@ struct NanoVector {
// @pyjt(__map_getitem__) // @pyjt(__map_getitem__)
inline NanoVector slice(Slice slice) { inline NanoVector slice(Slice slice) {
if (slice.mask&2) slice.stop = size();
if (slice.start<0) slice.start += size(); if (slice.start<0) slice.start += size();
if (slice.stop<0) slice.stop += size(); if (slice.stop<0) slice.stop += size();
ASSERT(slice.start>=0 && slice.stop>=0 && slice.start<size() && slice.stop<size()) ASSERT(slice.start>=0 && slice.stop>=0 && slice.start<size() && slice.stop<=size())
<< "slice overflow:" << slice.start << slice.stop << slice.step; << "slice overflow:" << slice.start << slice.stop << slice.step;
NanoVector v; NanoVector v;
if (slice.step>0) { if (slice.step>0) {

View File

@ -123,8 +123,13 @@ DEF_IS(Slice, bool) is_type(PyObject* obj) {
} }
DEF_IS(Slice, T) from_py_object(PyObject* obj) { DEF_IS(Slice, T) from_py_object(PyObject* obj) {
Py_ssize_t start, stop, step; Py_ssize_t start, stop, step;
auto slice = (PySliceObject*)obj;
PySlice_Unpack(obj, &start, &stop, &step); PySlice_Unpack(obj, &start, &stop, &step);
return {(int)start, (int)stop, (int)step}; return {start, stop, step,
(slice->start == Py_None) |
((slice->stop == Py_None) << 1) |
((slice->step == Py_None) << 2)};
} }
#define GET_RAW_PTR(T, obj) ((T*)(((char*)obj) + sizeof(PyObject))) #define GET_RAW_PTR(T, obj) ((T*)(((char*)obj) + sizeof(PyObject)))