mirror of https://github.com/Jittor/Jittor
tune interface
This commit is contained in:
parent
eab8bcc049
commit
c6fc9c05ca
|
@ -579,8 +579,8 @@ can store value for backward computation)::
|
|||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
return grads[0] * self.y, grads[1] * self.x
|
||||
def grad(self, grad0, grad1):
|
||||
return grad0 * self.y, grad1 * self.x
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
|
@ -601,9 +601,9 @@ can also be None)::
|
|||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
assert grads[1] is None
|
||||
return grads[0] * self.y, None
|
||||
def grad(self, grad0, grad1):
|
||||
assert grad1 is None
|
||||
return grad0 * self.y, None
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
|
@ -643,7 +643,7 @@ can also be None)::
|
|||
taped_outputs.append(v)
|
||||
# tape output and input together so
|
||||
# backward treat them as one operator
|
||||
tape_together(taped_inputs, taped_outputs, lambda args: self.grad(args))
|
||||
tape_together(taped_inputs, taped_outputs, lambda *args: self.grad(*args))
|
||||
return res
|
||||
|
||||
def dfs(self, parents, k, callback, callback_leave=None):
|
||||
|
|
|
@ -18,8 +18,8 @@ class TestFunction(unittest.TestCase):
|
|||
def execute(self, x):
|
||||
return x+1
|
||||
|
||||
def grad(self, grads):
|
||||
return grads[0]-2
|
||||
def grad(self, grad):
|
||||
return grad-2
|
||||
a = jt.ones(1)
|
||||
func = MyFunc()
|
||||
b = func(a)
|
||||
|
@ -32,8 +32,8 @@ class TestFunction(unittest.TestCase):
|
|||
self.x = x
|
||||
return x+1
|
||||
|
||||
def grad(self, grads):
|
||||
return (grads[0]-2) * self.x
|
||||
def grad(self, grad):
|
||||
return (grad-2) * self.x
|
||||
a = jt.ones(1) * 10
|
||||
func = MyFunc()
|
||||
b = func(a)
|
||||
|
@ -47,8 +47,8 @@ class TestFunction(unittest.TestCase):
|
|||
self.y = y
|
||||
return x*y
|
||||
|
||||
def grad(self, grads):
|
||||
return (grads[0]-2) * self.x
|
||||
def grad(self, grad):
|
||||
return (grad-2) * self.x
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
|
@ -62,8 +62,8 @@ class TestFunction(unittest.TestCase):
|
|||
self.y = y
|
||||
return x*y
|
||||
|
||||
def grad(self, grads):
|
||||
return (grads[0]-2) * self.y, (grads[0]-2) * self.x
|
||||
def grad(self, grad):
|
||||
return (grad-2) * self.y, (grad-2) * self.x
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
|
@ -79,8 +79,8 @@ class TestFunction(unittest.TestCase):
|
|||
self.y = y
|
||||
return x*y
|
||||
|
||||
def grad(self, grads):
|
||||
return (grads[0]-2) * self.y, None
|
||||
def grad(self, grad):
|
||||
return (grad-2) * self.y, None
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
|
@ -96,8 +96,8 @@ class TestFunction(unittest.TestCase):
|
|||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
return grads[0] * self.y, grads[1] * self.x
|
||||
def grad(self, grad0, grad1):
|
||||
return grad0 * self.y, grad1 * self.x
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
|
@ -113,8 +113,8 @@ class TestFunction(unittest.TestCase):
|
|||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
return grads[0] * self.y, grads[1] * self.x
|
||||
def grad(self, grad0, grad1):
|
||||
return grad0 * self.y, grad1 * self.x
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
b.stop_grad()
|
||||
|
@ -131,9 +131,9 @@ class TestFunction(unittest.TestCase):
|
|||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
assert grads[1] is None
|
||||
return grads[0] * self.y, None
|
||||
def grad(self, grad0, grad1):
|
||||
assert grad1 is None
|
||||
return grad0 * self.y, None
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
func = MyFunc()
|
||||
|
@ -150,8 +150,8 @@ class TestFunction(unittest.TestCase):
|
|||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
res = (grads[0] * self.y, grads[1] * self.x)
|
||||
def grad(self, grad0, grad1):
|
||||
res = (grad0 * self.y, grad1 * self.x)
|
||||
print(res)
|
||||
return res
|
||||
a = jt.array(3.0)
|
||||
|
@ -169,8 +169,8 @@ class TestFunction(unittest.TestCase):
|
|||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
res = (grads[0] * self.y, grads[1] * self.x)
|
||||
def grad(self, grad0, grad1):
|
||||
res = (grad0 * self.y, grad1 * self.x)
|
||||
print(res)
|
||||
return res
|
||||
a = jt.array(3.0)
|
||||
|
|
|
@ -613,10 +613,8 @@ DEF_IS(GradCallback, T) from_py_object(PyObject* obj) {
|
|||
PyTuple_SET_ITEM(list.obj, i, Py_None);
|
||||
}
|
||||
}
|
||||
PyObjHolder args(PyTuple_New(1));
|
||||
PyTuple_SET_ITEM(args.obj, 0, list.release());
|
||||
|
||||
PyObjHolder ret(PyObject_Call(obj, args.obj, nullptr));
|
||||
PyObjHolder ret(PyObject_Call(obj, list.obj, nullptr));
|
||||
auto is_seq = PyList_CheckExact(ret.obj) || PyTuple_CheckExact(ret.obj);
|
||||
auto check = [&](int i, PyObject* obj) {
|
||||
if (obj == Py_None) {
|
||||
|
|
Loading…
Reference in New Issue