mirror of https://github.com/Jittor/Jittor
cascade_setitem v[a][b][c] = x -> v[a,b,c] = x
This commit is contained in:
parent
499d3ee99c
commit
c7e604af1a
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.5.43'
|
||||
__version__ = '1.3.6.0'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -209,7 +209,7 @@ def setitem(x, slices, value):
|
|||
else:
|
||||
ss.append(s)
|
||||
slices = tuple(ss)
|
||||
return x.assign(x.setitem(slices, value))
|
||||
return x.check_cascade_setitem(x.setitem(slices, value))
|
||||
|
||||
jt.Var.__getitem__ = jt.Var.slice_var = getitem
|
||||
jt.Var.__setitem__ = setitem
|
||||
|
|
|
@ -12,8 +12,9 @@
|
|||
namespace jittor {
|
||||
|
||||
inline static bool fast_strcmp(const char* a, const char* b) {
|
||||
while (*b && *a == *b) a++, b++;
|
||||
return !*b;
|
||||
return ((const uint64*)a)[0] == ((const uint64*)b)[0];
|
||||
// while (*b && *a == *b) a++, b++;
|
||||
// return !*b;
|
||||
}
|
||||
|
||||
// add dependency b -> a
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
namespace jittor {
|
||||
|
||||
constexpr size_t alignment = 32;
|
||||
struct VarHolder;
|
||||
|
||||
struct Var : Node {
|
||||
NanoVector shape;
|
||||
|
@ -25,6 +26,7 @@ struct Var : Node {
|
|||
Allocator* allocator = nullptr;
|
||||
size_t allocation;
|
||||
int64 size, num;
|
||||
VarHolder* holder = nullptr;
|
||||
inline bool is_float() const { CHECK_EXIST; return ns.is_float(); }
|
||||
inline int dsize() const { CHECK_EXIST; return ns.dsize(); }
|
||||
inline NanoString dtype() const { CHECK_EXIST; return ns; }
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
#include "graph.h"
|
||||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "ops/getitem_op.h"
|
||||
#include "ops/setitem_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -45,6 +47,7 @@ void add_hold_vars(VarHolder* self) {
|
|||
|
||||
VarHolder::VarHolder(Var* v) : var(v) {
|
||||
// Var holder has both forward and backward liveness
|
||||
own_holder();
|
||||
var->own_both_liveness();
|
||||
add_hold_vars(this);
|
||||
}
|
||||
|
@ -52,10 +55,12 @@ VarHolder::VarHolder(Var* v) : var(v) {
|
|||
VarHolder::VarHolder(VarPtr&& v) {
|
||||
var = v.ptr;
|
||||
v.ptr = nullptr;
|
||||
own_holder();
|
||||
add_hold_vars(this);
|
||||
}
|
||||
|
||||
VarHolder::VarHolder(VarHolder* v) : var(v->var) {
|
||||
own_holder();
|
||||
iter = v->iter;
|
||||
*iter = this;
|
||||
// free memory without calling deconstructor
|
||||
|
@ -73,6 +78,7 @@ VarHolder::VarHolder(PyObject* obj, NanoString dtype) {
|
|||
vp = make_unary(vp, dtype);
|
||||
var = vp.ptr;
|
||||
vp.ptr = nullptr;
|
||||
own_holder();
|
||||
add_hold_vars(this);
|
||||
}
|
||||
|
||||
|
@ -82,6 +88,7 @@ VarHolder::~VarHolder() {
|
|||
if (iter == sync_ptr)
|
||||
sync_ptr = std::next(sync_ptr);
|
||||
hold_vars.erase(iter);
|
||||
release_holder();
|
||||
var->release_both_liveness();
|
||||
}
|
||||
|
||||
|
@ -105,8 +112,10 @@ void VarHolder::operator=(VarPtr&& v) {
|
|||
v.ptr->flags.set(NodeFlags::_th_require_grad);
|
||||
}
|
||||
assign_var(v.ptr, var);
|
||||
release_holder();
|
||||
var->release_both_liveness();
|
||||
var = v.ptr;
|
||||
own_holder();
|
||||
v.ptr = nullptr;
|
||||
}
|
||||
|
||||
|
@ -137,9 +146,11 @@ VarHolder* VarHolder::assign(VarHolder* v) {
|
|||
v->set_requires_grad(get_requires_grad());
|
||||
}
|
||||
assign_var(v->var, var);
|
||||
release_holder();
|
||||
v->var->own_both_liveness();
|
||||
var->release_both_liveness();
|
||||
var = v->var;
|
||||
own_holder();
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -149,9 +160,11 @@ VarHolder* VarHolder::update(VarHolder* v) {
|
|||
}
|
||||
|
||||
VarHolder* VarHolder::_update(VarHolder* v) {
|
||||
release_holder();
|
||||
v->var->own_both_liveness();
|
||||
var->release_both_liveness();
|
||||
var = v->var;
|
||||
own_holder();
|
||||
var->flags.set(NodeFlags::_out_hint);
|
||||
return this;
|
||||
}
|
||||
|
@ -278,4 +291,46 @@ void migrate_all_to_cpu() {
|
|||
#endif
|
||||
}
|
||||
|
||||
static auto make_setitem = get_op_info("setitem")
|
||||
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
|
||||
|
||||
inline static bool fast_strcmp(const char* a, const char* b) {
|
||||
return ((const uint64*)a)[0] == ((const uint64*)b)[0];
|
||||
}
|
||||
|
||||
VarHolder* VarHolder::check_cascade_setitem(VarHolder* out) {
|
||||
// return this;
|
||||
auto v = var;
|
||||
int n=0;
|
||||
int64 slices[10];
|
||||
while (n<10) {
|
||||
Op* iop = v->input();
|
||||
if (!iop) break;
|
||||
if (!fast_strcmp(iop->name(), "getitem")) break;
|
||||
v = iop->inputs().front();
|
||||
GetitemOp* gop = (GetitemOp*)iop;
|
||||
if (gop->vs.n == 1 && gop->vs.slices[0].is_int()) {
|
||||
slices[n++] = gop->vs.slices[0].i;
|
||||
} else break;
|
||||
if (v->holder) {
|
||||
// found holder var: v
|
||||
// v[a][b][c][d] = y
|
||||
// ^
|
||||
auto* prev_op = (SetitemOp*)out->var->input();
|
||||
VarSlices& old_slices = prev_op->vs;
|
||||
Var* y = prev_op->input(1);
|
||||
VarSlices new_slices(n+old_slices.n);
|
||||
for (int i=n-1; i>=0; i--)
|
||||
new_slices.slices[n-1-i].set_int(slices[i]);
|
||||
for (int i=0; i<old_slices.n; i++)
|
||||
new_slices.slices[n+i] = old_slices.slices[i];
|
||||
// apply new slice
|
||||
// v[a][b][c][d] = y -> v[a,b,c,d] = y
|
||||
(*v->holder) = make_setitem(v, move(new_slices), y, ns_void);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return assign(out);
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -58,6 +58,9 @@ struct VarHolder {
|
|||
// @pyjt(fetch_sync,numpy)
|
||||
ArrayArgs fetch_sync();
|
||||
|
||||
inline void release_holder() {var->holder = nullptr;}
|
||||
inline void own_holder() {var->holder = this;}
|
||||
|
||||
/**
|
||||
* assign the data from another Var.
|
||||
*/
|
||||
|
@ -87,7 +90,11 @@ struct VarHolder {
|
|||
*/
|
||||
// @pyjt(swap)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* swap(VarHolder* v) { std::swap(var, v->var); return this; };
|
||||
inline VarHolder* swap(VarHolder* v) {
|
||||
std::swap(var, v->var);
|
||||
own_holder(); v->own_holder();
|
||||
return this;
|
||||
};
|
||||
|
||||
void operator=(VarPtr&& v);
|
||||
|
||||
|
@ -330,6 +337,11 @@ struct VarHolder {
|
|||
return this;
|
||||
}
|
||||
|
||||
/* check a[x][y] = c
|
||||
*/
|
||||
// @pyjt(check_cascade_setitem)
|
||||
// @attrs(return_self)
|
||||
VarHolder* check_cascade_setitem(VarHolder* out);
|
||||
};
|
||||
|
||||
// @pyjt(sync)
|
||||
|
|
|
@ -437,7 +437,11 @@ class TestSetitem(unittest.TestCase):
|
|||
b.sync(True)
|
||||
assert b.item() == 1
|
||||
|
||||
|
||||
def test_cascade_setitem(self):
|
||||
a = jt.zeros(3,3,3,3)
|
||||
a[1][2][0][0] = 1
|
||||
assert a[1,2,0,0] == 1
|
||||
# TODO: convert a[x] = a[x] + b -> a[x] += b
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue