This commit is contained in:
Gword 2020-07-13 22:09:56 +08:00
parent a6caacf54e
commit 96883622c6
4 changed files with 101 additions and 36 deletions

View File

@ -11,38 +11,60 @@ import jittor as jt
import numpy as np
class TestCodeOp(unittest.TestCase):
def forward_code(self, np, data):
a,b = data["inputs"]
c,d = data["outputs"]
np.add(a,b,out=c)
np.subtract(a,b,out=d)
def backward_code1(self, np, data):
dout = data["dout"]
out = data["outputs"][0]
np.copyto(out, dout)
def backward_code2(self, np, data):
dout = data["dout"]
out_index = data["out_index"]
out = data["outputs"][0]
if out_index==0:
np.copyto(out, dout)
else:
np.negative(dout, out)
def test(self):
def forward_code(np, data):
a = data["inputs"][0]
b = data["outputs"][0]
np.add(a,a,out=b)
def backward_code(np, data):
dout = data["dout"]
out = data["outputs"][0]
np.copyto(out, dout*2.0)
a = jt.random((5,1))
b = jt.numpy_code(
a.shape,
a.dtype,
[a],
forward_code,
[backward_code],
)
assert np.allclose(b.data,(a+a).data)
da = jt.grad(b,a)
one=np.ones(a.shape)
assert np.allclose(da.data,one*2.0)
def test_multi_input(self):
def forward_code(np, data):
a,b = data["inputs"]
c,d = data["outputs"]
np.add(a,b,out=c)
np.subtract(a,b,out=d)
def backward_code1(np, data):
dout = data["dout"]
out = data["outputs"][0]
np.copyto(out, dout)
def backward_code2(np, data):
dout = data["dout"]
out_index = data["out_index"]
out = data["outputs"][0]
if out_index==0:
np.copyto(out, dout)
else:
np.negative(dout, out)
a = jt.random((5,1))
b = jt.random((5,1))
c, d = jt.numpy_code(
[a.shape, a.shape],
[a.dtype, a.dtype],
[a, b],
self.forward_code,
[self.backward_code1,self.backward_code2],
forward_code,
[backward_code1,backward_code2],
)
assert np.allclose(c.data,(a+b).data)
assert np.allclose(d.data,(a-b).data)
dca, dcb = jt.grad(c,[a,b])
@ -54,5 +76,48 @@ class TestCodeOp(unittest.TestCase):
assert np.allclose(dda.data,one)
assert np.allclose(ddb.data,mone)
@unittest.skipIf(True, "Memory leak testing is not in progress, Skip")
def test_memory_leak(self):
def forward_code(np, data):
a,b = data["inputs"]
c,d = data["outputs"]
np.add(a,b,out=c)
np.subtract(a,b,out=d)
def backward_code1(np, data):
dout = data["dout"]
out = data["outputs"][0]
np.copyto(out, dout)
def backward_code2(np, data):
dout = data["dout"]
out_index = data["out_index"]
out = data["outputs"][0]
if out_index==0:
np.copyto(out, dout)
else:
np.negative(dout, out)
for i in range(1000000):
a = jt.random((10000,1))
b = jt.random((10000,1))
c, d = jt.numpy_code(
[a.shape, a.shape],
[a.dtype, a.dtype],
[a, b],
forward_code,
[backward_code1,backward_code2],
)
assert np.allclose(c.data,(a+b).data)
assert np.allclose(d.data,(a-b).data)
dca, dcb = jt.grad(c,[a,b])
dda, ddb = jt.grad(d,[a,b])
one=np.ones(a.shape)
mone=one*-1.0
assert np.allclose(dca.data,one)
assert np.allclose(dcb.data,one)
assert np.allclose(dda.data,one)
assert np.allclose(ddb.data,mone)
if __name__ == "__main__":
unittest.main()

View File

@ -57,7 +57,7 @@ NumpyCodeOp::NumpyCodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inpu
: _inputs(inputs), forward(forward), _results(move(results))
{
_outputs.push_back(create_output(shape, dtype));
CHECKop(_inputs.size(),<=,10)
CHECKop(_inputs.size(),<=,10);
ASSERT(_outputs[0]->num >= 0);
}

View File

@ -57,18 +57,18 @@ struct NumpyCodeOp : Op {
Example-1::
def forward_code(self, np, data):
a = data["inputs"]
b = data["outputs"]
def forward_code(np, data):
a = data["inputs"][0]
b = data["outputs"][0]
np.add(a,a,out=b)
def backward_code(self, np, data):
def backward_code(np, data):
dout = data["dout"]
out = data["outputs"][0]
np.copyto(out, dout)
np.copyto(out, dout*2.0)
a = jt.random((5,1))
c, d = jt.numpy_code(
b = jt.numpy_code(
a.shape,
a.dtype,
[a],
@ -78,18 +78,18 @@ struct NumpyCodeOp : Op {
Example-2::
def forward_code(self, np, data):
def forward_code(np, data):
a,b = data["inputs"]
c,d = data["outputs"]
np.add(a,b,out=c)
np.subtract(a,b,out=d)
def backward_code1(self, np, data):
def backward_code1(np, data):
dout = data["dout"]
out = data["outputs"][0]
np.copyto(out, dout)
def backward_code2(self, np, data):
def backward_code2(np, data):
dout = data["dout"]
out_index = data["out_index"]
out = data["outputs"][0]

View File

@ -578,8 +578,8 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj) {
// args = []
PyObjHolder args(PyTuple_New(2));
PyTuple_SET_ITEM(args.obj, 0, np.obj);
PyTuple_SET_ITEM(args.obj, 1, data.obj);
PyTuple_SET_ITEM(args.obj, 0, np.release());
PyTuple_SET_ITEM(args.obj, 1, data.release());
PyObjHolder ret(PyObject_Call(obj, args.obj, nullptr));
},
// deleter