fix function grad

This commit is contained in:
Dun Liang 2020-07-14 14:13:41 +08:00
parent f2bf93ae56
commit eab8bcc049
5 changed files with 75 additions and 144 deletions

View File

@ -162,5 +162,23 @@ class TestFunction(unittest.TestCase):
assert da.data == 4, da.data
assert db.data == 9
def test_multi_grads_multi_out3(self):
class MyFunc(Function):
def execute(self, x, y):
self.x = x
self.y = y
return x*y, x/y
def grad(self, grads):
res = (grads[0] * self.y, grads[1] * self.x)
print(res)
return res
a = jt.array(3.0)
b = jt.array(4.0)
c,d = MyFunc()(a, b)
da, db = jt.grad(c+d*3, [a, b])
assert da.data == 4, da.data
assert db.data == 9
if __name__ == "__main__":
unittest.main()

View File

@ -35,125 +35,6 @@ inline static void assign_attrs(Var* a, Var* b) {
a->flags.set(NodeFlags::_stop_fuse);
}
void tape_together(
const vector<VarHolder*>& taped_inputs,
const vector<VarHolder*>& taped_outputs,
GradCallback&& grad_callback
) {
auto tapes = new Tapes();
tapes->total = tapes->ref = taped_inputs.size() + taped_outputs.size();
tapes->callback = move(grad_callback);
tapes->flags.set(NodeFlags::_grads);
for (int i=0; i<taped_inputs.size(); i++) {
auto v = taped_inputs[i]->var;
auto op = (TapeOp*)v->input();
ASSERT(op);
op->flags.set(NodeFlags::_tape);
tapes->_inputs.emplace_back(op->inputs().front());
op->tapes = tapes;
}
for (int i=0; i<taped_outputs.size(); i++) {
auto v = taped_outputs[i]->var;
auto op = (TapeOp*)v->input();
ASSERT(op);
op->flags.set(NodeFlags::_tape);
tapes->_outputs.emplace_back(v,0);
op->tapes = tapes;
}
}
template <typename Func>
void bfs_backward_with_tape(vector<Node*>& queue, Func&& func) {
auto t = ++Node::tflag_count;
size_t i=0;
for (Node* node : queue) node->tflag = t;
while (i < queue.size()) {
Node* node = queue[i++];
for (auto i : node->_inputs) {
auto inode = i.node;
if (inode->flags.get(NodeFlags::_tape)) {
Tapes* t = ((TapeOp*)inode)->tapes;
inode = t;
ASSERT(t->ref == t->total);
}
if (inode->tflag != t && func(inode)) {
inode->tflag = t;
queue.push_back(inode);
}
}
}
}
template <typename Func>
void bfs_backward_with_tape(vector<Node*>& seed, vector<Node*>& queue, Func&& func) {
for (Node* node : seed)
if (func(node)) queue.push_back(node);
bfs_backward_with_tape(queue, func);
}
template <typename Func>
void bfs_forward_with_tape(vector<Node*>& queue, Func&& func) {
auto t = ++Node::tflag_count;
size_t i=0;
for (Node* node : queue) node->tflag = t;
while (i < queue.size()) {
Node* node = queue[i++];
for (auto o : node->_outputs) {
auto onode = o.node;
if (onode->flags.get(NodeFlags::_tape)) {
Tapes* t = ((TapeOp*)onode)->tapes;
ASSERT(t->ref == t->total) << t->ref << t->total;
onode = t;
}
if (onode->tflag != t && func(onode)) {
onode->tflag = t;
queue.push_back(onode);
}
}
}
}
template <typename Func>
void toplogical_sort_backward_with_tape(vector<Node*>& nodes, vector<Node*>& sorted, Func&& func) {
auto t = ++Node::tflag_count;
sorted.reserve(nodes.size());
for (auto node : nodes) node->tflag = t;
for (auto node : nodes) {
auto& deps = node->custom_data;
deps = 0;
for (auto o : node->_outputs) {
auto onode = o.node;
if (onode->flags.get(NodeFlags::_tape)) {
Tapes* t = ((TapeOp*)onode)->tapes;
onode = t;
}
if (onode->tflag == t)
deps++;
}
if (deps == 0) sorted.push_back(node);
}
size_t i=0;
while (i < sorted.size()) {
Node* node = sorted[i++];
for (auto i : node->_inputs) {
auto inode = i.node;
if (inode->flags.get(NodeFlags::_tape)) {
Tapes* t = ((TapeOp*)inode)->tapes;
inode = t;
}
if (inode->tflag == t) {
inode->custom_data--;
if (inode->custom_data == 0)
sorted.push_back(inode);
}
}
func(node);
}
ASSERTop(nodes.size(),==,sorted.size());
}
vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
LOGvv << "loss:" >> loss << "targets:" >> targets;
CHECK(loss->is_float()) << "Loss should be float";
@ -163,13 +44,13 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
vector<Node*> ts(targets.begin(), targets.end());
// bfs visit find all successors of targets
LOGvv << "Size of successors:" << ts.size();
bfs_forward_with_tape(ts, [](Node*){ return true; });
bfs_forward(ts, [](Node*){ return true; });
vector<Node*> gnodes;
gnodes.reserve(ts.size());
auto nt = Node::tflag_count;
if (loss->tflag == nt)
gnodes.push_back(loss);
bfs_backward_with_tape(gnodes, [&](Node* node) {
bfs_backward(gnodes, [&](Node* node) {
if (node->tflag != nt)
return false;
if (node->is_stop_grad())
@ -182,7 +63,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
LOGvv << "Size of grad nodes:" << gnodes.size();
vector<Node*> sorted;
toplogical_sort_backward_with_tape(gnodes, sorted, [](Node*){});
toplogical_sort_backward(gnodes, sorted, [](Node*){});
nt = Node::tflag_count;
vector<Var*> gvars;
gvars.reserve(sorted.size());
@ -217,9 +98,6 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
Var* var = gvars[i];
for (auto it : var->outputs_with_index()) {
Op* op = it.op;
if (op->flags.get(NodeFlags::_tape)) {
op = ((TapeOp*)op)->tapes;
}
auto index = it.index;
if (op->tflag != nt) continue;
id_buffer.emplace_back(op, index);
@ -302,7 +180,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
Var* dout = grads[id];
VarPtr dvar = make_grad(op, out, dout, var, index);
registe_node_trace_grad(dvar.ptr, op, index);
if (dvar)
if (dvar && var->num)
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
<< "dvar" << dvar << "var" << var;
if (!grad)

View File

@ -44,10 +44,8 @@ struct NodeFlags {
_vary_shape=_n+3,
// bit4~5: op type
_op_type=_n+4, _op_type_nbits=2,
// bit6: is tape op
_tape=_n+6,
// bit7: backprop grad at ones
_grads=_n+7,
// bit6: backprop grad at ones
_grads=_n+6,
};
inline void set(Flags f, int a=1, int nbits=1) {

View File

@ -15,7 +15,7 @@ namespace jittor {
static auto make_tape = get_op_info("tape")
.get_constructor<VarPtr, Var*>();
TapeOp::TapeOp(Var* x) : tapes(nullptr) {
TapeOp::TapeOp(Var* x) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
auto y = create_output(nullptr, x->dtype());
@ -23,16 +23,6 @@ TapeOp::TapeOp(Var* x) : tapes(nullptr) {
y->name = x->name;
}
TapeOp::~TapeOp() {
if (tapes) {
if (! --tapes->ref) {
tapes->_inputs.clear();
tapes->_outputs.clear();
delete tapes;
}
}
}
VarPtr TapeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return dout;
}
@ -48,4 +38,49 @@ void Tapes::grads(Var** douts, VarPtr* dins) {
callback.func(_outputs.size(), douts, _inputs.size(), dins);
}
Tapes::Tapes(
const vector<VarHolder*>& taped_inputs,
const vector<VarHolder*>& taped_outputs,
GradCallback&& grad_callback
) {
callback = move(grad_callback);
flags.set(NodeFlags::_grads);
/*
stop grad stop grad
i --> tape --> t_i ---> .... ---> o --> tape --> t_o
| ^
+---> tapes ------------------------------+
*/
// set tape output
for (int i=0; i<taped_outputs.size(); i++) {
VarPtr out(0, ns_float32);
out->add_inputs({this});
auto v = taped_outputs[i]->var;
auto op = v->input();
op->add_inputs(vector<Node*>{out.ptr});
}
// set tapes input
vector<Var*> tin(taped_inputs.size());
for (int i=0; i<taped_inputs.size(); i++) {
tin[i] = taped_inputs[i]->var->input()->inputs().front();
}
add_inputs(tin);
// stop grad for input and output
for (int i=0; i<taped_inputs.size(); i++) {
taped_inputs[i]->var->set_stop_grad();
}
for (int i=0; i<taped_outputs.size(); i++) {
taped_outputs[i]->var->input()->inputs().front()->set_stop_grad();
}
}
void tape_together(
const vector<VarHolder*>& taped_inputs,
const vector<VarHolder*>& taped_outputs,
GradCallback&& grad_callback
) {
new Tapes(taped_inputs, taped_outputs, move(grad_callback));
}
} // jittor

View File

@ -33,9 +33,7 @@ struct GradCallback {
};
struct TapeOp final : Op {
Tapes* tapes;
TapeOp(Var* x);
~TapeOp();
const char* name() const override { return "tape"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
@ -44,8 +42,12 @@ struct TapeOp final : Op {
struct Tapes final : Op {
int ref, total;
GradCallback callback;
Tapes(
const vector<VarHolder*>& taped_inputs,
const vector<VarHolder*>& taped_outputs,
GradCallback&& grad_callback
);
const char* name() const override { return "tapes"; }
void grads(Var** douts, VarPtr* dins) override;
};