mirror of https://github.com/Jittor/Jittor
add eager execution
This commit is contained in:
parent
a62b45d6ca
commit
eae3fffcf1
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.1.6.8'
|
||||
__version__ = '1.1.7.0'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
@ -284,6 +284,9 @@ def detach_inplace(x):
|
|||
return x.swap(x.stop_grad().clone())
|
||||
Var.start_grad = Var.detach_inplace = detach_inplace
|
||||
|
||||
def detach(x):
|
||||
return x.detach()
|
||||
|
||||
def unsqueeze(x, dim):
|
||||
shape = list(x.shape)
|
||||
if dim < 0: dim += len(shape) + 1
|
||||
|
@ -623,12 +626,17 @@ can also be None)::
|
|||
|
||||
'''
|
||||
def __call__(self, *args):
|
||||
backup = args
|
||||
args = list(args)
|
||||
taped_inputs = []
|
||||
taped_outputs = []
|
||||
input_mask = [-1] * len(args)
|
||||
for i,v in enumerate(args):
|
||||
if isinstance(v, Var):
|
||||
if v.is_stop_grad():
|
||||
# -2 in input_mask represents it is stop_grad
|
||||
input_mask[i] = -2
|
||||
continue
|
||||
v = v.tape()
|
||||
input_mask[i] = len(taped_inputs)
|
||||
args[i] = v
|
||||
|
@ -664,7 +672,8 @@ can also be None)::
|
|||
for i, r in enumerate(ret):
|
||||
j = self.input_mask[i]
|
||||
if j<0:
|
||||
assert r is None, f"{type(self)}'s {i}-th returned grad should be None, "\
|
||||
# -2 in input_mask represents it is stop_grad
|
||||
assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\
|
||||
"because the input value is not jittor variable."
|
||||
else:
|
||||
new_ret.append(r)
|
||||
|
|
|
@ -21,5 +21,14 @@ class TestClone(unittest.TestCase):
|
|||
c.stop_grad()
|
||||
assert jt.number_of_lived_vars()==3
|
||||
|
||||
def test2(self):
|
||||
a = jt.array([1,2])
|
||||
print(a.detach())
|
||||
|
||||
@jt.flag_scope(eager_execution=1)
|
||||
def test3(self):
|
||||
a = jt.array([1,2])
|
||||
print(a.detach())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -280,5 +280,14 @@ class TestFunction(unittest.TestCase):
|
|||
assert t3 < t2 + 10, (t1,t2,t3)
|
||||
self.assertEqual(jt.liveness_info()["lived_vars"], 0)
|
||||
|
||||
|
||||
class TestFunctionWithEagerExecution(TestFunction):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
jt.flags.eager_execution = 1
|
||||
@classmethod
|
||||
def tearDownClass(self):
|
||||
jt.flags.eager_execution = 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1 +1 @@
|
|||
b27082f9444a4e627f7dfc574d0114302ba27b5e
|
||||
a62b45d6caf9c1c18a9118630ec8a591c576e635
|
||||
|
|
|
@ -146,7 +146,6 @@ void count_fuse(int64_t tt, int start_var_num, const vector<Op*>& ops, const vec
|
|||
for (uint i=0; i<ops.size(); i++)
|
||||
LOGvvvv << ops[i] << dis[i] << deps[i];
|
||||
}
|
||||
|
||||
for (uint i=0; i<vars.size(); i++) {
|
||||
Var* v = vars[i];
|
||||
if (!v || v->tflag!=tt) {
|
||||
|
|
|
@ -15,24 +15,41 @@
|
|||
#include "update_queue.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
||||
DEFINE_FLAG(int, eager_execution, 0, "Use Eager execution rather than lazy execution, This flag makes error message and traceback infomation better.");
|
||||
|
||||
list<VarHolder*> VarHolder::hold_vars;
|
||||
|
||||
void add_hold_vars(VarHolder* self) {
|
||||
VarHolder::hold_vars.push_front(self);
|
||||
self->iter = VarHolder::hold_vars.begin();
|
||||
if (!eager_execution) return;
|
||||
auto v = self->var;
|
||||
for (int i=0; i<5; i++) {
|
||||
auto op = v->input();
|
||||
if (!op) break;
|
||||
if (i==0 && op->name() == string("tape")) return;
|
||||
if (op->type() == OpType::other) break;
|
||||
if (op->type() == OpType::reduce) break;
|
||||
if (op->inputs().size() == 0)
|
||||
break;
|
||||
if (op->type() == OpType::broadcast)
|
||||
return;
|
||||
v = op->inputs().front();
|
||||
}
|
||||
self->sync(true);
|
||||
}
|
||||
|
||||
VarHolder::VarHolder(Var* v) : var(v) {
|
||||
add_hold_vars(this);
|
||||
// Var holder has both forward and backward liveness
|
||||
var->own_both_liveness();
|
||||
add_hold_vars(this);
|
||||
}
|
||||
|
||||
VarHolder::VarHolder(VarPtr&& v) {
|
||||
add_hold_vars(this);
|
||||
var = v.ptr;
|
||||
v.ptr = nullptr;
|
||||
add_hold_vars(this);
|
||||
}
|
||||
|
||||
VarHolder::VarHolder(VarHolder* v) : var(v->var) {
|
||||
|
|
Loading…
Reference in New Issue