mirror of https://github.com/Jittor/Jittor
cascade_setitem v[a][b][c] = x -> v[a,b,c] = x
This commit is contained in:
parent
499d3ee99c
commit
c7e604af1a
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.3.5.43'
|
__version__ = '1.3.6.0'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -209,7 +209,7 @@ def setitem(x, slices, value):
|
||||||
else:
|
else:
|
||||||
ss.append(s)
|
ss.append(s)
|
||||||
slices = tuple(ss)
|
slices = tuple(ss)
|
||||||
return x.assign(x.setitem(slices, value))
|
return x.check_cascade_setitem(x.setitem(slices, value))
|
||||||
|
|
||||||
jt.Var.__getitem__ = jt.Var.slice_var = getitem
|
jt.Var.__getitem__ = jt.Var.slice_var = getitem
|
||||||
jt.Var.__setitem__ = setitem
|
jt.Var.__setitem__ = setitem
|
||||||
|
|
|
@ -12,8 +12,9 @@
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
inline static bool fast_strcmp(const char* a, const char* b) {
|
inline static bool fast_strcmp(const char* a, const char* b) {
|
||||||
while (*b && *a == *b) a++, b++;
|
return ((const uint64*)a)[0] == ((const uint64*)b)[0];
|
||||||
return !*b;
|
// while (*b && *a == *b) a++, b++;
|
||||||
|
// return !*b;
|
||||||
}
|
}
|
||||||
|
|
||||||
// add dependency b -> a
|
// add dependency b -> a
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
constexpr size_t alignment = 32;
|
constexpr size_t alignment = 32;
|
||||||
|
struct VarHolder;
|
||||||
|
|
||||||
struct Var : Node {
|
struct Var : Node {
|
||||||
NanoVector shape;
|
NanoVector shape;
|
||||||
|
@ -25,6 +26,7 @@ struct Var : Node {
|
||||||
Allocator* allocator = nullptr;
|
Allocator* allocator = nullptr;
|
||||||
size_t allocation;
|
size_t allocation;
|
||||||
int64 size, num;
|
int64 size, num;
|
||||||
|
VarHolder* holder = nullptr;
|
||||||
inline bool is_float() const { CHECK_EXIST; return ns.is_float(); }
|
inline bool is_float() const { CHECK_EXIST; return ns.is_float(); }
|
||||||
inline int dsize() const { CHECK_EXIST; return ns.dsize(); }
|
inline int dsize() const { CHECK_EXIST; return ns.dsize(); }
|
||||||
inline NanoString dtype() const { CHECK_EXIST; return ns; }
|
inline NanoString dtype() const { CHECK_EXIST; return ns; }
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
#include "graph.h"
|
#include "graph.h"
|
||||||
#include "mem/allocator/cuda_dual_allocator.h"
|
#include "mem/allocator/cuda_dual_allocator.h"
|
||||||
#include "ops/op_register.h"
|
#include "ops/op_register.h"
|
||||||
|
#include "ops/getitem_op.h"
|
||||||
|
#include "ops/setitem_op.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
|
@ -45,6 +47,7 @@ void add_hold_vars(VarHolder* self) {
|
||||||
|
|
||||||
VarHolder::VarHolder(Var* v) : var(v) {
|
VarHolder::VarHolder(Var* v) : var(v) {
|
||||||
// Var holder has both forward and backward liveness
|
// Var holder has both forward and backward liveness
|
||||||
|
own_holder();
|
||||||
var->own_both_liveness();
|
var->own_both_liveness();
|
||||||
add_hold_vars(this);
|
add_hold_vars(this);
|
||||||
}
|
}
|
||||||
|
@ -52,10 +55,12 @@ VarHolder::VarHolder(Var* v) : var(v) {
|
||||||
VarHolder::VarHolder(VarPtr&& v) {
|
VarHolder::VarHolder(VarPtr&& v) {
|
||||||
var = v.ptr;
|
var = v.ptr;
|
||||||
v.ptr = nullptr;
|
v.ptr = nullptr;
|
||||||
|
own_holder();
|
||||||
add_hold_vars(this);
|
add_hold_vars(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
VarHolder::VarHolder(VarHolder* v) : var(v->var) {
|
VarHolder::VarHolder(VarHolder* v) : var(v->var) {
|
||||||
|
own_holder();
|
||||||
iter = v->iter;
|
iter = v->iter;
|
||||||
*iter = this;
|
*iter = this;
|
||||||
// free memory without calling deconstructor
|
// free memory without calling deconstructor
|
||||||
|
@ -73,6 +78,7 @@ VarHolder::VarHolder(PyObject* obj, NanoString dtype) {
|
||||||
vp = make_unary(vp, dtype);
|
vp = make_unary(vp, dtype);
|
||||||
var = vp.ptr;
|
var = vp.ptr;
|
||||||
vp.ptr = nullptr;
|
vp.ptr = nullptr;
|
||||||
|
own_holder();
|
||||||
add_hold_vars(this);
|
add_hold_vars(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,6 +88,7 @@ VarHolder::~VarHolder() {
|
||||||
if (iter == sync_ptr)
|
if (iter == sync_ptr)
|
||||||
sync_ptr = std::next(sync_ptr);
|
sync_ptr = std::next(sync_ptr);
|
||||||
hold_vars.erase(iter);
|
hold_vars.erase(iter);
|
||||||
|
release_holder();
|
||||||
var->release_both_liveness();
|
var->release_both_liveness();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,8 +112,10 @@ void VarHolder::operator=(VarPtr&& v) {
|
||||||
v.ptr->flags.set(NodeFlags::_th_require_grad);
|
v.ptr->flags.set(NodeFlags::_th_require_grad);
|
||||||
}
|
}
|
||||||
assign_var(v.ptr, var);
|
assign_var(v.ptr, var);
|
||||||
|
release_holder();
|
||||||
var->release_both_liveness();
|
var->release_both_liveness();
|
||||||
var = v.ptr;
|
var = v.ptr;
|
||||||
|
own_holder();
|
||||||
v.ptr = nullptr;
|
v.ptr = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,9 +146,11 @@ VarHolder* VarHolder::assign(VarHolder* v) {
|
||||||
v->set_requires_grad(get_requires_grad());
|
v->set_requires_grad(get_requires_grad());
|
||||||
}
|
}
|
||||||
assign_var(v->var, var);
|
assign_var(v->var, var);
|
||||||
|
release_holder();
|
||||||
v->var->own_both_liveness();
|
v->var->own_both_liveness();
|
||||||
var->release_both_liveness();
|
var->release_both_liveness();
|
||||||
var = v->var;
|
var = v->var;
|
||||||
|
own_holder();
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,9 +160,11 @@ VarHolder* VarHolder::update(VarHolder* v) {
|
||||||
}
|
}
|
||||||
|
|
||||||
VarHolder* VarHolder::_update(VarHolder* v) {
|
VarHolder* VarHolder::_update(VarHolder* v) {
|
||||||
|
release_holder();
|
||||||
v->var->own_both_liveness();
|
v->var->own_both_liveness();
|
||||||
var->release_both_liveness();
|
var->release_both_liveness();
|
||||||
var = v->var;
|
var = v->var;
|
||||||
|
own_holder();
|
||||||
var->flags.set(NodeFlags::_out_hint);
|
var->flags.set(NodeFlags::_out_hint);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
@ -278,4 +291,46 @@ void migrate_all_to_cpu() {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static auto make_setitem = get_op_info("setitem")
|
||||||
|
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
|
||||||
|
|
||||||
|
inline static bool fast_strcmp(const char* a, const char* b) {
|
||||||
|
return ((const uint64*)a)[0] == ((const uint64*)b)[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
VarHolder* VarHolder::check_cascade_setitem(VarHolder* out) {
|
||||||
|
// return this;
|
||||||
|
auto v = var;
|
||||||
|
int n=0;
|
||||||
|
int64 slices[10];
|
||||||
|
while (n<10) {
|
||||||
|
Op* iop = v->input();
|
||||||
|
if (!iop) break;
|
||||||
|
if (!fast_strcmp(iop->name(), "getitem")) break;
|
||||||
|
v = iop->inputs().front();
|
||||||
|
GetitemOp* gop = (GetitemOp*)iop;
|
||||||
|
if (gop->vs.n == 1 && gop->vs.slices[0].is_int()) {
|
||||||
|
slices[n++] = gop->vs.slices[0].i;
|
||||||
|
} else break;
|
||||||
|
if (v->holder) {
|
||||||
|
// found holder var: v
|
||||||
|
// v[a][b][c][d] = y
|
||||||
|
// ^
|
||||||
|
auto* prev_op = (SetitemOp*)out->var->input();
|
||||||
|
VarSlices& old_slices = prev_op->vs;
|
||||||
|
Var* y = prev_op->input(1);
|
||||||
|
VarSlices new_slices(n+old_slices.n);
|
||||||
|
for (int i=n-1; i>=0; i--)
|
||||||
|
new_slices.slices[n-1-i].set_int(slices[i]);
|
||||||
|
for (int i=0; i<old_slices.n; i++)
|
||||||
|
new_slices.slices[n+i] = old_slices.slices[i];
|
||||||
|
// apply new slice
|
||||||
|
// v[a][b][c][d] = y -> v[a,b,c,d] = y
|
||||||
|
(*v->holder) = make_setitem(v, move(new_slices), y, ns_void);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return assign(out);
|
||||||
|
}
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -58,6 +58,9 @@ struct VarHolder {
|
||||||
// @pyjt(fetch_sync,numpy)
|
// @pyjt(fetch_sync,numpy)
|
||||||
ArrayArgs fetch_sync();
|
ArrayArgs fetch_sync();
|
||||||
|
|
||||||
|
inline void release_holder() {var->holder = nullptr;}
|
||||||
|
inline void own_holder() {var->holder = this;}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* assign the data from another Var.
|
* assign the data from another Var.
|
||||||
*/
|
*/
|
||||||
|
@ -87,7 +90,11 @@ struct VarHolder {
|
||||||
*/
|
*/
|
||||||
// @pyjt(swap)
|
// @pyjt(swap)
|
||||||
// @attrs(return_self)
|
// @attrs(return_self)
|
||||||
inline VarHolder* swap(VarHolder* v) { std::swap(var, v->var); return this; };
|
inline VarHolder* swap(VarHolder* v) {
|
||||||
|
std::swap(var, v->var);
|
||||||
|
own_holder(); v->own_holder();
|
||||||
|
return this;
|
||||||
|
};
|
||||||
|
|
||||||
void operator=(VarPtr&& v);
|
void operator=(VarPtr&& v);
|
||||||
|
|
||||||
|
@ -330,6 +337,11 @@ struct VarHolder {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* check a[x][y] = c
|
||||||
|
*/
|
||||||
|
// @pyjt(check_cascade_setitem)
|
||||||
|
// @attrs(return_self)
|
||||||
|
VarHolder* check_cascade_setitem(VarHolder* out);
|
||||||
};
|
};
|
||||||
|
|
||||||
// @pyjt(sync)
|
// @pyjt(sync)
|
||||||
|
|
|
@ -437,7 +437,11 @@ class TestSetitem(unittest.TestCase):
|
||||||
b.sync(True)
|
b.sync(True)
|
||||||
assert b.item() == 1
|
assert b.item() == 1
|
||||||
|
|
||||||
|
def test_cascade_setitem(self):
|
||||||
|
a = jt.zeros(3,3,3,3)
|
||||||
|
a[1][2][0][0] = 1
|
||||||
|
assert a[1,2,0,0] == 1
|
||||||
|
# TODO: convert a[x] = a[x] + b -> a[x] += b
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
Loading…
Reference in New Issue