polish gopt setitem concat

This commit is contained in:
Dun Liang 2022-04-02 23:02:58 +08:00
parent da0dc2cfba
commit 18795bd02f
7 changed files with 110 additions and 68 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.1.63'
__version__ = '1.3.2.0'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -465,7 +465,8 @@ void GetitemOp::jit_run() {
auto in = inputs().front();
auto out = outputs().front();
if (out->num == 0) return;
if (in->allocator == out->allocator &&
if (ns.get(GetitemOp::_inplace) &&
in->allocator == out->allocator &&
in->allocation == out->allocation)
return;

View File

@ -12,6 +12,7 @@
namespace jittor {
struct GetitemOp : Op {
static constexpr jittor::NanoString::Flags _inplace = (jittor::NanoString::Flags)0;
VarSlices vs;
// map i to related var slice
NanoVector i_to_vs;

View File

@ -315,7 +315,7 @@ void SetitemOp::jit_run() {
checkCudaErrors(cudaMemcpyAsync(op, ip, out->size, cudaMemcpyDeviceToDevice, 0));
#endif
if (flags.get((NodeFlags::Flags(SetitemOp::_data_inplaced))) &&
if (ns.get(GetitemOp::_inplace) &&
// array op may move the data allocation, double check
// affect test_contrib.pu
in->allocator == data->allocator &&

View File

@ -11,7 +11,6 @@
namespace jittor {
struct SetitemOp : Op {
static constexpr int _data_inplaced = NodeFlags::_has_vary_input + 1;
VarSlices vs;
// map i to related var slice
NanoVector i_to_vs;

View File

@ -107,69 +107,9 @@ static void setitem_inplace(SetitemOp* op) {
}
add_dependency(data->input(), {input->node()});
data->share_with(input, size);
op->flags.set((NodeFlags::Flags(SetitemOp::_data_inplaced)));
}
struct BBox {
int n = 0;
int* minmax = nullptr;
void load_var_slice(const VarSlice& vs) {
}
};
static void setitem_grad_opt(GetitemOp* op) {
if (!op->flags.get(NodeFlags::_has_gopt))
return;
auto get_in = op->inputs().front();
auto get_in_op = get_in->input();
if (!get_in_op)
return;
auto name = get_in_op->name();
if (!fast_strcmp(name, "setitem"))
return;
// find setitem op chain
auto first_set = (SetitemOp*)get_in_op;
vector<SetitemOp*> chain;
while (1) {
auto next = first_set->inputs().front()->input();
if (!next) break;
if (!fast_strcmp(next->name(), "setitem"))
break;
chain.push_back(first_set);
first_set = (SetitemOp*)next;
}
chain.push_back(first_set);
for (int i=0; i<chain.size()/2; i++)
std::swap(chain[i], chain[chain.size()-1-i]);
auto last_set = (SetitemOp*)get_in_op;
while (1) {
SetitemOp* next = nullptr;
auto out_var = last_set->outputs().front();
for (auto* out : out_var->outputs()) {
if (fast_strcmp(out->name(), "setitem")) {
next = (SetitemOp*)out;
break;
}
}
if (!next) break;
last_set = next;
chain.push_back(next);
}
// LOGir << "find setitem chain" << chain.size() << chain;
for (auto* sop : chain) {
// LOGig << sop << sop->vs;
auto out_var = sop->outputs().front();
for (auto* out : out_var->outputs()) {
if (fast_strcmp(out->name(), "getitem")) {
out->flags.set(NodeFlags::_has_gopt, 0);
}
}
}
op->ns.set(GetitemOp::_inplace);
// LOGir << input->shape << input->dtype() << data->shape << data->dtype() << vs << data->input();
// LOGir << output;
}
static void getitem_inplace(GetitemOp* op) {
@ -207,7 +147,9 @@ static void getitem_inplace(GetitemOp* op) {
if (s.slice.step != 1) return;
}
ou->share_with(in, size);
op->ns.set(GetitemOp::_inplace);
// LOGir << "pass getitem_inplace";
// LOGir << "inplace getitem" << vs << in->shape << ou->shape;
}
void SetitemOp::graph_optimize() {
@ -220,7 +162,6 @@ void GetitemOp::graph_optimize() {
// This optimize is still WIP
// LOGir << "hello getitem graph_optimize";
// setitem_grad_opt(this);
(void*)setitem_grad_opt;
// (void)getitem_inplace;
getitem_inplace(this);
(void*)getitem_inplace;

View File

@ -242,6 +242,106 @@ class TestSetitem(unittest.TestCase):
c = b[::-1]
d = c.sum()
jt.grad(d, [a])
def test_concat2(self):
a = jt.rand(10)
b = jt.rand(11)
c = jt.rand(12)
def cc():
x = jt.concat([b.copy(), c.copy()])
d = jt.concat([a.copy(), x])
return d.copy().copy().copy().copy().copy().copy()\
.copy().copy() + x.sum()*0.0
d = cc()
np.testing.assert_allclose(d.data,
np.concatenate([a.data,b.data,c.data]))
def test_concat3(self):
# a = jt.rand(10)
b = jt.rand(11)
c = jt.rand(12)
def cc():
x = jt.concat([b.copy(), c.copy()])
d = jt.concat([x])
return d.copy().copy().copy().copy().copy().copy()\
.copy().copy() + x.sum()*0.0
d = cc()
np.testing.assert_allclose(d.data,
np.concatenate([b.data,c.data]))
def test_concat4(self):
# a = jt.rand(10)
b = jt.rand(11)
c = jt.rand(12)
def cc():
x = jt.concat([b.copy(), c.copy()])
d = jt.concat([x])
return d
d = cc()
np.testing.assert_allclose(d.data,
np.concatenate([b.data,c.data]))
def test_concat_random(self):
def check():
n1, n2, n3 = 1000, 20, 10
# n1, n2, n3 = 2, 2, 1
import random
data = []
for i in range(n1):
if len(data) > n2:
del data[random.randint(0,len(data)-1)]
x1 = random.randint(0,9)
# print(i, x1)
if len(data) == 0:
# a = jt.random((random.randint(10,20),))
a = jt.array(np.random.rand(random.randint(n3,n3*2)))
data.append(a)
if x1 == 0:
a = data[random.randint(0,len(data)-1)]
a = a.copy()
data.append(a)
elif x1 == 1:
a = data[random.randint(0,len(data)-1)]
a = a.clone()
data.append(a)
elif x1 == 2:
a = data[random.randint(0,len(data)-1)]
b = np.random.permutation(np.arange(a.numel()))
# print("permutation", b)
a = a[b]
data.append(a)
elif x1 == 3:
a = data[random.randint(0,len(data)-1)]
a = a[:100]
# print(a.shape)
data.append(a)
elif x1 == 4:
# a = jt.random((random.randint(10,20),))
a = jt.array(np.random.rand(random.randint(n3,n3*2)))
data.append(a)
else:
if not len(data): continue
n = random.randint(1,3)
a = [ data[random.randint(0,len(data)-1)] for i in range(n) ]
a = jt.concat(a)
if a.numel() > 1000:
b = np.random.permutation(np.arange(a.numel()))
a = a[b][:100]
data.append(a)
ret = jt.concat(data).numpy()
# print(data)
return ret
for s in range(1000):
jt.set_global_seed(s)
data = check()
jt.gc()
jt.set_global_seed(s)
with jt.flag_scope(gopt_disable=1):
data2 = check()
jt.gc()
np.testing.assert_allclose(data, data2)
if __name__ == "__main__":