add eager execution

This commit is contained in:
Dun Liang 2020-07-30 15:33:51 +08:00
parent a62b45d6ca
commit eae3fffcf1
6 changed files with 50 additions and 7 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.1.6.8'
__version__ = '1.1.7.0'
from . import lock
with lock.lock_scope():
from . import compiler
@ -284,6 +284,9 @@ def detach_inplace(x):
return x.swap(x.stop_grad().clone())
Var.start_grad = Var.detach_inplace = detach_inplace
def detach(x):
return x.detach()
def unsqueeze(x, dim):
shape = list(x.shape)
if dim < 0: dim += len(shape) + 1
@ -623,12 +626,17 @@ can also be None)::
'''
def __call__(self, *args):
backup = args
args = list(args)
taped_inputs = []
taped_outputs = []
input_mask = [-1] * len(args)
for i,v in enumerate(args):
if isinstance(v, Var):
if v.is_stop_grad():
# -2 in input_mask represents it is stop_grad
input_mask[i] = -2
continue
v = v.tape()
input_mask[i] = len(taped_inputs)
args[i] = v
@ -664,7 +672,8 @@ can also be None)::
for i, r in enumerate(ret):
j = self.input_mask[i]
if j<0:
assert r is None, f"{type(self)}'s {i}-th returned grad should be None, "\
# -2 in input_mask represents it is stop_grad
assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\
"because the input value is not jittor variable."
else:
new_ret.append(r)

View File

@ -21,5 +21,14 @@ class TestClone(unittest.TestCase):
c.stop_grad()
assert jt.number_of_lived_vars()==3
def test2(self):
a = jt.array([1,2])
print(a.detach())
@jt.flag_scope(eager_execution=1)
def test3(self):
a = jt.array([1,2])
print(a.detach())
if __name__ == "__main__":
unittest.main()

View File

@ -280,5 +280,14 @@ class TestFunction(unittest.TestCase):
assert t3 < t2 + 10, (t1,t2,t3)
self.assertEqual(jt.liveness_info()["lived_vars"], 0)
class TestFunctionWithEagerExecution(TestFunction):
@classmethod
def setUpClass(self):
jt.flags.eager_execution = 1
@classmethod
def tearDownClass(self):
jt.flags.eager_execution = 0
if __name__ == "__main__":
unittest.main()

View File

@ -1 +1 @@
b27082f9444a4e627f7dfc574d0114302ba27b5e
a62b45d6caf9c1c18a9118630ec8a591c576e635

View File

@ -146,7 +146,6 @@ void count_fuse(int64_t tt, int start_var_num, const vector<Op*>& ops, const vec
for (uint i=0; i<ops.size(); i++)
LOGvvvv << ops[i] << dis[i] << deps[i];
}
for (uint i=0; i<vars.size(); i++) {
Var* v = vars[i];
if (!v || v->tflag!=tt) {

View File

@ -15,24 +15,41 @@
#include "update_queue.h"
namespace jittor {
DEFINE_FLAG(int, eager_execution, 0, "Use Eager execution rather than lazy execution, This flag makes error message and traceback infomation better.");
list<VarHolder*> VarHolder::hold_vars;
void add_hold_vars(VarHolder* self) {
VarHolder::hold_vars.push_front(self);
self->iter = VarHolder::hold_vars.begin();
if (!eager_execution) return;
auto v = self->var;
for (int i=0; i<5; i++) {
auto op = v->input();
if (!op) break;
if (i==0 && op->name() == string("tape")) return;
if (op->type() == OpType::other) break;
if (op->type() == OpType::reduce) break;
if (op->inputs().size() == 0)
break;
if (op->type() == OpType::broadcast)
return;
v = op->inputs().front();
}
self->sync(true);
}
VarHolder::VarHolder(Var* v) : var(v) {
add_hold_vars(this);
// Var holder has both forward and backward liveness
var->own_both_liveness();
add_hold_vars(this);
}
VarHolder::VarHolder(VarPtr&& v) {
add_hold_vars(this);
var = v.ptr;
v.ptr = nullptr;
add_hold_vars(this);
}
VarHolder::VarHolder(VarHolder* v) : var(v->var) {