tune interface

This commit is contained in:
Dun Liang 2020-07-14 14:34:08 +08:00
parent eab8bcc049
commit c6fc9c05ca
3 changed files with 28 additions and 30 deletions

View File

@ -579,8 +579,8 @@ can store value for backward computation)::
self.y = y
return x*y, x/y
def grad(self, grads):
return grads[0] * self.y, grads[1] * self.x
def grad(self, grad0, grad1):
return grad0 * self.y, grad1 * self.x
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
@ -601,9 +601,9 @@ can also be None)::
self.y = y
return x*y, x/y
def grad(self, grads):
assert grads[1] is None
return grads[0] * self.y, None
def grad(self, grad0, grad1):
assert grad1 is None
return grad0 * self.y, None
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
@ -643,7 +643,7 @@ can also be None)::
taped_outputs.append(v)
# tape output and input together so
# backward treat them as one operator
tape_together(taped_inputs, taped_outputs, lambda args: self.grad(args))
tape_together(taped_inputs, taped_outputs, lambda *args: self.grad(*args))
return res
def dfs(self, parents, k, callback, callback_leave=None):

View File

@ -18,8 +18,8 @@ class TestFunction(unittest.TestCase):
def execute(self, x):
return x+1
def grad(self, grads):
return grads[0]-2
def grad(self, grad):
return grad-2
a = jt.ones(1)
func = MyFunc()
b = func(a)
@ -32,8 +32,8 @@ class TestFunction(unittest.TestCase):
self.x = x
return x+1
def grad(self, grads):
return (grads[0]-2) * self.x
def grad(self, grad):
return (grad-2) * self.x
a = jt.ones(1) * 10
func = MyFunc()
b = func(a)
@ -47,8 +47,8 @@ class TestFunction(unittest.TestCase):
self.y = y
return x*y
def grad(self, grads):
return (grads[0]-2) * self.x
def grad(self, grad):
return (grad-2) * self.x
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
@ -62,8 +62,8 @@ class TestFunction(unittest.TestCase):
self.y = y
return x*y
def grad(self, grads):
return (grads[0]-2) * self.y, (grads[0]-2) * self.x
def grad(self, grad):
return (grad-2) * self.y, (grad-2) * self.x
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
@ -79,8 +79,8 @@ class TestFunction(unittest.TestCase):
self.y = y
return x*y
def grad(self, grads):
return (grads[0]-2) * self.y, None
def grad(self, grad):
return (grad-2) * self.y, None
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
@ -96,8 +96,8 @@ class TestFunction(unittest.TestCase):
self.y = y
return x*y, x/y
def grad(self, grads):
return grads[0] * self.y, grads[1] * self.x
def grad(self, grad0, grad1):
return grad0 * self.y, grad1 * self.x
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
@ -113,8 +113,8 @@ class TestFunction(unittest.TestCase):
self.y = y
return x*y, x/y
def grad(self, grads):
return grads[0] * self.y, grads[1] * self.x
def grad(self, grad0, grad1):
return grad0 * self.y, grad1 * self.x
a = jt.array(3.0)
b = jt.array(4.0)
b.stop_grad()
@ -131,9 +131,9 @@ class TestFunction(unittest.TestCase):
self.y = y
return x*y, x/y
def grad(self, grads):
assert grads[1] is None
return grads[0] * self.y, None
def grad(self, grad0, grad1):
assert grad1 is None
return grad0 * self.y, None
a = jt.array(3.0)
b = jt.array(4.0)
func = MyFunc()
@ -150,8 +150,8 @@ class TestFunction(unittest.TestCase):
self.y = y
return x*y, x/y
def grad(self, grads):
res = (grads[0] * self.y, grads[1] * self.x)
def grad(self, grad0, grad1):
res = (grad0 * self.y, grad1 * self.x)
print(res)
return res
a = jt.array(3.0)
@ -169,8 +169,8 @@ class TestFunction(unittest.TestCase):
self.y = y
return x*y, x/y
def grad(self, grads):
res = (grads[0] * self.y, grads[1] * self.x)
def grad(self, grad0, grad1):
res = (grad0 * self.y, grad1 * self.x)
print(res)
return res
a = jt.array(3.0)

View File

@ -613,10 +613,8 @@ DEF_IS(GradCallback, T) from_py_object(PyObject* obj) {
PyTuple_SET_ITEM(list.obj, i, Py_None);
}
}
PyObjHolder args(PyTuple_New(1));
PyTuple_SET_ITEM(args.obj, 0, list.release());
PyObjHolder ret(PyObject_Call(obj, args.obj, nullptr));
PyObjHolder ret(PyObject_Call(obj, list.obj, nullptr));
auto is_seq = PyList_CheckExact(ret.obj) || PyTuple_CheckExact(ret.obj);
auto check = [&](int i, PyObject* obj) {
if (obj == Py_None) {