mirror of https://github.com/Jittor/Jittor
update
This commit is contained in:
parent
55b68cc294
commit
801eb2241f
|
@ -12,8 +12,77 @@ skip_this_test = False
|
|||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestSetitem(unittest.TestCase):
|
||||
def test_setitem(self):
|
||||
arr0 = jt.random((4,2,2))
|
||||
data0 = jt.ones((2,2))
|
||||
arr0[1] = data0
|
||||
arr0.sync()
|
||||
data0.data[0,0] = 0
|
||||
assert arr0[1,0,0] == 0
|
||||
|
||||
arr00 = jt.random((4,2,2))
|
||||
data00 = jt.ones((2,2))
|
||||
# share memory will fail if d has an edge to other nodes.
|
||||
tmp = data00 + 1
|
||||
arr00[1] = data00
|
||||
arr00.sync()
|
||||
data00.data[0,0] = 0
|
||||
assert arr00[1,0,0] == 0
|
||||
|
||||
arr1 = jt.random((4,2,2))
|
||||
data1 = jt.zeros((2,2))
|
||||
arr1[3,:,0:2] = data1
|
||||
arr1.sync()
|
||||
data1.data[0,0] = 1
|
||||
assert arr1[3,0,0] == 1
|
||||
|
||||
arr21 = jt.ones((2,2))
|
||||
arr22 = jt.ones((2,2)) * 2
|
||||
arr2 = jt.contrib.concat([arr21, arr22], dim=0)
|
||||
arr2.sync()
|
||||
arr21.data[0,0] = 3
|
||||
arr22.data[0,0] = 4
|
||||
assert arr2[0,0] == 3
|
||||
assert arr2[2,0] == 4
|
||||
|
||||
def test_getitem(self):
|
||||
# test getitem for float32/float64/bool/int8/int32
|
||||
# test for different slice type
|
||||
arr0 = jt.random((4,3))
|
||||
arr0_res = arr0[2,:]
|
||||
arr0_res.data[1] = 1
|
||||
assert arr0[2,1] == 1
|
||||
|
||||
arr1 = jt.array([1,2,3,4])
|
||||
arr1_res = arr1[None]
|
||||
arr1_res.data[0,2] = -1
|
||||
assert arr1[2] == -1
|
||||
|
||||
arr2 = jt.array([1,2,3,4])
|
||||
arr2_res = arr2[...]
|
||||
arr2_res.data[2] = -1
|
||||
assert arr2[2] == -1
|
||||
|
||||
arr3 = jt.array([1,2,3,4])
|
||||
arr3_res = arr3[3]
|
||||
arr3_res.data[0] = -1
|
||||
assert arr3[3] == -1
|
||||
|
||||
arr4 = jt.random((4,2,3,3))
|
||||
arr4_res = arr4[...,:,:]
|
||||
arr4_res.data[0,0,1,1] = 1
|
||||
assert arr4[0,0,1,1] == 1
|
||||
|
||||
arr5 = jt.random((4,2,3,3))
|
||||
arr5_res = arr5[1:3,:,:,:]
|
||||
arr5_res.data[1,0,1,1] = 1
|
||||
assert arr5[2,0,1,1] == 1
|
||||
|
||||
arr6 = jt.random((4,2,3,3))
|
||||
arr6_res = arr6[1]
|
||||
arr6_res.data[0,1,1] = 1
|
||||
assert arr6[1,0,1,1] == 1
|
||||
|
||||
# test for different data type (float32/float64/bool/int8/int32)
|
||||
arr_float32 = jt.random((4,2,3))
|
||||
arr_float32_res = arr_float32[1:3,:,:]
|
||||
arr_float32_res.data[0,0,0] = 1
|
||||
|
|
|
@ -312,6 +312,10 @@ void SetitemOp::jit_run() {
|
|||
checkCudaErrors(cudaMemcpyAsync(op, ip, out->size, cudaMemcpyDefault, 0));
|
||||
#endif
|
||||
|
||||
if (data->allocation == in->allocation &&
|
||||
data->allocator == in->allocator)
|
||||
return;
|
||||
|
||||
@for(d, 0, ODIM, for (index_t i@d=0; i@d < oshape@d; i@d++)) {
|
||||
index_t did = 0 @for(d, 0, ODIM, @if((BMASK>>d)&1,+ i@d * dstride@d));
|
||||
@for(d, 0, IDIM, index_t iid@d =
|
||||
|
|
|
@ -17,7 +17,7 @@ inline static bool fast_strcmp(const char* a, const char* b) {
|
|||
}
|
||||
|
||||
static void setitem_inplace(SetitemOp* op) {
|
||||
// LOGir << "setitem_inplace";
|
||||
// LOGir << "in setitem_inplace";
|
||||
auto input = op->inputs().front();
|
||||
if (!(input->outputs().size() == 1 &&
|
||||
input->forward_liveness<=1 &&
|
||||
|
@ -29,8 +29,7 @@ static void setitem_inplace(SetitemOp* op) {
|
|||
// make sure input op will not use input
|
||||
auto input_name = input_op->name();
|
||||
if (!(input_op->type() == OpType::broadcast ||
|
||||
fast_strcmp(input_name, "array") ||
|
||||
fast_strcmp(input_name, "empty") ||
|
||||
input_op->inputs().size() == 0 ||
|
||||
fast_strcmp(input_name, "setitem") ||
|
||||
fast_strcmp(input_name, "getitem")))
|
||||
// TODO: inplace getitem maybe risky, getitem maybe inplace too
|
||||
|
@ -38,7 +37,44 @@ static void setitem_inplace(SetitemOp* op) {
|
|||
}
|
||||
auto output = op->outputs().front();
|
||||
output->share_with(input);
|
||||
// LOGir << "apply setitem_inplace on" << op << "input:" << input << "output:" << output;
|
||||
|
||||
// LOGir << "pass setitem optim one";
|
||||
|
||||
auto data = op->input(1);
|
||||
input_op = input->input();
|
||||
|
||||
if (input_op && input_op->inputs().size() == 1) {
|
||||
input_op = input_op->inputs().front()->input();
|
||||
}
|
||||
if (input_op && input_op->inputs().size() == 1) {
|
||||
input_op = input_op->inputs().front()->input();
|
||||
}
|
||||
|
||||
VarSlices vs = op->vs;
|
||||
if (!(data->is_finished() == 0 && (data->outputs().size() == 1 || (!input_op || input_op->inputs().size() == 0))))
|
||||
return;
|
||||
|
||||
auto in_shape = input->shape;
|
||||
for (int i = vs.n - 1; i > 0; --i) {
|
||||
VarSlice s = vs.slices[i];
|
||||
if (!(s.is_slice())) return;
|
||||
Slice ss = s.slice;
|
||||
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
|
||||
return;
|
||||
}
|
||||
|
||||
VarSlice s = vs.slices[0];
|
||||
if (s.is_var()) return;
|
||||
|
||||
auto size = 0;
|
||||
if (s.is_int())
|
||||
size = s.i * input->size / in_shape[0];
|
||||
else if (s.is_slice())
|
||||
size = s.slice.start * input->size / in_shape[0];
|
||||
|
||||
data->input()->add_inputs(vector<Var*>{input});
|
||||
data->share_with(input, size);
|
||||
// LOGir << "pass setitem optim two";
|
||||
}
|
||||
|
||||
struct BBox {
|
||||
|
@ -104,7 +140,7 @@ static void setitem_grad_opt(GetitemOp* op) {
|
|||
}
|
||||
|
||||
static void getitem_inplace(GetitemOp* op) {
|
||||
// LOGir << "getitem_inplace";
|
||||
// LOGir << "in getitem_inplace";
|
||||
|
||||
auto in = op->inputs().front();
|
||||
auto ou = op->outputs().front();
|
||||
|
@ -115,15 +151,25 @@ static void getitem_inplace(GetitemOp* op) {
|
|||
|
||||
VarSlices vs = op->vs;
|
||||
auto in_shape = in->shape;
|
||||
auto ou_shape = ou->shape;
|
||||
|
||||
for (int i = vs.n - 1; i > 0; --i)
|
||||
if (!(vs.slices[i].slice.step == 1 && in_shape[i] == ou_shape[i]))
|
||||
return;
|
||||
|
||||
Slice s = vs.slices[0].slice;
|
||||
ou->share_with(in, (s.stop - s.start) * in->size / in_shape[0] / 2);
|
||||
return;
|
||||
for (int i = vs.n - 1; i > 0; --i) {
|
||||
VarSlice s = vs.slices[i];
|
||||
if (!(s.is_slice())) return;
|
||||
Slice ss = s.slice;
|
||||
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
|
||||
return;
|
||||
}
|
||||
|
||||
VarSlice s = vs.slices[0];
|
||||
if (s.is_var()) return;
|
||||
|
||||
auto size = 0;
|
||||
if (s.is_int())
|
||||
size = s.i * in->size / in_shape[0];
|
||||
else if (s.is_slice())
|
||||
size = s.slice.start * in->size / in_shape[0];
|
||||
ou->share_with(in, size);
|
||||
// LOGir << "pass getitem_inplace";
|
||||
}
|
||||
|
||||
void SetitemOp::graph_optimize() {
|
||||
|
@ -136,6 +182,7 @@ void GetitemOp::graph_optimize() {
|
|||
// LOGir << "hello getitem graph_optimize";
|
||||
// setitem_grad_opt(this);
|
||||
(void)setitem_grad_opt;
|
||||
// (void)getitem_inplace;
|
||||
getitem_inplace(this);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue