mirror of https://github.com/Jittor/Jittor
fix function grad
This commit is contained in:
parent
f2bf93ae56
commit
eab8bcc049
|
@ -162,5 +162,23 @@ class TestFunction(unittest.TestCase):
|
|||
assert da.data == 4, da.data
|
||||
assert db.data == 9
|
||||
|
||||
def test_multi_grads_multi_out3(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, x/y
|
||||
|
||||
def grad(self, grads):
|
||||
res = (grads[0] * self.y, grads[1] * self.x)
|
||||
print(res)
|
||||
return res
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
c,d = MyFunc()(a, b)
|
||||
da, db = jt.grad(c+d*3, [a, b])
|
||||
assert da.data == 4, da.data
|
||||
assert db.data == 9
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
130
src/grad.cc
130
src/grad.cc
|
@ -35,125 +35,6 @@ inline static void assign_attrs(Var* a, Var* b) {
|
|||
a->flags.set(NodeFlags::_stop_fuse);
|
||||
}
|
||||
|
||||
void tape_together(
|
||||
const vector<VarHolder*>& taped_inputs,
|
||||
const vector<VarHolder*>& taped_outputs,
|
||||
GradCallback&& grad_callback
|
||||
) {
|
||||
auto tapes = new Tapes();
|
||||
tapes->total = tapes->ref = taped_inputs.size() + taped_outputs.size();
|
||||
tapes->callback = move(grad_callback);
|
||||
tapes->flags.set(NodeFlags::_grads);
|
||||
for (int i=0; i<taped_inputs.size(); i++) {
|
||||
auto v = taped_inputs[i]->var;
|
||||
auto op = (TapeOp*)v->input();
|
||||
ASSERT(op);
|
||||
op->flags.set(NodeFlags::_tape);
|
||||
tapes->_inputs.emplace_back(op->inputs().front());
|
||||
op->tapes = tapes;
|
||||
}
|
||||
for (int i=0; i<taped_outputs.size(); i++) {
|
||||
auto v = taped_outputs[i]->var;
|
||||
auto op = (TapeOp*)v->input();
|
||||
ASSERT(op);
|
||||
op->flags.set(NodeFlags::_tape);
|
||||
tapes->_outputs.emplace_back(v,0);
|
||||
op->tapes = tapes;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Func>
|
||||
void bfs_backward_with_tape(vector<Node*>& queue, Func&& func) {
|
||||
auto t = ++Node::tflag_count;
|
||||
size_t i=0;
|
||||
for (Node* node : queue) node->tflag = t;
|
||||
while (i < queue.size()) {
|
||||
Node* node = queue[i++];
|
||||
for (auto i : node->_inputs) {
|
||||
auto inode = i.node;
|
||||
if (inode->flags.get(NodeFlags::_tape)) {
|
||||
Tapes* t = ((TapeOp*)inode)->tapes;
|
||||
inode = t;
|
||||
ASSERT(t->ref == t->total);
|
||||
}
|
||||
if (inode->tflag != t && func(inode)) {
|
||||
inode->tflag = t;
|
||||
queue.push_back(inode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void bfs_backward_with_tape(vector<Node*>& seed, vector<Node*>& queue, Func&& func) {
|
||||
for (Node* node : seed)
|
||||
if (func(node)) queue.push_back(node);
|
||||
bfs_backward_with_tape(queue, func);
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void bfs_forward_with_tape(vector<Node*>& queue, Func&& func) {
|
||||
auto t = ++Node::tflag_count;
|
||||
size_t i=0;
|
||||
for (Node* node : queue) node->tflag = t;
|
||||
while (i < queue.size()) {
|
||||
Node* node = queue[i++];
|
||||
for (auto o : node->_outputs) {
|
||||
auto onode = o.node;
|
||||
if (onode->flags.get(NodeFlags::_tape)) {
|
||||
Tapes* t = ((TapeOp*)onode)->tapes;
|
||||
ASSERT(t->ref == t->total) << t->ref << t->total;
|
||||
onode = t;
|
||||
}
|
||||
if (onode->tflag != t && func(onode)) {
|
||||
onode->tflag = t;
|
||||
queue.push_back(onode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Func>
|
||||
void toplogical_sort_backward_with_tape(vector<Node*>& nodes, vector<Node*>& sorted, Func&& func) {
|
||||
auto t = ++Node::tflag_count;
|
||||
sorted.reserve(nodes.size());
|
||||
for (auto node : nodes) node->tflag = t;
|
||||
for (auto node : nodes) {
|
||||
auto& deps = node->custom_data;
|
||||
deps = 0;
|
||||
for (auto o : node->_outputs) {
|
||||
auto onode = o.node;
|
||||
if (onode->flags.get(NodeFlags::_tape)) {
|
||||
Tapes* t = ((TapeOp*)onode)->tapes;
|
||||
onode = t;
|
||||
}
|
||||
if (onode->tflag == t)
|
||||
deps++;
|
||||
}
|
||||
if (deps == 0) sorted.push_back(node);
|
||||
}
|
||||
size_t i=0;
|
||||
while (i < sorted.size()) {
|
||||
Node* node = sorted[i++];
|
||||
for (auto i : node->_inputs) {
|
||||
auto inode = i.node;
|
||||
if (inode->flags.get(NodeFlags::_tape)) {
|
||||
Tapes* t = ((TapeOp*)inode)->tapes;
|
||||
inode = t;
|
||||
}
|
||||
if (inode->tflag == t) {
|
||||
inode->custom_data--;
|
||||
if (inode->custom_data == 0)
|
||||
sorted.push_back(inode);
|
||||
}
|
||||
}
|
||||
func(node);
|
||||
}
|
||||
ASSERTop(nodes.size(),==,sorted.size());
|
||||
}
|
||||
|
||||
vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
||||
LOGvv << "loss:" >> loss << "targets:" >> targets;
|
||||
CHECK(loss->is_float()) << "Loss should be float";
|
||||
|
@ -163,13 +44,13 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
vector<Node*> ts(targets.begin(), targets.end());
|
||||
// bfs visit find all successors of targets
|
||||
LOGvv << "Size of successors:" << ts.size();
|
||||
bfs_forward_with_tape(ts, [](Node*){ return true; });
|
||||
bfs_forward(ts, [](Node*){ return true; });
|
||||
vector<Node*> gnodes;
|
||||
gnodes.reserve(ts.size());
|
||||
auto nt = Node::tflag_count;
|
||||
if (loss->tflag == nt)
|
||||
gnodes.push_back(loss);
|
||||
bfs_backward_with_tape(gnodes, [&](Node* node) {
|
||||
bfs_backward(gnodes, [&](Node* node) {
|
||||
if (node->tflag != nt)
|
||||
return false;
|
||||
if (node->is_stop_grad())
|
||||
|
@ -182,7 +63,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
LOGvv << "Size of grad nodes:" << gnodes.size();
|
||||
|
||||
vector<Node*> sorted;
|
||||
toplogical_sort_backward_with_tape(gnodes, sorted, [](Node*){});
|
||||
toplogical_sort_backward(gnodes, sorted, [](Node*){});
|
||||
nt = Node::tflag_count;
|
||||
vector<Var*> gvars;
|
||||
gvars.reserve(sorted.size());
|
||||
|
@ -217,9 +98,6 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
Var* var = gvars[i];
|
||||
for (auto it : var->outputs_with_index()) {
|
||||
Op* op = it.op;
|
||||
if (op->flags.get(NodeFlags::_tape)) {
|
||||
op = ((TapeOp*)op)->tapes;
|
||||
}
|
||||
auto index = it.index;
|
||||
if (op->tflag != nt) continue;
|
||||
id_buffer.emplace_back(op, index);
|
||||
|
@ -302,7 +180,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
Var* dout = grads[id];
|
||||
VarPtr dvar = make_grad(op, out, dout, var, index);
|
||||
registe_node_trace_grad(dvar.ptr, op, index);
|
||||
if (dvar)
|
||||
if (dvar && var->num)
|
||||
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
||||
<< "dvar" << dvar << "var" << var;
|
||||
if (!grad)
|
||||
|
|
|
@ -44,10 +44,8 @@ struct NodeFlags {
|
|||
_vary_shape=_n+3,
|
||||
// bit4~5: op type
|
||||
_op_type=_n+4, _op_type_nbits=2,
|
||||
// bit6: is tape op
|
||||
_tape=_n+6,
|
||||
// bit7: backprop grad at ones
|
||||
_grads=_n+7,
|
||||
// bit6: backprop grad at ones
|
||||
_grads=_n+6,
|
||||
};
|
||||
|
||||
inline void set(Flags f, int a=1, int nbits=1) {
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace jittor {
|
|||
static auto make_tape = get_op_info("tape")
|
||||
.get_constructor<VarPtr, Var*>();
|
||||
|
||||
TapeOp::TapeOp(Var* x) : tapes(nullptr) {
|
||||
TapeOp::TapeOp(Var* x) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
auto y = create_output(nullptr, x->dtype());
|
||||
|
@ -23,16 +23,6 @@ TapeOp::TapeOp(Var* x) : tapes(nullptr) {
|
|||
y->name = x->name;
|
||||
}
|
||||
|
||||
TapeOp::~TapeOp() {
|
||||
if (tapes) {
|
||||
if (! --tapes->ref) {
|
||||
tapes->_inputs.clear();
|
||||
tapes->_outputs.clear();
|
||||
delete tapes;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VarPtr TapeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
return dout;
|
||||
}
|
||||
|
@ -48,4 +38,49 @@ void Tapes::grads(Var** douts, VarPtr* dins) {
|
|||
callback.func(_outputs.size(), douts, _inputs.size(), dins);
|
||||
}
|
||||
|
||||
Tapes::Tapes(
|
||||
const vector<VarHolder*>& taped_inputs,
|
||||
const vector<VarHolder*>& taped_outputs,
|
||||
GradCallback&& grad_callback
|
||||
) {
|
||||
callback = move(grad_callback);
|
||||
flags.set(NodeFlags::_grads);
|
||||
|
||||
/*
|
||||
stop grad stop grad
|
||||
i --> tape --> t_i ---> .... ---> o --> tape --> t_o
|
||||
| ^
|
||||
+---> tapes ------------------------------+
|
||||
*/
|
||||
// set tape output
|
||||
for (int i=0; i<taped_outputs.size(); i++) {
|
||||
VarPtr out(0, ns_float32);
|
||||
out->add_inputs({this});
|
||||
auto v = taped_outputs[i]->var;
|
||||
auto op = v->input();
|
||||
op->add_inputs(vector<Node*>{out.ptr});
|
||||
}
|
||||
// set tapes input
|
||||
vector<Var*> tin(taped_inputs.size());
|
||||
for (int i=0; i<taped_inputs.size(); i++) {
|
||||
tin[i] = taped_inputs[i]->var->input()->inputs().front();
|
||||
}
|
||||
add_inputs(tin);
|
||||
// stop grad for input and output
|
||||
for (int i=0; i<taped_inputs.size(); i++) {
|
||||
taped_inputs[i]->var->set_stop_grad();
|
||||
}
|
||||
for (int i=0; i<taped_outputs.size(); i++) {
|
||||
taped_outputs[i]->var->input()->inputs().front()->set_stop_grad();
|
||||
}
|
||||
}
|
||||
|
||||
void tape_together(
|
||||
const vector<VarHolder*>& taped_inputs,
|
||||
const vector<VarHolder*>& taped_outputs,
|
||||
GradCallback&& grad_callback
|
||||
) {
|
||||
new Tapes(taped_inputs, taped_outputs, move(grad_callback));
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -33,9 +33,7 @@ struct GradCallback {
|
|||
};
|
||||
|
||||
struct TapeOp final : Op {
|
||||
Tapes* tapes;
|
||||
TapeOp(Var* x);
|
||||
~TapeOp();
|
||||
|
||||
const char* name() const override { return "tape"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
|
@ -44,8 +42,12 @@ struct TapeOp final : Op {
|
|||
|
||||
|
||||
struct Tapes final : Op {
|
||||
int ref, total;
|
||||
GradCallback callback;
|
||||
Tapes(
|
||||
const vector<VarHolder*>& taped_inputs,
|
||||
const vector<VarHolder*>& taped_outputs,
|
||||
GradCallback&& grad_callback
|
||||
);
|
||||
const char* name() const override { return "tapes"; }
|
||||
void grads(Var** douts, VarPtr* dins) override;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue