optimize get item inplace

This commit is contained in:
zhouwy19 2020-11-23 15:11:09 +08:00
parent 4d989e94b1
commit 55b68cc294
4 changed files with 75 additions and 2 deletions

View File

@ -0,0 +1,49 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Wenyang Zhou <576825820@qq.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
skip_this_test = False
@unittest.skipIf(skip_this_test, "No Torch found")
class TestSetitem(unittest.TestCase):
def test_getitem(self):
# test getitem for float32/float64/bool/int8/int32
arr_float32 = jt.random((4,2,3))
arr_float32_res = arr_float32[1:3,:,:]
arr_float32_res.data[0,0,0] = 1
assert arr_float32[1,0,0] == 1
arr_float32_res.data[1,1,2] = 1
assert arr_float32[2,1,2] == 1
arr_float32[1,0,0] = 0
# getitem and setitem do not conflict
assert arr_float32_res[0,0,0] == 1
arr_bool = jt.bool(np.ones((4,2,3)))
arr_bool_res = arr_bool[1:3,:,:]
arr_bool_res.data[0,0,0] = False
assert arr_bool[1,0,0] == False
arr_bool_res.data[0,0,1] = False
assert arr_bool[1,0,1] == False
arr_float64 = jt.random((4,2,3), dtype='float64')
arr_float64_res = arr_float64[1:3,:,:]
arr_float64_res.data[0,0,0] = 1
assert arr_float64[1,0,0] == 1
arr_float64_res.data[1,1,2] = 1
assert arr_float64[2,1,2] == 1
arr_int32 = jt.ones((4,2,3), dtype='int32')
arr_int32_res = arr_int32[1:3,:,:]
arr_int32_res.data[0,0,0] = 0
assert arr_int32[1,0,0] == 0
arr_int32_res.data[1,1,2] = 0
assert arr_int32[2,1,2] == 0
if __name__ == "__main__":
unittest.main()

View File

@ -103,6 +103,29 @@ static void setitem_grad_opt(GetitemOp* op) {
} }
static void getitem_inplace(GetitemOp* op) {
// LOGir << "getitem_inplace";
auto in = op->inputs().front();
auto ou = op->outputs().front();
// return if input or output's shape is variable
if (in->num < 0 || ou->num < 0)
return;
VarSlices vs = op->vs;
auto in_shape = in->shape;
auto ou_shape = ou->shape;
for (int i = vs.n - 1; i > 0; --i)
if (!(vs.slices[i].slice.step == 1 && in_shape[i] == ou_shape[i]))
return;
Slice s = vs.slices[0].slice;
ou->share_with(in, (s.stop - s.start) * in->size / in_shape[0] / 2);
return;
}
void SetitemOp::graph_optimize() { void SetitemOp::graph_optimize() {
// LOGir << "hello graph_optimize"; // LOGir << "hello graph_optimize";
setitem_inplace(this); setitem_inplace(this);
@ -113,6 +136,7 @@ void GetitemOp::graph_optimize() {
// LOGir << "hello getitem graph_optimize"; // LOGir << "hello getitem graph_optimize";
// setitem_grad_opt(this); // setitem_grad_opt(this);
(void)setitem_grad_opt; (void)setitem_grad_opt;
getitem_inplace(this);
} }
} }

View File

@ -64,7 +64,7 @@ bool Var::alloc(Allocator* allocator) {
if (mem_ptr) return true; if (mem_ptr) return true;
if (auto* x = (Var*)(this->allocator)) { if (auto* x = (Var*)(this->allocator)) {
if (x->allocator->share_with(size, x->allocation)) { if (x->allocator->share_with(size, x->allocation)) {
mem_ptr = x->mem_ptr; mem_ptr = ((char*) x->mem_ptr) + allocation;
allocation = x->allocation; allocation = x->allocation;
this->allocator = x->allocator; this->allocator = x->allocator;
return true; return true;

View File

@ -42,7 +42,7 @@ struct Var : Node {
int64_t numel(); int64_t numel();
void set_shape(NanoVector shape); void set_shape(NanoVector shape);
bool alloc(Allocator* allocator); bool alloc(Allocator* allocator);
inline void share_with(Var* x) { CHECK_EXIST; allocator = (Allocator*)x; } inline void share_with(Var* x, size_t offset = 0) { CHECK_EXIST; allocator = (Allocator*)x; allocation = offset; }
}; };
struct VarPtr { struct VarPtr {