mirror of https://github.com/Jittor/Jittor
code op with output
This commit is contained in:
parent
aac16547d0
commit
db30ead79e
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.0.3'
|
||||
__version__ = '1.2.0.4'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
|
|
@ -248,15 +248,15 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
LOGf << "Wrong output size of" << \"{op_name}\";
|
||||
}}
|
||||
if (_op->flags.get(NodeFlags::_forwarded)) {{
|
||||
VarPtr output(move(_op->outputs_holder[0]));
|
||||
VarPtr _out(move(_op->outputs_holder[0]));
|
||||
delete _op;
|
||||
return output;
|
||||
return _out;
|
||||
}}
|
||||
_op->outputs_holder[0]->set_inputs({{_op}});
|
||||
VarPtr output(move(_op->outputs_holder[0]));
|
||||
VarPtr _out(move(_op->outputs_holder[0]));
|
||||
{src.replace("->var","")};
|
||||
_op->init();
|
||||
return output;
|
||||
return _out;
|
||||
}}
|
||||
""")
|
||||
else:
|
||||
|
@ -264,16 +264,16 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
vector<VarPtr> make_{cc_func_name}({", ".join(cc_make_args)}) {{
|
||||
auto _op = new {op_name}({", ".join(op_make_args)});
|
||||
if (_op->flags.get(NodeFlags::_forwarded)) {{
|
||||
vector<VarPtr> outputs = move(_op->outputs_holder);
|
||||
vector<VarPtr> _outs = move(_op->outputs_holder);
|
||||
delete _op;
|
||||
return outputs;
|
||||
return _outs;
|
||||
}}
|
||||
vector<VarPtr> outputs = move(_op->outputs_holder);
|
||||
for (uint i=0; i<outputs.size(); i++)
|
||||
outputs[i]->set_inputs({{_op}});
|
||||
vector<VarPtr> _outs = move(_op->outputs_holder);
|
||||
for (uint i=0; i<_outs.size(); i++)
|
||||
_outs[i]->set_inputs({{_op}});
|
||||
{src.replace("->var","")};
|
||||
_op->init();
|
||||
return outputs;
|
||||
return _outs;
|
||||
}}
|
||||
""")
|
||||
if pybind_name == 'None':
|
||||
|
@ -291,7 +291,14 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
/*{doc_string}*/
|
||||
// @pyjt({",".join(pyjt_names)})
|
||||
vector<VarHolder*> {cc_func_name}({", ".join(cc_args)}) {{
|
||||
return make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));
|
||||
{ f'return make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));'
|
||||
if "replace_outputs" not in attrs else
|
||||
f'''auto rt = make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));
|
||||
ASSERT(rt.size() == outputs.size());
|
||||
for (int i=0; i<outputs.size(); i++)
|
||||
outputs[i]->assign(rt[i]);
|
||||
return rt;
|
||||
'''}
|
||||
}}
|
||||
""")
|
||||
else:
|
||||
|
|
|
@ -115,6 +115,26 @@ class TestCodeOp(unittest.TestCase):
|
|||
assert (b.data == [3,2]).all()
|
||||
assert (c.data[:3] == [3,2,1]).all()
|
||||
|
||||
def test_return_multi_output(self):
|
||||
a = jt.array([3,2,1])
|
||||
b = jt.array([1,2])
|
||||
c = jt.array([3,4,5,6])
|
||||
jt.code([a], [b,c],
|
||||
cpu_src="""
|
||||
@alias(a, in0)
|
||||
@alias(b, out0)
|
||||
@alias(c, out1)
|
||||
for (int i=0; i<a_shape0; i++) {
|
||||
if (i<b_shape0) @b(i) += @a(i);
|
||||
if (i<c_shape0) @c(i) += @a(i);
|
||||
}
|
||||
"""
|
||||
)
|
||||
assert b.shape == [2]
|
||||
assert c.shape == [4]
|
||||
assert (b.data == [4,4]).all()
|
||||
assert (c.data[:3] == [6,6,6]).all()
|
||||
|
||||
def test_multi_output2(self):
|
||||
a = jt.array([3,2,1])
|
||||
b,c = jt.code([(1,), (1,)], [a.dtype, a.dtype], [a],
|
||||
|
|
|
@ -32,7 +32,6 @@ CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs,
|
|||
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
|
||||
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
|
||||
_outputs.push_back(create_output(shape, dtype));
|
||||
CHECKop(_inputs.size(),<=,10);
|
||||
|
||||
if (_outputs[0]->num < 0) {
|
||||
flags.set(NodeFlags::_vary_shape);
|
||||
|
@ -52,8 +51,6 @@ CodeOp::CodeOp(
|
|||
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
|
||||
CHECKop(shapes.size(),==,dtypes.size()) << "Number of outputs' shapes and dtypes should be the same";
|
||||
_outputs.resize(shapes.size());
|
||||
CHECKop(_inputs.size(),<=,10);
|
||||
CHECKop(_outputs.size(),<=,10);
|
||||
CHECKop(_outputs.size(),>,0);
|
||||
for (int i=0; i<shapes.size(); i++) {
|
||||
_outputs[i] = create_output(shapes[i], dtypes[i]);
|
||||
|
@ -64,6 +61,27 @@ CodeOp::CodeOp(
|
|||
}
|
||||
}
|
||||
|
||||
CodeOp::CodeOp(
|
||||
vector<Var*>&& inputs, vector<Var*>&& outputs,
|
||||
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
|
||||
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
|
||||
: _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
|
||||
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
|
||||
{
|
||||
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
|
||||
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
|
||||
_outputs.resize(outputs.size());
|
||||
CHECKop(_outputs.size(),>,0);
|
||||
for (int i=0; i<outputs.size(); i++) {
|
||||
auto o = outputs[i];
|
||||
_outputs[i] = create_output(o->shape, o->dtype());
|
||||
_outputs[i]->share_with(o);
|
||||
/*
|
||||
TODO: vary shape not allowed in direct output
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
// Do not have grad to extras input
|
||||
|
|
|
@ -230,6 +230,10 @@ struct CodeOp : Op {
|
|||
// @attrs(multiple_outputs)
|
||||
CodeOp(vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
|
||||
|
||||
// @attrs(multiple_outputs,replace_outputs)
|
||||
CodeOp(vector<Var*>&& inputs, vector<Var*>&& outputs, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
|
||||
|
||||
|
||||
const char* name() const override { return "code"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
|
|
|
@ -431,7 +431,13 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj);
|
|||
CHECK_IS_1(vector);
|
||||
|
||||
DEF_IS_1(vector, bool) is_type(PyObject* obj) {
|
||||
return PyList_CheckExact(obj) || PyTuple_CheckExact(obj);
|
||||
if (!(PyList_CheckExact(obj) || PyTuple_CheckExact(obj)))
|
||||
return false;
|
||||
auto size = Py_SIZE(obj);
|
||||
if (!size)
|
||||
return true;
|
||||
auto arr = PySequence_Fast_ITEMS(obj);
|
||||
return is_type<typename T::value_type>(arr[0]);
|
||||
}
|
||||
|
||||
DEF_IS_1(vector, PyObject*) to_py_object(const T& a) {
|
||||
|
|
Loading…
Reference in New Issue