diff --git a/python/jittor/test/test_function.py b/python/jittor/test/test_function.py index 46d23cd6..b914a665 100644 --- a/python/jittor/test/test_function.py +++ b/python/jittor/test/test_function.py @@ -162,5 +162,23 @@ class TestFunction(unittest.TestCase): assert da.data == 4, da.data assert db.data == 9 + def test_multi_grads_multi_out3(self): + class MyFunc(Function): + def execute(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + def grad(self, grads): + res = (grads[0] * self.y, grads[1] * self.x) + print(res) + return res + a = jt.array(3.0) + b = jt.array(4.0) + c,d = MyFunc()(a, b) + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4, da.data + assert db.data == 9 + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/src/grad.cc b/src/grad.cc index 53d3c10a..b5d53889 100644 --- a/src/grad.cc +++ b/src/grad.cc @@ -35,125 +35,6 @@ inline static void assign_attrs(Var* a, Var* b) { a->flags.set(NodeFlags::_stop_fuse); } -void tape_together( - const vector& taped_inputs, - const vector& taped_outputs, - GradCallback&& grad_callback -) { - auto tapes = new Tapes(); - tapes->total = tapes->ref = taped_inputs.size() + taped_outputs.size(); - tapes->callback = move(grad_callback); - tapes->flags.set(NodeFlags::_grads); - for (int i=0; ivar; - auto op = (TapeOp*)v->input(); - ASSERT(op); - op->flags.set(NodeFlags::_tape); - tapes->_inputs.emplace_back(op->inputs().front()); - op->tapes = tapes; - } - for (int i=0; ivar; - auto op = (TapeOp*)v->input(); - ASSERT(op); - op->flags.set(NodeFlags::_tape); - tapes->_outputs.emplace_back(v,0); - op->tapes = tapes; - } -} - - -template -void bfs_backward_with_tape(vector& queue, Func&& func) { - auto t = ++Node::tflag_count; - size_t i=0; - for (Node* node : queue) node->tflag = t; - while (i < queue.size()) { - Node* node = queue[i++]; - for (auto i : node->_inputs) { - auto inode = i.node; - if (inode->flags.get(NodeFlags::_tape)) { - Tapes* t = ((TapeOp*)inode)->tapes; - inode = t; - ASSERT(t->ref == t->total); - } - if (inode->tflag != t && func(inode)) { - inode->tflag = t; - queue.push_back(inode); - } - } - } -} - -template -void bfs_backward_with_tape(vector& seed, vector& queue, Func&& func) { - for (Node* node : seed) - if (func(node)) queue.push_back(node); - bfs_backward_with_tape(queue, func); -} - -template -void bfs_forward_with_tape(vector& queue, Func&& func) { - auto t = ++Node::tflag_count; - size_t i=0; - for (Node* node : queue) node->tflag = t; - while (i < queue.size()) { - Node* node = queue[i++]; - for (auto o : node->_outputs) { - auto onode = o.node; - if (onode->flags.get(NodeFlags::_tape)) { - Tapes* t = ((TapeOp*)onode)->tapes; - ASSERT(t->ref == t->total) << t->ref << t->total; - onode = t; - } - if (onode->tflag != t && func(onode)) { - onode->tflag = t; - queue.push_back(onode); - } - } - } -} - - -template -void toplogical_sort_backward_with_tape(vector& nodes, vector& sorted, Func&& func) { - auto t = ++Node::tflag_count; - sorted.reserve(nodes.size()); - for (auto node : nodes) node->tflag = t; - for (auto node : nodes) { - auto& deps = node->custom_data; - deps = 0; - for (auto o : node->_outputs) { - auto onode = o.node; - if (onode->flags.get(NodeFlags::_tape)) { - Tapes* t = ((TapeOp*)onode)->tapes; - onode = t; - } - if (onode->tflag == t) - deps++; - } - if (deps == 0) sorted.push_back(node); - } - size_t i=0; - while (i < sorted.size()) { - Node* node = sorted[i++]; - for (auto i : node->_inputs) { - auto inode = i.node; - if (inode->flags.get(NodeFlags::_tape)) { - Tapes* t = ((TapeOp*)inode)->tapes; - inode = t; - } - if (inode->tflag == t) { - inode->custom_data--; - if (inode->custom_data == 0) - sorted.push_back(inode); - } - } - func(node); - } - ASSERTop(nodes.size(),==,sorted.size()); -} - vector grad(Var* loss, vector targets) { LOGvv << "loss:" >> loss << "targets:" >> targets; CHECK(loss->is_float()) << "Loss should be float"; @@ -163,13 +44,13 @@ vector grad(Var* loss, vector targets) { vector ts(targets.begin(), targets.end()); // bfs visit find all successors of targets LOGvv << "Size of successors:" << ts.size(); - bfs_forward_with_tape(ts, [](Node*){ return true; }); + bfs_forward(ts, [](Node*){ return true; }); vector gnodes; gnodes.reserve(ts.size()); auto nt = Node::tflag_count; if (loss->tflag == nt) gnodes.push_back(loss); - bfs_backward_with_tape(gnodes, [&](Node* node) { + bfs_backward(gnodes, [&](Node* node) { if (node->tflag != nt) return false; if (node->is_stop_grad()) @@ -182,7 +63,7 @@ vector grad(Var* loss, vector targets) { LOGvv << "Size of grad nodes:" << gnodes.size(); vector sorted; - toplogical_sort_backward_with_tape(gnodes, sorted, [](Node*){}); + toplogical_sort_backward(gnodes, sorted, [](Node*){}); nt = Node::tflag_count; vector gvars; gvars.reserve(sorted.size()); @@ -217,9 +98,6 @@ vector grad(Var* loss, vector targets) { Var* var = gvars[i]; for (auto it : var->outputs_with_index()) { Op* op = it.op; - if (op->flags.get(NodeFlags::_tape)) { - op = ((TapeOp*)op)->tapes; - } auto index = it.index; if (op->tflag != nt) continue; id_buffer.emplace_back(op, index); @@ -302,7 +180,7 @@ vector grad(Var* loss, vector targets) { Var* dout = grads[id]; VarPtr dvar = make_grad(op, out, dout, var, index); registe_node_trace_grad(dvar.ptr, op, index); - if (dvar) + if (dvar && var->num) ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size()) << "dvar" << dvar << "var" << var; if (!grad) diff --git a/src/node.h b/src/node.h index 749795fb..2d306c1b 100644 --- a/src/node.h +++ b/src/node.h @@ -44,10 +44,8 @@ struct NodeFlags { _vary_shape=_n+3, // bit4~5: op type _op_type=_n+4, _op_type_nbits=2, - // bit6: is tape op - _tape=_n+6, - // bit7: backprop grad at ones - _grads=_n+7, + // bit6: backprop grad at ones + _grads=_n+6, }; inline void set(Flags f, int a=1, int nbits=1) { diff --git a/src/ops/tape_op.cc b/src/ops/tape_op.cc index a044bc80..61aaff3a 100644 --- a/src/ops/tape_op.cc +++ b/src/ops/tape_op.cc @@ -15,7 +15,7 @@ namespace jittor { static auto make_tape = get_op_info("tape") .get_constructor(); -TapeOp::TapeOp(Var* x) : tapes(nullptr) { +TapeOp::TapeOp(Var* x) { flags.set(NodeFlags::_cpu); flags.set(NodeFlags::_cuda); auto y = create_output(nullptr, x->dtype()); @@ -23,16 +23,6 @@ TapeOp::TapeOp(Var* x) : tapes(nullptr) { y->name = x->name; } -TapeOp::~TapeOp() { - if (tapes) { - if (! --tapes->ref) { - tapes->_inputs.clear(); - tapes->_outputs.clear(); - delete tapes; - } - } -} - VarPtr TapeOp::grad(Var* out, Var* dout, Var* v, int v_index) { return dout; } @@ -48,4 +38,49 @@ void Tapes::grads(Var** douts, VarPtr* dins) { callback.func(_outputs.size(), douts, _inputs.size(), dins); } +Tapes::Tapes( + const vector& taped_inputs, + const vector& taped_outputs, + GradCallback&& grad_callback +) { + callback = move(grad_callback); + flags.set(NodeFlags::_grads); + + /* + stop grad stop grad + i --> tape --> t_i ---> .... ---> o --> tape --> t_o + | ^ + +---> tapes ------------------------------+ + */ + // set tape output + for (int i=0; iadd_inputs({this}); + auto v = taped_outputs[i]->var; + auto op = v->input(); + op->add_inputs(vector{out.ptr}); + } + // set tapes input + vector tin(taped_inputs.size()); + for (int i=0; ivar->input()->inputs().front(); + } + add_inputs(tin); + // stop grad for input and output + for (int i=0; ivar->set_stop_grad(); + } + for (int i=0; ivar->input()->inputs().front()->set_stop_grad(); + } +} + +void tape_together( + const vector& taped_inputs, + const vector& taped_outputs, + GradCallback&& grad_callback +) { + new Tapes(taped_inputs, taped_outputs, move(grad_callback)); +} + } // jittor \ No newline at end of file diff --git a/src/ops/tape_op.h b/src/ops/tape_op.h index 414bf01d..14cf0134 100644 --- a/src/ops/tape_op.h +++ b/src/ops/tape_op.h @@ -33,9 +33,7 @@ struct GradCallback { }; struct TapeOp final : Op { - Tapes* tapes; TapeOp(Var* x); - ~TapeOp(); const char* name() const override { return "tape"; } VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; @@ -44,8 +42,12 @@ struct TapeOp final : Op { struct Tapes final : Op { - int ref, total; GradCallback callback; + Tapes( + const vector& taped_inputs, + const vector& taped_outputs, + GradCallback&& grad_callback + ); const char* name() const override { return "tapes"; } void grads(Var** douts, VarPtr* dins) override; };