code op with output

This commit is contained in:
Dun Liang 2020-10-16 15:02:04 +08:00
parent aac16547d0
commit db30ead79e
6 changed files with 71 additions and 16 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.0.3'
__version__ = '1.2.0.4'
from . import lock
with lock.lock_scope():
from . import compiler

View File

@ -248,15 +248,15 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
LOGf << "Wrong output size of" << \"{op_name}\";
}}
if (_op->flags.get(NodeFlags::_forwarded)) {{
VarPtr output(move(_op->outputs_holder[0]));
VarPtr _out(move(_op->outputs_holder[0]));
delete _op;
return output;
return _out;
}}
_op->outputs_holder[0]->set_inputs({{_op}});
VarPtr output(move(_op->outputs_holder[0]));
VarPtr _out(move(_op->outputs_holder[0]));
{src.replace("->var","")};
_op->init();
return output;
return _out;
}}
""")
else:
@ -264,16 +264,16 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
vector<VarPtr> make_{cc_func_name}({", ".join(cc_make_args)}) {{
auto _op = new {op_name}({", ".join(op_make_args)});
if (_op->flags.get(NodeFlags::_forwarded)) {{
vector<VarPtr> outputs = move(_op->outputs_holder);
vector<VarPtr> _outs = move(_op->outputs_holder);
delete _op;
return outputs;
return _outs;
}}
vector<VarPtr> outputs = move(_op->outputs_holder);
for (uint i=0; i<outputs.size(); i++)
outputs[i]->set_inputs({{_op}});
vector<VarPtr> _outs = move(_op->outputs_holder);
for (uint i=0; i<_outs.size(); i++)
_outs[i]->set_inputs({{_op}});
{src.replace("->var","")};
_op->init();
return outputs;
return _outs;
}}
""")
if pybind_name == 'None':
@ -291,7 +291,14 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
/*{doc_string}*/
// @pyjt({",".join(pyjt_names)})
vector<VarHolder*> {cc_func_name}({", ".join(cc_args)}) {{
return make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));
{ f'return make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));'
if "replace_outputs" not in attrs else
f'''auto rt = make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));
ASSERT(rt.size() == outputs.size());
for (int i=0; i<outputs.size(); i++)
outputs[i]->assign(rt[i]);
return rt;
'''}
}}
""")
else:

View File

@ -115,6 +115,26 @@ class TestCodeOp(unittest.TestCase):
assert (b.data == [3,2]).all()
assert (c.data[:3] == [3,2,1]).all()
def test_return_multi_output(self):
a = jt.array([3,2,1])
b = jt.array([1,2])
c = jt.array([3,4,5,6])
jt.code([a], [b,c],
cpu_src="""
@alias(a, in0)
@alias(b, out0)
@alias(c, out1)
for (int i=0; i<a_shape0; i++) {
if (i<b_shape0) @b(i) += @a(i);
if (i<c_shape0) @c(i) += @a(i);
}
"""
)
assert b.shape == [2]
assert c.shape == [4]
assert (b.data == [4,4]).all()
assert (c.data[:3] == [6,6,6]).all()
def test_multi_output2(self):
a = jt.array([3,2,1])
b,c = jt.code([(1,), (1,)], [a.dtype, a.dtype], [a],

View File

@ -32,7 +32,6 @@ CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs,
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
_outputs.push_back(create_output(shape, dtype));
CHECKop(_inputs.size(),<=,10);
if (_outputs[0]->num < 0) {
flags.set(NodeFlags::_vary_shape);
@ -52,8 +51,6 @@ CodeOp::CodeOp(
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
CHECKop(shapes.size(),==,dtypes.size()) << "Number of outputs' shapes and dtypes should be the same";
_outputs.resize(shapes.size());
CHECKop(_inputs.size(),<=,10);
CHECKop(_outputs.size(),<=,10);
CHECKop(_outputs.size(),>,0);
for (int i=0; i<shapes.size(); i++) {
_outputs[i] = create_output(shapes[i], dtypes[i]);
@ -64,6 +61,27 @@ CodeOp::CodeOp(
}
}
CodeOp::CodeOp(
vector<Var*>&& inputs, vector<Var*>&& outputs,
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
: _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
{
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
_outputs.resize(outputs.size());
CHECKop(_outputs.size(),>,0);
for (int i=0; i<outputs.size(); i++) {
auto o = outputs[i];
_outputs[i] = create_output(o->shape, o->dtype());
_outputs[i]->share_with(o);
/*
TODO: vary shape not allowed in direct output
*/
}
}
VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
// Do not have grad to extras input

View File

@ -230,6 +230,10 @@ struct CodeOp : Op {
// @attrs(multiple_outputs)
CodeOp(vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
// @attrs(multiple_outputs,replace_outputs)
CodeOp(vector<Var*>&& inputs, vector<Var*>&& outputs, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
const char* name() const override { return "code"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;

View File

@ -431,7 +431,13 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj);
CHECK_IS_1(vector);
DEF_IS_1(vector, bool) is_type(PyObject* obj) {
return PyList_CheckExact(obj) || PyTuple_CheckExact(obj);
if (!(PyList_CheckExact(obj) || PyTuple_CheckExact(obj)))
return false;
auto size = Py_SIZE(obj);
if (!size)
return true;
auto arr = PySequence_Fast_ITEMS(obj);
return is_type<typename T::value_type>(arr[0]);
}
DEF_IS_1(vector, PyObject*) to_py_object(const T& a) {