mirror of https://github.com/Jittor/Jittor
polish gopt setitem concat
This commit is contained in:
parent
da0dc2cfba
commit
18795bd02f
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.63'
|
||||
__version__ = '1.3.2.0'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -465,7 +465,8 @@ void GetitemOp::jit_run() {
|
|||
auto in = inputs().front();
|
||||
auto out = outputs().front();
|
||||
if (out->num == 0) return;
|
||||
if (in->allocator == out->allocator &&
|
||||
if (ns.get(GetitemOp::_inplace) &&
|
||||
in->allocator == out->allocator &&
|
||||
in->allocation == out->allocation)
|
||||
return;
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
namespace jittor {
|
||||
|
||||
struct GetitemOp : Op {
|
||||
static constexpr jittor::NanoString::Flags _inplace = (jittor::NanoString::Flags)0;
|
||||
VarSlices vs;
|
||||
// map i to related var slice
|
||||
NanoVector i_to_vs;
|
||||
|
|
|
@ -315,7 +315,7 @@ void SetitemOp::jit_run() {
|
|||
checkCudaErrors(cudaMemcpyAsync(op, ip, out->size, cudaMemcpyDeviceToDevice, 0));
|
||||
#endif
|
||||
|
||||
if (flags.get((NodeFlags::Flags(SetitemOp::_data_inplaced))) &&
|
||||
if (ns.get(GetitemOp::_inplace) &&
|
||||
// array op may move the data allocation, double check
|
||||
// affect test_contrib.pu
|
||||
in->allocator == data->allocator &&
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
namespace jittor {
|
||||
|
||||
struct SetitemOp : Op {
|
||||
static constexpr int _data_inplaced = NodeFlags::_has_vary_input + 1;
|
||||
VarSlices vs;
|
||||
// map i to related var slice
|
||||
NanoVector i_to_vs;
|
||||
|
|
|
@ -107,69 +107,9 @@ static void setitem_inplace(SetitemOp* op) {
|
|||
}
|
||||
add_dependency(data->input(), {input->node()});
|
||||
data->share_with(input, size);
|
||||
op->flags.set((NodeFlags::Flags(SetitemOp::_data_inplaced)));
|
||||
}
|
||||
|
||||
struct BBox {
|
||||
int n = 0;
|
||||
int* minmax = nullptr;
|
||||
|
||||
|
||||
|
||||
void load_var_slice(const VarSlice& vs) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
static void setitem_grad_opt(GetitemOp* op) {
|
||||
if (!op->flags.get(NodeFlags::_has_gopt))
|
||||
return;
|
||||
auto get_in = op->inputs().front();
|
||||
auto get_in_op = get_in->input();
|
||||
if (!get_in_op)
|
||||
return;
|
||||
auto name = get_in_op->name();
|
||||
if (!fast_strcmp(name, "setitem"))
|
||||
return;
|
||||
// find setitem op chain
|
||||
auto first_set = (SetitemOp*)get_in_op;
|
||||
vector<SetitemOp*> chain;
|
||||
while (1) {
|
||||
auto next = first_set->inputs().front()->input();
|
||||
if (!next) break;
|
||||
if (!fast_strcmp(next->name(), "setitem"))
|
||||
break;
|
||||
chain.push_back(first_set);
|
||||
first_set = (SetitemOp*)next;
|
||||
}
|
||||
chain.push_back(first_set);
|
||||
for (int i=0; i<chain.size()/2; i++)
|
||||
std::swap(chain[i], chain[chain.size()-1-i]);
|
||||
auto last_set = (SetitemOp*)get_in_op;
|
||||
while (1) {
|
||||
SetitemOp* next = nullptr;
|
||||
auto out_var = last_set->outputs().front();
|
||||
for (auto* out : out_var->outputs()) {
|
||||
if (fast_strcmp(out->name(), "setitem")) {
|
||||
next = (SetitemOp*)out;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!next) break;
|
||||
last_set = next;
|
||||
chain.push_back(next);
|
||||
}
|
||||
// LOGir << "find setitem chain" << chain.size() << chain;
|
||||
for (auto* sop : chain) {
|
||||
// LOGig << sop << sop->vs;
|
||||
auto out_var = sop->outputs().front();
|
||||
for (auto* out : out_var->outputs()) {
|
||||
if (fast_strcmp(out->name(), "getitem")) {
|
||||
out->flags.set(NodeFlags::_has_gopt, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
op->ns.set(GetitemOp::_inplace);
|
||||
// LOGir << input->shape << input->dtype() << data->shape << data->dtype() << vs << data->input();
|
||||
// LOGir << output;
|
||||
}
|
||||
|
||||
static void getitem_inplace(GetitemOp* op) {
|
||||
|
@ -207,7 +147,9 @@ static void getitem_inplace(GetitemOp* op) {
|
|||
if (s.slice.step != 1) return;
|
||||
}
|
||||
ou->share_with(in, size);
|
||||
op->ns.set(GetitemOp::_inplace);
|
||||
// LOGir << "pass getitem_inplace";
|
||||
// LOGir << "inplace getitem" << vs << in->shape << ou->shape;
|
||||
}
|
||||
|
||||
void SetitemOp::graph_optimize() {
|
||||
|
@ -220,7 +162,6 @@ void GetitemOp::graph_optimize() {
|
|||
// This optimize is still WIP
|
||||
// LOGir << "hello getitem graph_optimize";
|
||||
// setitem_grad_opt(this);
|
||||
(void*)setitem_grad_opt;
|
||||
// (void)getitem_inplace;
|
||||
getitem_inplace(this);
|
||||
(void*)getitem_inplace;
|
||||
|
|
|
@ -242,6 +242,106 @@ class TestSetitem(unittest.TestCase):
|
|||
c = b[::-1]
|
||||
d = c.sum()
|
||||
jt.grad(d, [a])
|
||||
|
||||
def test_concat2(self):
|
||||
a = jt.rand(10)
|
||||
b = jt.rand(11)
|
||||
c = jt.rand(12)
|
||||
def cc():
|
||||
x = jt.concat([b.copy(), c.copy()])
|
||||
d = jt.concat([a.copy(), x])
|
||||
return d.copy().copy().copy().copy().copy().copy()\
|
||||
.copy().copy() + x.sum()*0.0
|
||||
d = cc()
|
||||
np.testing.assert_allclose(d.data,
|
||||
np.concatenate([a.data,b.data,c.data]))
|
||||
|
||||
def test_concat3(self):
|
||||
# a = jt.rand(10)
|
||||
b = jt.rand(11)
|
||||
c = jt.rand(12)
|
||||
def cc():
|
||||
x = jt.concat([b.copy(), c.copy()])
|
||||
d = jt.concat([x])
|
||||
return d.copy().copy().copy().copy().copy().copy()\
|
||||
.copy().copy() + x.sum()*0.0
|
||||
d = cc()
|
||||
np.testing.assert_allclose(d.data,
|
||||
np.concatenate([b.data,c.data]))
|
||||
|
||||
|
||||
def test_concat4(self):
|
||||
# a = jt.rand(10)
|
||||
b = jt.rand(11)
|
||||
c = jt.rand(12)
|
||||
def cc():
|
||||
x = jt.concat([b.copy(), c.copy()])
|
||||
d = jt.concat([x])
|
||||
return d
|
||||
d = cc()
|
||||
np.testing.assert_allclose(d.data,
|
||||
np.concatenate([b.data,c.data]))
|
||||
|
||||
def test_concat_random(self):
|
||||
def check():
|
||||
n1, n2, n3 = 1000, 20, 10
|
||||
# n1, n2, n3 = 2, 2, 1
|
||||
import random
|
||||
data = []
|
||||
for i in range(n1):
|
||||
if len(data) > n2:
|
||||
del data[random.randint(0,len(data)-1)]
|
||||
x1 = random.randint(0,9)
|
||||
# print(i, x1)
|
||||
if len(data) == 0:
|
||||
# a = jt.random((random.randint(10,20),))
|
||||
a = jt.array(np.random.rand(random.randint(n3,n3*2)))
|
||||
data.append(a)
|
||||
if x1 == 0:
|
||||
a = data[random.randint(0,len(data)-1)]
|
||||
a = a.copy()
|
||||
data.append(a)
|
||||
elif x1 == 1:
|
||||
a = data[random.randint(0,len(data)-1)]
|
||||
a = a.clone()
|
||||
data.append(a)
|
||||
elif x1 == 2:
|
||||
a = data[random.randint(0,len(data)-1)]
|
||||
b = np.random.permutation(np.arange(a.numel()))
|
||||
# print("permutation", b)
|
||||
a = a[b]
|
||||
data.append(a)
|
||||
elif x1 == 3:
|
||||
a = data[random.randint(0,len(data)-1)]
|
||||
a = a[:100]
|
||||
# print(a.shape)
|
||||
data.append(a)
|
||||
elif x1 == 4:
|
||||
# a = jt.random((random.randint(10,20),))
|
||||
a = jt.array(np.random.rand(random.randint(n3,n3*2)))
|
||||
data.append(a)
|
||||
else:
|
||||
if not len(data): continue
|
||||
n = random.randint(1,3)
|
||||
a = [ data[random.randint(0,len(data)-1)] for i in range(n) ]
|
||||
a = jt.concat(a)
|
||||
if a.numel() > 1000:
|
||||
b = np.random.permutation(np.arange(a.numel()))
|
||||
a = a[b][:100]
|
||||
data.append(a)
|
||||
ret = jt.concat(data).numpy()
|
||||
# print(data)
|
||||
return ret
|
||||
|
||||
for s in range(1000):
|
||||
jt.set_global_seed(s)
|
||||
data = check()
|
||||
jt.gc()
|
||||
jt.set_global_seed(s)
|
||||
with jt.flag_scope(gopt_disable=1):
|
||||
data2 = check()
|
||||
jt.gc()
|
||||
np.testing.assert_allclose(data, data2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue