polish interface

This commit is contained in:
Dun Liang 2020-07-14 15:48:14 +08:00
parent c6fc9c05ca
commit a60c73ea39
4 changed files with 82 additions and 25 deletions

1
.gitignore vendored
View File

@ -23,3 +23,4 @@ venv/
python/jittor.egg-info
dist/
!doc/source/*
core

View File

@ -614,37 +614,53 @@ can also be None)::
assert db.data == 0
'''
def __call__(self, *args, **kw):
args2 = list(args)
kw = dict(kw)
def __call__(self, *args):
args = list(args)
taped_inputs = []
taped_outputs = []
for i,v in enumerate(args2):
input_mask = [-1] * len(args)
for i,v in enumerate(args):
if isinstance(v, Var):
v = v.tape()
args2[i] = v
input_mask[i] = len(taped_inputs)
args[i] = v
taped_inputs.append(v)
for k,v in kw.items():
if isinstance(v, Var):
v = v.tape()
kw[k] = v
taped_inputs.append(v)
res = self.execute(*args2, **kw)
if isinstance(res, Var):
res = res.tape()
taped_outputs.append(res)
ori_res = self.execute(*args)
if not isinstance(ori_res, Sequence):
res = [ori_res]
else:
assert isinstance(res, Sequence)
res = list(res)
for i,v in enumerate(res):
if isinstance(v, Var):
v = v.tape()
res[i] = v
taped_outputs.append(v)
res = list(ori_res)
output_mask = [-1] * len(res)
for i,v in enumerate(res):
if isinstance(v, Var):
v = v.tape()
output_mask[i] = len(taped_outputs)
res[i] = v
taped_outputs.append(v)
self.input_mask = input_mask
self.output_mask = output_mask
# tape output and input together so
# backward treat them as one operator
tape_together(taped_inputs, taped_outputs, lambda *args: self.grad(*args))
return res
tape_together(taped_inputs, taped_outputs, self._grad)
if isinstance(ori_res, Sequence):
return res
else:
return res[0]
def _grad(self, *args):
new_args = ( (args[i] if i>=0 else None) for i in self.output_mask )
ret = self.grad(*new_args)
if not isinstance(ret, Sequence):
ret = (ret,)
new_ret = []
for i, r in enumerate(ret):
j = self.input_mask[i]
if j<0:
assert r is None, f"The {i}-th returned grad should be None, "\
"because the input value is not jittor variable."
else:
new_ret.append(r)
return new_ret
def dfs(self, parents, k, callback, callback_leave=None):
pass

View File

@ -180,5 +180,45 @@ class TestFunction(unittest.TestCase):
assert da.data == 4, da.data
assert db.data == 9
def test_multi_grads_multi_out4(self):
class MyFunc(Function):
def execute(self, x, z, y):
self.x = x
self.y = y
return x*y, "test", x/y
def grad(self, grad0, _, grad1):
assert _ is None
res = (grad0 * self.y, None, grad1 * self.x)
print(res)
return res
a = jt.array(3.0)
b = jt.array(4.0)
c,_,d = MyFunc()(a, "a", b)
da, db = jt.grad(c+d*3, [a, b])
assert da.data == 4, da.data
assert db.data == 9
def test_multi_grads_multi_out5(self):
class MyFunc(Function):
def execute(self, x, z, y):
self.x = x
self.y = y
return x*y, "test", x/y
def grad(self, grad0, _, grad1):
assert _ is None
res = (grad0 * self.y, 1, grad1 * self.x)
print(res)
return res
a = jt.array(3.0)
b = jt.array(4.0)
c,_,d = MyFunc()(a, "a", b)
expect_error(lambda : jt.grad(c+d*3, [a, b]))
def test_zz_last_test(self):
self.assertEqual(jt.liveness_info()["lived_vars"], 0)
if __name__ == "__main__":
unittest.main()

View File

@ -626,11 +626,11 @@ DEF_IS(GradCallback, T) from_py_object(PyObject* obj) {
}
};
if (!is_seq) {
CHECKop(n_i,==,1) << "returned grad size not match";
CHECKop(n_i,==,1) << n_i >> " returned grad required, but 1 given.";
check(0, ret.obj);
} else {
auto size = Py_SIZE(ret.obj);
CHECKop(n_i,==,size) << "returned grad size not match";
CHECKop(n_i,==,size) << n_i >> " returned grad required, but " >> size >> " given.";
auto arr = PySequence_Fast_ITEMS(ret.obj);
for (int i=0; i<size; i++) {
auto oi = arr[i];