cascade_setitem v[a][b][c] = x -> v[a,b,c] = x

This commit is contained in:
Dun Liang 2022-12-05 14:42:15 +08:00
parent 499d3ee99c
commit c7e604af1a
7 changed files with 80 additions and 6 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.3.5.43' __version__ = '1.3.6.0'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -209,7 +209,7 @@ def setitem(x, slices, value):
else: else:
ss.append(s) ss.append(s)
slices = tuple(ss) slices = tuple(ss)
return x.assign(x.setitem(slices, value)) return x.check_cascade_setitem(x.setitem(slices, value))
jt.Var.__getitem__ = jt.Var.slice_var = getitem jt.Var.__getitem__ = jt.Var.slice_var = getitem
jt.Var.__setitem__ = setitem jt.Var.__setitem__ = setitem

View File

@ -12,8 +12,9 @@
namespace jittor { namespace jittor {
inline static bool fast_strcmp(const char* a, const char* b) { inline static bool fast_strcmp(const char* a, const char* b) {
while (*b && *a == *b) a++, b++; return ((const uint64*)a)[0] == ((const uint64*)b)[0];
return !*b; // while (*b && *a == *b) a++, b++;
// return !*b;
} }
// add dependency b -> a // add dependency b -> a

View File

@ -13,6 +13,7 @@
namespace jittor { namespace jittor {
constexpr size_t alignment = 32; constexpr size_t alignment = 32;
struct VarHolder;
struct Var : Node { struct Var : Node {
NanoVector shape; NanoVector shape;
@ -25,6 +26,7 @@ struct Var : Node {
Allocator* allocator = nullptr; Allocator* allocator = nullptr;
size_t allocation; size_t allocation;
int64 size, num; int64 size, num;
VarHolder* holder = nullptr;
inline bool is_float() const { CHECK_EXIST; return ns.is_float(); } inline bool is_float() const { CHECK_EXIST; return ns.is_float(); }
inline int dsize() const { CHECK_EXIST; return ns.dsize(); } inline int dsize() const { CHECK_EXIST; return ns.dsize(); }
inline NanoString dtype() const { CHECK_EXIST; return ns; } inline NanoString dtype() const { CHECK_EXIST; return ns; }

View File

@ -15,6 +15,8 @@
#include "graph.h" #include "graph.h"
#include "mem/allocator/cuda_dual_allocator.h" #include "mem/allocator/cuda_dual_allocator.h"
#include "ops/op_register.h" #include "ops/op_register.h"
#include "ops/getitem_op.h"
#include "ops/setitem_op.h"
namespace jittor { namespace jittor {
@ -45,6 +47,7 @@ void add_hold_vars(VarHolder* self) {
VarHolder::VarHolder(Var* v) : var(v) { VarHolder::VarHolder(Var* v) : var(v) {
// Var holder has both forward and backward liveness // Var holder has both forward and backward liveness
own_holder();
var->own_both_liveness(); var->own_both_liveness();
add_hold_vars(this); add_hold_vars(this);
} }
@ -52,10 +55,12 @@ VarHolder::VarHolder(Var* v) : var(v) {
VarHolder::VarHolder(VarPtr&& v) { VarHolder::VarHolder(VarPtr&& v) {
var = v.ptr; var = v.ptr;
v.ptr = nullptr; v.ptr = nullptr;
own_holder();
add_hold_vars(this); add_hold_vars(this);
} }
VarHolder::VarHolder(VarHolder* v) : var(v->var) { VarHolder::VarHolder(VarHolder* v) : var(v->var) {
own_holder();
iter = v->iter; iter = v->iter;
*iter = this; *iter = this;
// free memory without calling deconstructor // free memory without calling deconstructor
@ -73,6 +78,7 @@ VarHolder::VarHolder(PyObject* obj, NanoString dtype) {
vp = make_unary(vp, dtype); vp = make_unary(vp, dtype);
var = vp.ptr; var = vp.ptr;
vp.ptr = nullptr; vp.ptr = nullptr;
own_holder();
add_hold_vars(this); add_hold_vars(this);
} }
@ -82,6 +88,7 @@ VarHolder::~VarHolder() {
if (iter == sync_ptr) if (iter == sync_ptr)
sync_ptr = std::next(sync_ptr); sync_ptr = std::next(sync_ptr);
hold_vars.erase(iter); hold_vars.erase(iter);
release_holder();
var->release_both_liveness(); var->release_both_liveness();
} }
@ -105,8 +112,10 @@ void VarHolder::operator=(VarPtr&& v) {
v.ptr->flags.set(NodeFlags::_th_require_grad); v.ptr->flags.set(NodeFlags::_th_require_grad);
} }
assign_var(v.ptr, var); assign_var(v.ptr, var);
release_holder();
var->release_both_liveness(); var->release_both_liveness();
var = v.ptr; var = v.ptr;
own_holder();
v.ptr = nullptr; v.ptr = nullptr;
} }
@ -137,9 +146,11 @@ VarHolder* VarHolder::assign(VarHolder* v) {
v->set_requires_grad(get_requires_grad()); v->set_requires_grad(get_requires_grad());
} }
assign_var(v->var, var); assign_var(v->var, var);
release_holder();
v->var->own_both_liveness(); v->var->own_both_liveness();
var->release_both_liveness(); var->release_both_liveness();
var = v->var; var = v->var;
own_holder();
return this; return this;
} }
@ -149,9 +160,11 @@ VarHolder* VarHolder::update(VarHolder* v) {
} }
VarHolder* VarHolder::_update(VarHolder* v) { VarHolder* VarHolder::_update(VarHolder* v) {
release_holder();
v->var->own_both_liveness(); v->var->own_both_liveness();
var->release_both_liveness(); var->release_both_liveness();
var = v->var; var = v->var;
own_holder();
var->flags.set(NodeFlags::_out_hint); var->flags.set(NodeFlags::_out_hint);
return this; return this;
} }
@ -278,4 +291,46 @@ void migrate_all_to_cpu() {
#endif #endif
} }
static auto make_setitem = get_op_info("setitem")
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
inline static bool fast_strcmp(const char* a, const char* b) {
return ((const uint64*)a)[0] == ((const uint64*)b)[0];
}
VarHolder* VarHolder::check_cascade_setitem(VarHolder* out) {
// return this;
auto v = var;
int n=0;
int64 slices[10];
while (n<10) {
Op* iop = v->input();
if (!iop) break;
if (!fast_strcmp(iop->name(), "getitem")) break;
v = iop->inputs().front();
GetitemOp* gop = (GetitemOp*)iop;
if (gop->vs.n == 1 && gop->vs.slices[0].is_int()) {
slices[n++] = gop->vs.slices[0].i;
} else break;
if (v->holder) {
// found holder var: v
// v[a][b][c][d] = y
// ^
auto* prev_op = (SetitemOp*)out->var->input();
VarSlices& old_slices = prev_op->vs;
Var* y = prev_op->input(1);
VarSlices new_slices(n+old_slices.n);
for (int i=n-1; i>=0; i--)
new_slices.slices[n-1-i].set_int(slices[i]);
for (int i=0; i<old_slices.n; i++)
new_slices.slices[n+i] = old_slices.slices[i];
// apply new slice
// v[a][b][c][d] = y -> v[a,b,c,d] = y
(*v->holder) = make_setitem(v, move(new_slices), y, ns_void);
break;
}
}
return assign(out);
}
} // jittor } // jittor

View File

@ -58,6 +58,9 @@ struct VarHolder {
// @pyjt(fetch_sync,numpy) // @pyjt(fetch_sync,numpy)
ArrayArgs fetch_sync(); ArrayArgs fetch_sync();
inline void release_holder() {var->holder = nullptr;}
inline void own_holder() {var->holder = this;}
/** /**
* assign the data from another Var. * assign the data from another Var.
*/ */
@ -87,7 +90,11 @@ struct VarHolder {
*/ */
// @pyjt(swap) // @pyjt(swap)
// @attrs(return_self) // @attrs(return_self)
inline VarHolder* swap(VarHolder* v) { std::swap(var, v->var); return this; }; inline VarHolder* swap(VarHolder* v) {
std::swap(var, v->var);
own_holder(); v->own_holder();
return this;
};
void operator=(VarPtr&& v); void operator=(VarPtr&& v);
@ -330,6 +337,11 @@ struct VarHolder {
return this; return this;
} }
/* check a[x][y] = c
*/
// @pyjt(check_cascade_setitem)
// @attrs(return_self)
VarHolder* check_cascade_setitem(VarHolder* out);
}; };
// @pyjt(sync) // @pyjt(sync)

View File

@ -437,7 +437,11 @@ class TestSetitem(unittest.TestCase):
b.sync(True) b.sync(True)
assert b.item() == 1 assert b.item() == 1
def test_cascade_setitem(self):
a = jt.zeros(3,3,3,3)
a[1][2][0][0] = 1
assert a[1,2,0,0] == 1
# TODO: convert a[x] = a[x] + b -> a[x] += b
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()