mirror of https://github.com/Jittor/Jittor
setitem inplace failure case
This commit is contained in:
parent
3f858b76e1
commit
efcb32b1e1
|
@ -114,6 +114,24 @@ class TestSetitem(unittest.TestCase):
|
|||
assert arr_int32[1,0,0] == 0
|
||||
arr_int32_res.data[1,1,2] = 0
|
||||
assert arr_int32[2,1,2] == 0
|
||||
|
||||
def test_setitem_inplace_case1(self):
|
||||
# test type case
|
||||
a = jt.zeros((3,))
|
||||
a[1] = 123
|
||||
assert a.data[1] == 123
|
||||
|
||||
def test_setitem_inplace_case2(self):
|
||||
# test un-continuous first dim
|
||||
a = jt.zeros((3,))
|
||||
a[0::2] = jt.ones((2,))
|
||||
assert a.data[2] == 1
|
||||
|
||||
def test_setitem_inplace_case3(self):
|
||||
# test broadcast
|
||||
a = jt.zeros((3,))
|
||||
a[0:] = 1.0
|
||||
assert a.data[2] == 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -50,9 +50,11 @@ static void setitem_inplace(SetitemOp* op) {
|
|||
output->share_with(input);
|
||||
// return;
|
||||
|
||||
// LOGir << "pass setitem optim one";
|
||||
|
||||
auto data = op->input(1);
|
||||
// if setitem requires type conversion, don't inplace
|
||||
if (data->dtype() != input->dtype())
|
||||
return;
|
||||
|
||||
input_op = input->input();
|
||||
|
||||
if (input_op && input_op->inputs().size() == 1) {
|
||||
|
@ -72,12 +74,14 @@ static void setitem_inplace(SetitemOp* op) {
|
|||
return;
|
||||
|
||||
auto in_shape = input->shape;
|
||||
int64 inplace_size = 1;
|
||||
for (int i = vs.n - 1; i > 0; --i) {
|
||||
VarSlice s = vs.slices[i];
|
||||
if (!(s.is_slice())) return;
|
||||
Slice ss = s.slice;
|
||||
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
|
||||
return;
|
||||
return;
|
||||
inplace_size *= in_shape[i];
|
||||
}
|
||||
|
||||
VarSlice s = vs.slices[0];
|
||||
|
@ -86,12 +90,22 @@ static void setitem_inplace(SetitemOp* op) {
|
|||
auto size = 0;
|
||||
if (s.is_int())
|
||||
size = s.i * input->size / in_shape[0];
|
||||
else if (s.is_slice())
|
||||
size = s.slice.start * input->size / in_shape[0];
|
||||
|
||||
else if (s.is_slice()) {
|
||||
Slice ss = s.slice;
|
||||
// we also need to check the first dim is continuous
|
||||
if (ss.step != 1)
|
||||
return;
|
||||
size = ss.start * input->size / in_shape[0];
|
||||
inplace_size *= ss.stop - ss.start;
|
||||
}
|
||||
if (inplace_size > data->num) {
|
||||
// if data has been broadcast into input, don't
|
||||
// inplace data, because their shapes are not match
|
||||
// This would lead partial setitem
|
||||
return;
|
||||
}
|
||||
add_dependency(data->input(), {input->node()});
|
||||
data->share_with(input, size);
|
||||
// LOGir << "pass setitem optim two";
|
||||
}
|
||||
|
||||
struct BBox {
|
||||
|
|
Loading…
Reference in New Issue