diff --git a/python/jittor/test/test_setitem.py b/python/jittor/test/test_setitem.py new file mode 100644 index 00000000..60cb2d57 --- /dev/null +++ b/python/jittor/test/test_setitem.py @@ -0,0 +1,49 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: +# Wenyang Zhou <576825820@qq.com>. +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +skip_this_test = False + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestSetitem(unittest.TestCase): + def test_getitem(self): + # test getitem for float32/float64/bool/int8/int32 + arr_float32 = jt.random((4,2,3)) + arr_float32_res = arr_float32[1:3,:,:] + arr_float32_res.data[0,0,0] = 1 + assert arr_float32[1,0,0] == 1 + arr_float32_res.data[1,1,2] = 1 + assert arr_float32[2,1,2] == 1 + arr_float32[1,0,0] = 0 + # getitem and setitem do not conflict + assert arr_float32_res[0,0,0] == 1 + + arr_bool = jt.bool(np.ones((4,2,3))) + arr_bool_res = arr_bool[1:3,:,:] + arr_bool_res.data[0,0,0] = False + assert arr_bool[1,0,0] == False + arr_bool_res.data[0,0,1] = False + assert arr_bool[1,0,1] == False + + arr_float64 = jt.random((4,2,3), dtype='float64') + arr_float64_res = arr_float64[1:3,:,:] + arr_float64_res.data[0,0,0] = 1 + assert arr_float64[1,0,0] == 1 + arr_float64_res.data[1,1,2] = 1 + assert arr_float64[2,1,2] == 1 + + arr_int32 = jt.ones((4,2,3), dtype='int32') + arr_int32_res = arr_int32[1:3,:,:] + arr_int32_res.data[0,0,0] = 0 + assert arr_int32[1,0,0] == 0 + arr_int32_res.data[1,1,2] = 0 + assert arr_int32[2,1,2] == 0 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/src/opt/gopt/setitem_gopt.cc b/src/opt/gopt/setitem_gopt.cc index e1ace09a..fa5ec404 100644 --- a/src/opt/gopt/setitem_gopt.cc +++ b/src/opt/gopt/setitem_gopt.cc @@ -103,6 +103,29 @@ static void setitem_grad_opt(GetitemOp* op) { } +static void getitem_inplace(GetitemOp* op) { + // LOGir << "getitem_inplace"; + + auto in = op->inputs().front(); + auto ou = op->outputs().front(); + + // return if input or output's shape is variable + if (in->num < 0 || ou->num < 0) + return; + + VarSlices vs = op->vs; + auto in_shape = in->shape; + auto ou_shape = ou->shape; + + for (int i = vs.n - 1; i > 0; --i) + if (!(vs.slices[i].slice.step == 1 && in_shape[i] == ou_shape[i])) + return; + + Slice s = vs.slices[0].slice; + ou->share_with(in, (s.stop - s.start) * in->size / in_shape[0] / 2); + return; +} + void SetitemOp::graph_optimize() { // LOGir << "hello graph_optimize"; setitem_inplace(this); @@ -113,6 +136,7 @@ void GetitemOp::graph_optimize() { // LOGir << "hello getitem graph_optimize"; // setitem_grad_opt(this); (void)setitem_grad_opt; + getitem_inplace(this); } } diff --git a/src/var.cc b/src/var.cc index a325652d..32b16d51 100644 --- a/src/var.cc +++ b/src/var.cc @@ -64,7 +64,7 @@ bool Var::alloc(Allocator* allocator) { if (mem_ptr) return true; if (auto* x = (Var*)(this->allocator)) { if (x->allocator->share_with(size, x->allocation)) { - mem_ptr = x->mem_ptr; + mem_ptr = ((char*) x->mem_ptr) + allocation; allocation = x->allocation; this->allocator = x->allocator; return true; diff --git a/src/var.h b/src/var.h index a3e58368..37edcd12 100644 --- a/src/var.h +++ b/src/var.h @@ -42,7 +42,7 @@ struct Var : Node { int64_t numel(); void set_shape(NanoVector shape); bool alloc(Allocator* allocator); - inline void share_with(Var* x) { CHECK_EXIST; allocator = (Allocator*)x; } + inline void share_with(Var* x, size_t offset = 0) { CHECK_EXIST; allocator = (Allocator*)x; allocation = offset; } }; struct VarPtr {