setitem inplace failure case

This commit is contained in:
Dun Liang 2021-01-07 22:24:23 +08:00
parent 3f858b76e1
commit efcb32b1e1
2 changed files with 39 additions and 7 deletions

View File

@ -114,6 +114,24 @@ class TestSetitem(unittest.TestCase):
assert arr_int32[1,0,0] == 0
arr_int32_res.data[1,1,2] = 0
assert arr_int32[2,1,2] == 0
def test_setitem_inplace_case1(self):
# test type case
a = jt.zeros((3,))
a[1] = 123
assert a.data[1] == 123
def test_setitem_inplace_case2(self):
# test un-continuous first dim
a = jt.zeros((3,))
a[0::2] = jt.ones((2,))
assert a.data[2] == 1
def test_setitem_inplace_case3(self):
# test broadcast
a = jt.zeros((3,))
a[0:] = 1.0
assert a.data[2] == 1
if __name__ == "__main__":
unittest.main()

View File

@ -50,9 +50,11 @@ static void setitem_inplace(SetitemOp* op) {
output->share_with(input);
// return;
// LOGir << "pass setitem optim one";
auto data = op->input(1);
// if setitem requires type conversion, don't inplace
if (data->dtype() != input->dtype())
return;
input_op = input->input();
if (input_op && input_op->inputs().size() == 1) {
@ -72,12 +74,14 @@ static void setitem_inplace(SetitemOp* op) {
return;
auto in_shape = input->shape;
int64 inplace_size = 1;
for (int i = vs.n - 1; i > 0; --i) {
VarSlice s = vs.slices[i];
if (!(s.is_slice())) return;
Slice ss = s.slice;
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
return;
return;
inplace_size *= in_shape[i];
}
VarSlice s = vs.slices[0];
@ -86,12 +90,22 @@ static void setitem_inplace(SetitemOp* op) {
auto size = 0;
if (s.is_int())
size = s.i * input->size / in_shape[0];
else if (s.is_slice())
size = s.slice.start * input->size / in_shape[0];
else if (s.is_slice()) {
Slice ss = s.slice;
// we also need to check the first dim is continuous
if (ss.step != 1)
return;
size = ss.start * input->size / in_shape[0];
inplace_size *= ss.stop - ss.start;
}
if (inplace_size > data->num) {
// if data has been broadcast into input, don't
// inplace data, because their shapes are not match
// This would lead partial setitem
return;
}
add_dependency(data->input(), {input->node()});
data->share_with(input, size);
// LOGir << "pass setitem optim two";
}
struct BBox {