From c6fc9c05ca460ac7d524b5507f66a35191b59f5a Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Tue, 14 Jul 2020 14:34:08 +0800 Subject: [PATCH] tune interface --- python/jittor/__init__.py | 12 ++++----- python/jittor/test/test_function.py | 42 ++++++++++++++--------------- src/pyjt/py_converter.h | 4 +-- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 47038a66..e900f062 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -579,8 +579,8 @@ can store value for backward computation):: self.y = y return x*y, x/y - def grad(self, grads): - return grads[0] * self.y, grads[1] * self.x + def grad(self, grad0, grad1): + return grad0 * self.y, grad1 * self.x a = jt.array(3.0) b = jt.array(4.0) func = MyFunc() @@ -601,9 +601,9 @@ can also be None):: self.y = y return x*y, x/y - def grad(self, grads): - assert grads[1] is None - return grads[0] * self.y, None + def grad(self, grad0, grad1): + assert grad1 is None + return grad0 * self.y, None a = jt.array(3.0) b = jt.array(4.0) func = MyFunc() @@ -643,7 +643,7 @@ can also be None):: taped_outputs.append(v) # tape output and input together so # backward treat them as one operator - tape_together(taped_inputs, taped_outputs, lambda args: self.grad(args)) + tape_together(taped_inputs, taped_outputs, lambda *args: self.grad(*args)) return res def dfs(self, parents, k, callback, callback_leave=None): diff --git a/python/jittor/test/test_function.py b/python/jittor/test/test_function.py index b914a665..6577b75c 100644 --- a/python/jittor/test/test_function.py +++ b/python/jittor/test/test_function.py @@ -18,8 +18,8 @@ class TestFunction(unittest.TestCase): def execute(self, x): return x+1 - def grad(self, grads): - return grads[0]-2 + def grad(self, grad): + return grad-2 a = jt.ones(1) func = MyFunc() b = func(a) @@ -32,8 +32,8 @@ class TestFunction(unittest.TestCase): self.x = x return x+1 - def grad(self, grads): - return (grads[0]-2) * self.x + def grad(self, grad): + return (grad-2) * self.x a = jt.ones(1) * 10 func = MyFunc() b = func(a) @@ -47,8 +47,8 @@ class TestFunction(unittest.TestCase): self.y = y return x*y - def grad(self, grads): - return (grads[0]-2) * self.x + def grad(self, grad): + return (grad-2) * self.x a = jt.array(3.0) b = jt.array(4.0) func = MyFunc() @@ -62,8 +62,8 @@ class TestFunction(unittest.TestCase): self.y = y return x*y - def grad(self, grads): - return (grads[0]-2) * self.y, (grads[0]-2) * self.x + def grad(self, grad): + return (grad-2) * self.y, (grad-2) * self.x a = jt.array(3.0) b = jt.array(4.0) func = MyFunc() @@ -79,8 +79,8 @@ class TestFunction(unittest.TestCase): self.y = y return x*y - def grad(self, grads): - return (grads[0]-2) * self.y, None + def grad(self, grad): + return (grad-2) * self.y, None a = jt.array(3.0) b = jt.array(4.0) func = MyFunc() @@ -96,8 +96,8 @@ class TestFunction(unittest.TestCase): self.y = y return x*y, x/y - def grad(self, grads): - return grads[0] * self.y, grads[1] * self.x + def grad(self, grad0, grad1): + return grad0 * self.y, grad1 * self.x a = jt.array(3.0) b = jt.array(4.0) func = MyFunc() @@ -113,8 +113,8 @@ class TestFunction(unittest.TestCase): self.y = y return x*y, x/y - def grad(self, grads): - return grads[0] * self.y, grads[1] * self.x + def grad(self, grad0, grad1): + return grad0 * self.y, grad1 * self.x a = jt.array(3.0) b = jt.array(4.0) b.stop_grad() @@ -131,9 +131,9 @@ class TestFunction(unittest.TestCase): self.y = y return x*y, x/y - def grad(self, grads): - assert grads[1] is None - return grads[0] * self.y, None + def grad(self, grad0, grad1): + assert grad1 is None + return grad0 * self.y, None a = jt.array(3.0) b = jt.array(4.0) func = MyFunc() @@ -150,8 +150,8 @@ class TestFunction(unittest.TestCase): self.y = y return x*y, x/y - def grad(self, grads): - res = (grads[0] * self.y, grads[1] * self.x) + def grad(self, grad0, grad1): + res = (grad0 * self.y, grad1 * self.x) print(res) return res a = jt.array(3.0) @@ -169,8 +169,8 @@ class TestFunction(unittest.TestCase): self.y = y return x*y, x/y - def grad(self, grads): - res = (grads[0] * self.y, grads[1] * self.x) + def grad(self, grad0, grad1): + res = (grad0 * self.y, grad1 * self.x) print(res) return res a = jt.array(3.0) diff --git a/src/pyjt/py_converter.h b/src/pyjt/py_converter.h index fc2cc127..66fc9d1f 100644 --- a/src/pyjt/py_converter.h +++ b/src/pyjt/py_converter.h @@ -613,10 +613,8 @@ DEF_IS(GradCallback, T) from_py_object(PyObject* obj) { PyTuple_SET_ITEM(list.obj, i, Py_None); } } - PyObjHolder args(PyTuple_New(1)); - PyTuple_SET_ITEM(args.obj, 0, list.release()); - PyObjHolder ret(PyObject_Call(obj, args.obj, nullptr)); + PyObjHolder ret(PyObject_Call(obj, list.obj, nullptr)); auto is_seq = PyList_CheckExact(ret.obj) || PyTuple_CheckExact(ret.obj); auto check = [&](int i, PyObject* obj) { if (obj == Py_None) {