mirror of https://github.com/Jittor/Jittor
polish interface
This commit is contained in:
parent
c6fc9c05ca
commit
a60c73ea39
|
@ -23,3 +23,4 @@ venv/
|
|||
python/jittor.egg-info
|
||||
dist/
|
||||
!doc/source/*
|
||||
core
|
||||
|
|
|
@ -614,37 +614,53 @@ can also be None)::
|
|||
assert db.data == 0
|
||||
|
||||
'''
|
||||
def __call__(self, *args, **kw):
|
||||
args2 = list(args)
|
||||
kw = dict(kw)
|
||||
def __call__(self, *args):
|
||||
args = list(args)
|
||||
taped_inputs = []
|
||||
taped_outputs = []
|
||||
for i,v in enumerate(args2):
|
||||
input_mask = [-1] * len(args)
|
||||
for i,v in enumerate(args):
|
||||
if isinstance(v, Var):
|
||||
v = v.tape()
|
||||
args2[i] = v
|
||||
input_mask[i] = len(taped_inputs)
|
||||
args[i] = v
|
||||
taped_inputs.append(v)
|
||||
for k,v in kw.items():
|
||||
if isinstance(v, Var):
|
||||
v = v.tape()
|
||||
kw[k] = v
|
||||
taped_inputs.append(v)
|
||||
res = self.execute(*args2, **kw)
|
||||
if isinstance(res, Var):
|
||||
res = res.tape()
|
||||
taped_outputs.append(res)
|
||||
ori_res = self.execute(*args)
|
||||
if not isinstance(ori_res, Sequence):
|
||||
res = [ori_res]
|
||||
else:
|
||||
assert isinstance(res, Sequence)
|
||||
res = list(res)
|
||||
for i,v in enumerate(res):
|
||||
if isinstance(v, Var):
|
||||
v = v.tape()
|
||||
res[i] = v
|
||||
taped_outputs.append(v)
|
||||
res = list(ori_res)
|
||||
output_mask = [-1] * len(res)
|
||||
for i,v in enumerate(res):
|
||||
if isinstance(v, Var):
|
||||
v = v.tape()
|
||||
output_mask[i] = len(taped_outputs)
|
||||
res[i] = v
|
||||
taped_outputs.append(v)
|
||||
self.input_mask = input_mask
|
||||
self.output_mask = output_mask
|
||||
# tape output and input together so
|
||||
# backward treat them as one operator
|
||||
tape_together(taped_inputs, taped_outputs, lambda *args: self.grad(*args))
|
||||
return res
|
||||
tape_together(taped_inputs, taped_outputs, self._grad)
|
||||
if isinstance(ori_res, Sequence):
|
||||
return res
|
||||
else:
|
||||
return res[0]
|
||||
|
||||
def _grad(self, *args):
|
||||
new_args = ( (args[i] if i>=0 else None) for i in self.output_mask )
|
||||
ret = self.grad(*new_args)
|
||||
if not isinstance(ret, Sequence):
|
||||
ret = (ret,)
|
||||
new_ret = []
|
||||
for i, r in enumerate(ret):
|
||||
j = self.input_mask[i]
|
||||
if j<0:
|
||||
assert r is None, f"The {i}-th returned grad should be None, "\
|
||||
"because the input value is not jittor variable."
|
||||
else:
|
||||
new_ret.append(r)
|
||||
return new_ret
|
||||
|
||||
def dfs(self, parents, k, callback, callback_leave=None):
|
||||
pass
|
||||
|
|
|
@ -180,5 +180,45 @@ class TestFunction(unittest.TestCase):
|
|||
assert da.data == 4, da.data
|
||||
assert db.data == 9
|
||||
|
||||
def test_multi_grads_multi_out4(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, z, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, "test", x/y
|
||||
|
||||
def grad(self, grad0, _, grad1):
|
||||
assert _ is None
|
||||
res = (grad0 * self.y, None, grad1 * self.x)
|
||||
print(res)
|
||||
return res
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
c,_,d = MyFunc()(a, "a", b)
|
||||
da, db = jt.grad(c+d*3, [a, b])
|
||||
assert da.data == 4, da.data
|
||||
assert db.data == 9
|
||||
|
||||
|
||||
def test_multi_grads_multi_out5(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, z, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, "test", x/y
|
||||
|
||||
def grad(self, grad0, _, grad1):
|
||||
assert _ is None
|
||||
res = (grad0 * self.y, 1, grad1 * self.x)
|
||||
print(res)
|
||||
return res
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
c,_,d = MyFunc()(a, "a", b)
|
||||
expect_error(lambda : jt.grad(c+d*3, [a, b]))
|
||||
|
||||
def test_zz_last_test(self):
|
||||
self.assertEqual(jt.liveness_info()["lived_vars"], 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -626,11 +626,11 @@ DEF_IS(GradCallback, T) from_py_object(PyObject* obj) {
|
|||
}
|
||||
};
|
||||
if (!is_seq) {
|
||||
CHECKop(n_i,==,1) << "returned grad size not match";
|
||||
CHECKop(n_i,==,1) << n_i >> " returned grad required, but 1 given.";
|
||||
check(0, ret.obj);
|
||||
} else {
|
||||
auto size = Py_SIZE(ret.obj);
|
||||
CHECKop(n_i,==,size) << "returned grad size not match";
|
||||
CHECKop(n_i,==,size) << n_i >> " returned grad required, but " >> size >> " given.";
|
||||
auto arr = PySequence_Fast_ITEMS(ret.obj);
|
||||
for (int i=0; i<size; i++) {
|
||||
auto oi = arr[i];
|
||||
|
|
Loading…
Reference in New Issue