mirror of https://github.com/Jittor/Jittor
tune interface
This commit is contained in:
parent
eab8bcc049
commit
c6fc9c05ca
|
@ -579,8 +579,8 @@ can store value for backward computation)::
|
||||||
self.y = y
|
self.y = y
|
||||||
return x*y, x/y
|
return x*y, x/y
|
||||||
|
|
||||||
def grad(self, grads):
|
def grad(self, grad0, grad1):
|
||||||
return grads[0] * self.y, grads[1] * self.x
|
return grad0 * self.y, grad1 * self.x
|
||||||
a = jt.array(3.0)
|
a = jt.array(3.0)
|
||||||
b = jt.array(4.0)
|
b = jt.array(4.0)
|
||||||
func = MyFunc()
|
func = MyFunc()
|
||||||
|
@ -601,9 +601,9 @@ can also be None)::
|
||||||
self.y = y
|
self.y = y
|
||||||
return x*y, x/y
|
return x*y, x/y
|
||||||
|
|
||||||
def grad(self, grads):
|
def grad(self, grad0, grad1):
|
||||||
assert grads[1] is None
|
assert grad1 is None
|
||||||
return grads[0] * self.y, None
|
return grad0 * self.y, None
|
||||||
a = jt.array(3.0)
|
a = jt.array(3.0)
|
||||||
b = jt.array(4.0)
|
b = jt.array(4.0)
|
||||||
func = MyFunc()
|
func = MyFunc()
|
||||||
|
@ -643,7 +643,7 @@ can also be None)::
|
||||||
taped_outputs.append(v)
|
taped_outputs.append(v)
|
||||||
# tape output and input together so
|
# tape output and input together so
|
||||||
# backward treat them as one operator
|
# backward treat them as one operator
|
||||||
tape_together(taped_inputs, taped_outputs, lambda args: self.grad(args))
|
tape_together(taped_inputs, taped_outputs, lambda *args: self.grad(*args))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def dfs(self, parents, k, callback, callback_leave=None):
|
def dfs(self, parents, k, callback, callback_leave=None):
|
||||||
|
|
|
@ -18,8 +18,8 @@ class TestFunction(unittest.TestCase):
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
return x+1
|
return x+1
|
||||||
|
|
||||||
def grad(self, grads):
|
def grad(self, grad):
|
||||||
return grads[0]-2
|
return grad-2
|
||||||
a = jt.ones(1)
|
a = jt.ones(1)
|
||||||
func = MyFunc()
|
func = MyFunc()
|
||||||
b = func(a)
|
b = func(a)
|
||||||
|
@ -32,8 +32,8 @@ class TestFunction(unittest.TestCase):
|
||||||
self.x = x
|
self.x = x
|
||||||
return x+1
|
return x+1
|
||||||
|
|
||||||
def grad(self, grads):
|
def grad(self, grad):
|
||||||
return (grads[0]-2) * self.x
|
return (grad-2) * self.x
|
||||||
a = jt.ones(1) * 10
|
a = jt.ones(1) * 10
|
||||||
func = MyFunc()
|
func = MyFunc()
|
||||||
b = func(a)
|
b = func(a)
|
||||||
|
@ -47,8 +47,8 @@ class TestFunction(unittest.TestCase):
|
||||||
self.y = y
|
self.y = y
|
||||||
return x*y
|
return x*y
|
||||||
|
|
||||||
def grad(self, grads):
|
def grad(self, grad):
|
||||||
return (grads[0]-2) * self.x
|
return (grad-2) * self.x
|
||||||
a = jt.array(3.0)
|
a = jt.array(3.0)
|
||||||
b = jt.array(4.0)
|
b = jt.array(4.0)
|
||||||
func = MyFunc()
|
func = MyFunc()
|
||||||
|
@ -62,8 +62,8 @@ class TestFunction(unittest.TestCase):
|
||||||
self.y = y
|
self.y = y
|
||||||
return x*y
|
return x*y
|
||||||
|
|
||||||
def grad(self, grads):
|
def grad(self, grad):
|
||||||
return (grads[0]-2) * self.y, (grads[0]-2) * self.x
|
return (grad-2) * self.y, (grad-2) * self.x
|
||||||
a = jt.array(3.0)
|
a = jt.array(3.0)
|
||||||
b = jt.array(4.0)
|
b = jt.array(4.0)
|
||||||
func = MyFunc()
|
func = MyFunc()
|
||||||
|
@ -79,8 +79,8 @@ class TestFunction(unittest.TestCase):
|
||||||
self.y = y
|
self.y = y
|
||||||
return x*y
|
return x*y
|
||||||
|
|
||||||
def grad(self, grads):
|
def grad(self, grad):
|
||||||
return (grads[0]-2) * self.y, None
|
return (grad-2) * self.y, None
|
||||||
a = jt.array(3.0)
|
a = jt.array(3.0)
|
||||||
b = jt.array(4.0)
|
b = jt.array(4.0)
|
||||||
func = MyFunc()
|
func = MyFunc()
|
||||||
|
@ -96,8 +96,8 @@ class TestFunction(unittest.TestCase):
|
||||||
self.y = y
|
self.y = y
|
||||||
return x*y, x/y
|
return x*y, x/y
|
||||||
|
|
||||||
def grad(self, grads):
|
def grad(self, grad0, grad1):
|
||||||
return grads[0] * self.y, grads[1] * self.x
|
return grad0 * self.y, grad1 * self.x
|
||||||
a = jt.array(3.0)
|
a = jt.array(3.0)
|
||||||
b = jt.array(4.0)
|
b = jt.array(4.0)
|
||||||
func = MyFunc()
|
func = MyFunc()
|
||||||
|
@ -113,8 +113,8 @@ class TestFunction(unittest.TestCase):
|
||||||
self.y = y
|
self.y = y
|
||||||
return x*y, x/y
|
return x*y, x/y
|
||||||
|
|
||||||
def grad(self, grads):
|
def grad(self, grad0, grad1):
|
||||||
return grads[0] * self.y, grads[1] * self.x
|
return grad0 * self.y, grad1 * self.x
|
||||||
a = jt.array(3.0)
|
a = jt.array(3.0)
|
||||||
b = jt.array(4.0)
|
b = jt.array(4.0)
|
||||||
b.stop_grad()
|
b.stop_grad()
|
||||||
|
@ -131,9 +131,9 @@ class TestFunction(unittest.TestCase):
|
||||||
self.y = y
|
self.y = y
|
||||||
return x*y, x/y
|
return x*y, x/y
|
||||||
|
|
||||||
def grad(self, grads):
|
def grad(self, grad0, grad1):
|
||||||
assert grads[1] is None
|
assert grad1 is None
|
||||||
return grads[0] * self.y, None
|
return grad0 * self.y, None
|
||||||
a = jt.array(3.0)
|
a = jt.array(3.0)
|
||||||
b = jt.array(4.0)
|
b = jt.array(4.0)
|
||||||
func = MyFunc()
|
func = MyFunc()
|
||||||
|
@ -150,8 +150,8 @@ class TestFunction(unittest.TestCase):
|
||||||
self.y = y
|
self.y = y
|
||||||
return x*y, x/y
|
return x*y, x/y
|
||||||
|
|
||||||
def grad(self, grads):
|
def grad(self, grad0, grad1):
|
||||||
res = (grads[0] * self.y, grads[1] * self.x)
|
res = (grad0 * self.y, grad1 * self.x)
|
||||||
print(res)
|
print(res)
|
||||||
return res
|
return res
|
||||||
a = jt.array(3.0)
|
a = jt.array(3.0)
|
||||||
|
@ -169,8 +169,8 @@ class TestFunction(unittest.TestCase):
|
||||||
self.y = y
|
self.y = y
|
||||||
return x*y, x/y
|
return x*y, x/y
|
||||||
|
|
||||||
def grad(self, grads):
|
def grad(self, grad0, grad1):
|
||||||
res = (grads[0] * self.y, grads[1] * self.x)
|
res = (grad0 * self.y, grad1 * self.x)
|
||||||
print(res)
|
print(res)
|
||||||
return res
|
return res
|
||||||
a = jt.array(3.0)
|
a = jt.array(3.0)
|
||||||
|
|
|
@ -613,10 +613,8 @@ DEF_IS(GradCallback, T) from_py_object(PyObject* obj) {
|
||||||
PyTuple_SET_ITEM(list.obj, i, Py_None);
|
PyTuple_SET_ITEM(list.obj, i, Py_None);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
PyObjHolder args(PyTuple_New(1));
|
|
||||||
PyTuple_SET_ITEM(args.obj, 0, list.release());
|
|
||||||
|
|
||||||
PyObjHolder ret(PyObject_Call(obj, args.obj, nullptr));
|
PyObjHolder ret(PyObject_Call(obj, list.obj, nullptr));
|
||||||
auto is_seq = PyList_CheckExact(ret.obj) || PyTuple_CheckExact(ret.obj);
|
auto is_seq = PyList_CheckExact(ret.obj) || PyTuple_CheckExact(ret.obj);
|
||||||
auto check = [&](int i, PyObject* obj) {
|
auto check = [&](int i, PyObject* obj) {
|
||||||
if (obj == Py_None) {
|
if (obj == Py_None) {
|
||||||
|
|
Loading…
Reference in New Issue