mirror of https://github.com/Jittor/Jittor
fix slice bug
This commit is contained in:
parent
acfceca90f
commit
4fdb7b339b
|
@ -38,6 +38,11 @@ class TestNanoVector(unittest.TestCase):
|
||||||
a += [3,4]
|
a += [3,4]
|
||||||
assert a == [1,2,3,4], a
|
assert a == [1,2,3,4], a
|
||||||
|
|
||||||
|
def test_slice_bug(self):
|
||||||
|
a = jt.NanoVector([2,3,4,5])
|
||||||
|
assert a[:] == [2,3,4,5]
|
||||||
|
assert a[1:] == [3,4,5]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
|
@ -21,7 +21,7 @@ static inline int lzcnt(int64 v) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Slice {
|
struct Slice {
|
||||||
int start, stop, step;
|
int64 start, stop, step, mask;
|
||||||
};
|
};
|
||||||
|
|
||||||
// @pyjt(NanoVector)
|
// @pyjt(NanoVector)
|
||||||
|
@ -108,9 +108,10 @@ struct NanoVector {
|
||||||
|
|
||||||
// @pyjt(__map_getitem__)
|
// @pyjt(__map_getitem__)
|
||||||
inline NanoVector slice(Slice slice) {
|
inline NanoVector slice(Slice slice) {
|
||||||
|
if (slice.mask&2) slice.stop = size();
|
||||||
if (slice.start<0) slice.start += size();
|
if (slice.start<0) slice.start += size();
|
||||||
if (slice.stop<0) slice.stop += size();
|
if (slice.stop<0) slice.stop += size();
|
||||||
ASSERT(slice.start>=0 && slice.stop>=0 && slice.start<size() && slice.stop<size())
|
ASSERT(slice.start>=0 && slice.stop>=0 && slice.start<size() && slice.stop<=size())
|
||||||
<< "slice overflow:" << slice.start << slice.stop << slice.step;
|
<< "slice overflow:" << slice.start << slice.stop << slice.step;
|
||||||
NanoVector v;
|
NanoVector v;
|
||||||
if (slice.step>0) {
|
if (slice.step>0) {
|
||||||
|
|
|
@ -123,8 +123,13 @@ DEF_IS(Slice, bool) is_type(PyObject* obj) {
|
||||||
}
|
}
|
||||||
DEF_IS(Slice, T) from_py_object(PyObject* obj) {
|
DEF_IS(Slice, T) from_py_object(PyObject* obj) {
|
||||||
Py_ssize_t start, stop, step;
|
Py_ssize_t start, stop, step;
|
||||||
|
auto slice = (PySliceObject*)obj;
|
||||||
|
|
||||||
PySlice_Unpack(obj, &start, &stop, &step);
|
PySlice_Unpack(obj, &start, &stop, &step);
|
||||||
return {(int)start, (int)stop, (int)step};
|
return {start, stop, step,
|
||||||
|
(slice->start == Py_None) |
|
||||||
|
((slice->stop == Py_None) << 1) |
|
||||||
|
((slice->step == Py_None) << 2)};
|
||||||
}
|
}
|
||||||
|
|
||||||
#define GET_RAW_PTR(T, obj) ((T*)(((char*)obj) + sizeof(PyObject)))
|
#define GET_RAW_PTR(T, obj) ((T*)(((char*)obj) + sizeof(PyObject)))
|
||||||
|
|
Loading…
Reference in New Issue