mirror of https://github.com/Jittor/Jittor
optimize get item inplace
This commit is contained in:
parent
4d989e94b1
commit
55b68cc294
|
@ -0,0 +1,49 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Wenyang Zhou <576825820@qq.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
skip_this_test = False
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestSetitem(unittest.TestCase):
|
||||
def test_getitem(self):
|
||||
# test getitem for float32/float64/bool/int8/int32
|
||||
arr_float32 = jt.random((4,2,3))
|
||||
arr_float32_res = arr_float32[1:3,:,:]
|
||||
arr_float32_res.data[0,0,0] = 1
|
||||
assert arr_float32[1,0,0] == 1
|
||||
arr_float32_res.data[1,1,2] = 1
|
||||
assert arr_float32[2,1,2] == 1
|
||||
arr_float32[1,0,0] = 0
|
||||
# getitem and setitem do not conflict
|
||||
assert arr_float32_res[0,0,0] == 1
|
||||
|
||||
arr_bool = jt.bool(np.ones((4,2,3)))
|
||||
arr_bool_res = arr_bool[1:3,:,:]
|
||||
arr_bool_res.data[0,0,0] = False
|
||||
assert arr_bool[1,0,0] == False
|
||||
arr_bool_res.data[0,0,1] = False
|
||||
assert arr_bool[1,0,1] == False
|
||||
|
||||
arr_float64 = jt.random((4,2,3), dtype='float64')
|
||||
arr_float64_res = arr_float64[1:3,:,:]
|
||||
arr_float64_res.data[0,0,0] = 1
|
||||
assert arr_float64[1,0,0] == 1
|
||||
arr_float64_res.data[1,1,2] = 1
|
||||
assert arr_float64[2,1,2] == 1
|
||||
|
||||
arr_int32 = jt.ones((4,2,3), dtype='int32')
|
||||
arr_int32_res = arr_int32[1:3,:,:]
|
||||
arr_int32_res.data[0,0,0] = 0
|
||||
assert arr_int32[1,0,0] == 0
|
||||
arr_int32_res.data[1,1,2] = 0
|
||||
assert arr_int32[2,1,2] == 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -103,6 +103,29 @@ static void setitem_grad_opt(GetitemOp* op) {
|
|||
|
||||
}
|
||||
|
||||
static void getitem_inplace(GetitemOp* op) {
|
||||
// LOGir << "getitem_inplace";
|
||||
|
||||
auto in = op->inputs().front();
|
||||
auto ou = op->outputs().front();
|
||||
|
||||
// return if input or output's shape is variable
|
||||
if (in->num < 0 || ou->num < 0)
|
||||
return;
|
||||
|
||||
VarSlices vs = op->vs;
|
||||
auto in_shape = in->shape;
|
||||
auto ou_shape = ou->shape;
|
||||
|
||||
for (int i = vs.n - 1; i > 0; --i)
|
||||
if (!(vs.slices[i].slice.step == 1 && in_shape[i] == ou_shape[i]))
|
||||
return;
|
||||
|
||||
Slice s = vs.slices[0].slice;
|
||||
ou->share_with(in, (s.stop - s.start) * in->size / in_shape[0] / 2);
|
||||
return;
|
||||
}
|
||||
|
||||
void SetitemOp::graph_optimize() {
|
||||
// LOGir << "hello graph_optimize";
|
||||
setitem_inplace(this);
|
||||
|
@ -113,6 +136,7 @@ void GetitemOp::graph_optimize() {
|
|||
// LOGir << "hello getitem graph_optimize";
|
||||
// setitem_grad_opt(this);
|
||||
(void)setitem_grad_opt;
|
||||
getitem_inplace(this);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ bool Var::alloc(Allocator* allocator) {
|
|||
if (mem_ptr) return true;
|
||||
if (auto* x = (Var*)(this->allocator)) {
|
||||
if (x->allocator->share_with(size, x->allocation)) {
|
||||
mem_ptr = x->mem_ptr;
|
||||
mem_ptr = ((char*) x->mem_ptr) + allocation;
|
||||
allocation = x->allocation;
|
||||
this->allocator = x->allocator;
|
||||
return true;
|
||||
|
|
|
@ -42,7 +42,7 @@ struct Var : Node {
|
|||
int64_t numel();
|
||||
void set_shape(NanoVector shape);
|
||||
bool alloc(Allocator* allocator);
|
||||
inline void share_with(Var* x) { CHECK_EXIST; allocator = (Allocator*)x; }
|
||||
inline void share_with(Var* x, size_t offset = 0) { CHECK_EXIST; allocator = (Allocator*)x; allocation = offset; }
|
||||
};
|
||||
|
||||
struct VarPtr {
|
||||
|
|
Loading…
Reference in New Issue