This commit is contained in:
周文洋 2020-11-27 17:42:22 +08:00
parent 55b68cc294
commit 801eb2241f
3 changed files with 134 additions and 14 deletions

View File

@ -12,8 +12,77 @@ skip_this_test = False
@unittest.skipIf(skip_this_test, "No Torch found")
class TestSetitem(unittest.TestCase):
def test_setitem(self):
arr0 = jt.random((4,2,2))
data0 = jt.ones((2,2))
arr0[1] = data0
arr0.sync()
data0.data[0,0] = 0
assert arr0[1,0,0] == 0
arr00 = jt.random((4,2,2))
data00 = jt.ones((2,2))
# share memory will fail if d has an edge to other nodes.
tmp = data00 + 1
arr00[1] = data00
arr00.sync()
data00.data[0,0] = 0
assert arr00[1,0,0] == 0
arr1 = jt.random((4,2,2))
data1 = jt.zeros((2,2))
arr1[3,:,0:2] = data1
arr1.sync()
data1.data[0,0] = 1
assert arr1[3,0,0] == 1
arr21 = jt.ones((2,2))
arr22 = jt.ones((2,2)) * 2
arr2 = jt.contrib.concat([arr21, arr22], dim=0)
arr2.sync()
arr21.data[0,0] = 3
arr22.data[0,0] = 4
assert arr2[0,0] == 3
assert arr2[2,0] == 4
def test_getitem(self):
# test getitem for float32/float64/bool/int8/int32
# test for different slice type
arr0 = jt.random((4,3))
arr0_res = arr0[2,:]
arr0_res.data[1] = 1
assert arr0[2,1] == 1
arr1 = jt.array([1,2,3,4])
arr1_res = arr1[None]
arr1_res.data[0,2] = -1
assert arr1[2] == -1
arr2 = jt.array([1,2,3,4])
arr2_res = arr2[...]
arr2_res.data[2] = -1
assert arr2[2] == -1
arr3 = jt.array([1,2,3,4])
arr3_res = arr3[3]
arr3_res.data[0] = -1
assert arr3[3] == -1
arr4 = jt.random((4,2,3,3))
arr4_res = arr4[...,:,:]
arr4_res.data[0,0,1,1] = 1
assert arr4[0,0,1,1] == 1
arr5 = jt.random((4,2,3,3))
arr5_res = arr5[1:3,:,:,:]
arr5_res.data[1,0,1,1] = 1
assert arr5[2,0,1,1] == 1
arr6 = jt.random((4,2,3,3))
arr6_res = arr6[1]
arr6_res.data[0,1,1] = 1
assert arr6[1,0,1,1] == 1
# test for different data type (float32/float64/bool/int8/int32)
arr_float32 = jt.random((4,2,3))
arr_float32_res = arr_float32[1:3,:,:]
arr_float32_res.data[0,0,0] = 1

View File

@ -312,6 +312,10 @@ void SetitemOp::jit_run() {
checkCudaErrors(cudaMemcpyAsync(op, ip, out->size, cudaMemcpyDefault, 0));
#endif
if (data->allocation == in->allocation &&
data->allocator == in->allocator)
return;
@for(d, 0, ODIM, for (index_t i@d=0; i@d < oshape@d; i@d++)) {
index_t did = 0 @for(d, 0, ODIM, @if((BMASK>>d)&1,+ i@d * dstride@d));
@for(d, 0, IDIM, index_t iid@d =

View File

@ -17,7 +17,7 @@ inline static bool fast_strcmp(const char* a, const char* b) {
}
static void setitem_inplace(SetitemOp* op) {
// LOGir << "setitem_inplace";
// LOGir << "in setitem_inplace";
auto input = op->inputs().front();
if (!(input->outputs().size() == 1 &&
input->forward_liveness<=1 &&
@ -29,8 +29,7 @@ static void setitem_inplace(SetitemOp* op) {
// make sure input op will not use input
auto input_name = input_op->name();
if (!(input_op->type() == OpType::broadcast ||
fast_strcmp(input_name, "array") ||
fast_strcmp(input_name, "empty") ||
input_op->inputs().size() == 0 ||
fast_strcmp(input_name, "setitem") ||
fast_strcmp(input_name, "getitem")))
// TODO: inplace getitem maybe risky, getitem maybe inplace too
@ -38,7 +37,44 @@ static void setitem_inplace(SetitemOp* op) {
}
auto output = op->outputs().front();
output->share_with(input);
// LOGir << "apply setitem_inplace on" << op << "input:" << input << "output:" << output;
// LOGir << "pass setitem optim one";
auto data = op->input(1);
input_op = input->input();
if (input_op && input_op->inputs().size() == 1) {
input_op = input_op->inputs().front()->input();
}
if (input_op && input_op->inputs().size() == 1) {
input_op = input_op->inputs().front()->input();
}
VarSlices vs = op->vs;
if (!(data->is_finished() == 0 && (data->outputs().size() == 1 || (!input_op || input_op->inputs().size() == 0))))
return;
auto in_shape = input->shape;
for (int i = vs.n - 1; i > 0; --i) {
VarSlice s = vs.slices[i];
if (!(s.is_slice())) return;
Slice ss = s.slice;
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
return;
}
VarSlice s = vs.slices[0];
if (s.is_var()) return;
auto size = 0;
if (s.is_int())
size = s.i * input->size / in_shape[0];
else if (s.is_slice())
size = s.slice.start * input->size / in_shape[0];
data->input()->add_inputs(vector<Var*>{input});
data->share_with(input, size);
// LOGir << "pass setitem optim two";
}
struct BBox {
@ -104,7 +140,7 @@ static void setitem_grad_opt(GetitemOp* op) {
}
static void getitem_inplace(GetitemOp* op) {
// LOGir << "getitem_inplace";
// LOGir << "in getitem_inplace";
auto in = op->inputs().front();
auto ou = op->outputs().front();
@ -115,15 +151,25 @@ static void getitem_inplace(GetitemOp* op) {
VarSlices vs = op->vs;
auto in_shape = in->shape;
auto ou_shape = ou->shape;
for (int i = vs.n - 1; i > 0; --i)
if (!(vs.slices[i].slice.step == 1 && in_shape[i] == ou_shape[i]))
return;
Slice s = vs.slices[0].slice;
ou->share_with(in, (s.stop - s.start) * in->size / in_shape[0] / 2);
return;
for (int i = vs.n - 1; i > 0; --i) {
VarSlice s = vs.slices[i];
if (!(s.is_slice())) return;
Slice ss = s.slice;
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
return;
}
VarSlice s = vs.slices[0];
if (s.is_var()) return;
auto size = 0;
if (s.is_int())
size = s.i * in->size / in_shape[0];
else if (s.is_slice())
size = s.slice.start * in->size / in_shape[0];
ou->share_with(in, size);
// LOGir << "pass getitem_inplace";
}
void SetitemOp::graph_optimize() {
@ -136,6 +182,7 @@ void GetitemOp::graph_optimize() {
// LOGir << "hello getitem graph_optimize";
// setitem_grad_opt(this);
(void)setitem_grad_opt;
// (void)getitem_inplace;
getitem_inplace(this);
}