diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 68c75d0b..8c0765e0 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.5.43' +__version__ = '1.3.6.0' from jittor_utils import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py index 4432684c..f8549552 100644 --- a/python/jittor/contrib.py +++ b/python/jittor/contrib.py @@ -209,7 +209,7 @@ def setitem(x, slices, value): else: ss.append(s) slices = tuple(ss) - return x.assign(x.setitem(slices, value)) + return x.check_cascade_setitem(x.setitem(slices, value)) jt.Var.__getitem__ = jt.Var.slice_var = getitem jt.Var.__setitem__ = setitem diff --git a/python/jittor/src/opt/gopt/setitem_gopt.cc b/python/jittor/src/opt/gopt/setitem_gopt.cc index b6abfb13..788fc4df 100644 --- a/python/jittor/src/opt/gopt/setitem_gopt.cc +++ b/python/jittor/src/opt/gopt/setitem_gopt.cc @@ -12,8 +12,9 @@ namespace jittor { inline static bool fast_strcmp(const char* a, const char* b) { - while (*b && *a == *b) a++, b++; - return !*b; + return ((const uint64*)a)[0] == ((const uint64*)b)[0]; + // while (*b && *a == *b) a++, b++; + // return !*b; } // add dependency b -> a diff --git a/python/jittor/src/var.h b/python/jittor/src/var.h index 1b7ca721..3684a7ee 100644 --- a/python/jittor/src/var.h +++ b/python/jittor/src/var.h @@ -13,6 +13,7 @@ namespace jittor { constexpr size_t alignment = 32; +struct VarHolder; struct Var : Node { NanoVector shape; @@ -25,6 +26,7 @@ struct Var : Node { Allocator* allocator = nullptr; size_t allocation; int64 size, num; + VarHolder* holder = nullptr; inline bool is_float() const { CHECK_EXIST; return ns.is_float(); } inline int dsize() const { CHECK_EXIST; return ns.dsize(); } inline NanoString dtype() const { CHECK_EXIST; return ns; } diff --git a/python/jittor/src/var_holder.cc b/python/jittor/src/var_holder.cc index 25f090cb..417a5f7c 100644 --- a/python/jittor/src/var_holder.cc +++ b/python/jittor/src/var_holder.cc @@ -15,6 +15,8 @@ #include "graph.h" #include "mem/allocator/cuda_dual_allocator.h" #include "ops/op_register.h" +#include "ops/getitem_op.h" +#include "ops/setitem_op.h" namespace jittor { @@ -45,6 +47,7 @@ void add_hold_vars(VarHolder* self) { VarHolder::VarHolder(Var* v) : var(v) { // Var holder has both forward and backward liveness + own_holder(); var->own_both_liveness(); add_hold_vars(this); } @@ -52,10 +55,12 @@ VarHolder::VarHolder(Var* v) : var(v) { VarHolder::VarHolder(VarPtr&& v) { var = v.ptr; v.ptr = nullptr; + own_holder(); add_hold_vars(this); } VarHolder::VarHolder(VarHolder* v) : var(v->var) { + own_holder(); iter = v->iter; *iter = this; // free memory without calling deconstructor @@ -73,6 +78,7 @@ VarHolder::VarHolder(PyObject* obj, NanoString dtype) { vp = make_unary(vp, dtype); var = vp.ptr; vp.ptr = nullptr; + own_holder(); add_hold_vars(this); } @@ -82,6 +88,7 @@ VarHolder::~VarHolder() { if (iter == sync_ptr) sync_ptr = std::next(sync_ptr); hold_vars.erase(iter); + release_holder(); var->release_both_liveness(); } @@ -105,8 +112,10 @@ void VarHolder::operator=(VarPtr&& v) { v.ptr->flags.set(NodeFlags::_th_require_grad); } assign_var(v.ptr, var); + release_holder(); var->release_both_liveness(); var = v.ptr; + own_holder(); v.ptr = nullptr; } @@ -137,9 +146,11 @@ VarHolder* VarHolder::assign(VarHolder* v) { v->set_requires_grad(get_requires_grad()); } assign_var(v->var, var); + release_holder(); v->var->own_both_liveness(); var->release_both_liveness(); var = v->var; + own_holder(); return this; } @@ -149,9 +160,11 @@ VarHolder* VarHolder::update(VarHolder* v) { } VarHolder* VarHolder::_update(VarHolder* v) { + release_holder(); v->var->own_both_liveness(); var->release_both_liveness(); var = v->var; + own_holder(); var->flags.set(NodeFlags::_out_hint); return this; } @@ -278,4 +291,46 @@ void migrate_all_to_cpu() { #endif } +static auto make_setitem = get_op_info("setitem") + .get_constructor(); + +inline static bool fast_strcmp(const char* a, const char* b) { + return ((const uint64*)a)[0] == ((const uint64*)b)[0]; +} + +VarHolder* VarHolder::check_cascade_setitem(VarHolder* out) { + // return this; + auto v = var; + int n=0; + int64 slices[10]; + while (n<10) { + Op* iop = v->input(); + if (!iop) break; + if (!fast_strcmp(iop->name(), "getitem")) break; + v = iop->inputs().front(); + GetitemOp* gop = (GetitemOp*)iop; + if (gop->vs.n == 1 && gop->vs.slices[0].is_int()) { + slices[n++] = gop->vs.slices[0].i; + } else break; + if (v->holder) { + // found holder var: v + // v[a][b][c][d] = y + // ^ + auto* prev_op = (SetitemOp*)out->var->input(); + VarSlices& old_slices = prev_op->vs; + Var* y = prev_op->input(1); + VarSlices new_slices(n+old_slices.n); + for (int i=n-1; i>=0; i--) + new_slices.slices[n-1-i].set_int(slices[i]); + for (int i=0; i v[a,b,c,d] = y + (*v->holder) = make_setitem(v, move(new_slices), y, ns_void); + break; + } + } + return assign(out); +} + } // jittor \ No newline at end of file diff --git a/python/jittor/src/var_holder.h b/python/jittor/src/var_holder.h index c5bbb3ae..f9511752 100644 --- a/python/jittor/src/var_holder.h +++ b/python/jittor/src/var_holder.h @@ -58,6 +58,9 @@ struct VarHolder { // @pyjt(fetch_sync,numpy) ArrayArgs fetch_sync(); + inline void release_holder() {var->holder = nullptr;} + inline void own_holder() {var->holder = this;} + /** * assign the data from another Var. */ @@ -87,7 +90,11 @@ struct VarHolder { */ // @pyjt(swap) // @attrs(return_self) - inline VarHolder* swap(VarHolder* v) { std::swap(var, v->var); return this; }; + inline VarHolder* swap(VarHolder* v) { + std::swap(var, v->var); + own_holder(); v->own_holder(); + return this; + }; void operator=(VarPtr&& v); @@ -330,6 +337,11 @@ struct VarHolder { return this; } + /* check a[x][y] = c + */ + // @pyjt(check_cascade_setitem) + // @attrs(return_self) + VarHolder* check_cascade_setitem(VarHolder* out); }; // @pyjt(sync) diff --git a/python/jittor/test/test_setitem.py b/python/jittor/test/test_setitem.py index 192e43a0..91c7f3bf 100644 --- a/python/jittor/test/test_setitem.py +++ b/python/jittor/test/test_setitem.py @@ -437,7 +437,11 @@ class TestSetitem(unittest.TestCase): b.sync(True) assert b.item() == 1 - + def test_cascade_setitem(self): + a = jt.zeros(3,3,3,3) + a[1][2][0][0] = 1 + assert a[1,2,0,0] == 1 + # TODO: convert a[x] = a[x] + b -> a[x] += b if __name__ == "__main__": unittest.main() \ No newline at end of file