mirror of https://github.com/Jittor/Jittor
265 lines
8.8 KiB
C++
265 lines
8.8 KiB
C++
// ***************************************************************
|
|
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
|
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
|
// This file is subject to the terms and conditions defined in
|
|
// file 'LICENSE.txt', which is part of this source code package.
|
|
// ***************************************************************
|
|
#include "pybind/py_var_tracer.h"
|
|
#include "grad.h"
|
|
#include "var.h"
|
|
#include "op.h"
|
|
#include "graph.h"
|
|
#include "ops/op_register.h"
|
|
|
|
namespace jittor {
|
|
|
|
#define PREVENT_LARGE_FUSED_OP 16
|
|
|
|
static auto make_binary = get_op_info("binary")
|
|
.get_constructor<VarPtr, Var*, Var*, NanoString>();
|
|
static auto make_number = get_op_info("number")
|
|
.get_constructor<VarPtr, float, Var*>();
|
|
|
|
#ifdef _WIN32
|
|
template<class T> struct StackIniter {
|
|
T* a;
|
|
int n;
|
|
inline StackIniter(T* a, int n) :a(a), n(n) {
|
|
for (int i=0; i<n; i++)
|
|
new(a+i) T();
|
|
}
|
|
inline ~StackIniter() {
|
|
for (int i=0; i<n; i++)
|
|
a[i].~T();
|
|
}
|
|
};
|
|
|
|
#define STACK_ALLOC2(T, a, n) T* a = (T*)_alloca(sizeof(T)*(n)); StackIniter<T> __init_##a(a, n);
|
|
#else
|
|
#define STACK_ALLOC2(T, a, n) T a[n]
|
|
#endif
|
|
|
|
struct AmpGradGuard {
|
|
int amp_reg_bk;
|
|
AmpGradGuard(Op* op) {
|
|
amp_reg_bk = amp_reg;
|
|
amp_reg |= (op->flags.flags >> NodeFlags::_prefer_32);
|
|
}
|
|
|
|
~AmpGradGuard() {
|
|
amp_reg = amp_reg_bk;
|
|
}
|
|
};
|
|
|
|
VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) {
|
|
if (dout == nullptr) return nullptr;
|
|
if (x_index<0) return nullptr;
|
|
LOGvvvv << "Make grad op:" >> op->name() << "inputs:" >> op->inputs()
|
|
<< "out:" >> out << "dout:" >> dout << "x:" >> x << "xid:" >> x_index;
|
|
AmpGradGuard agg(op);
|
|
auto dx = op->grad(out, dout, x, x_index);
|
|
if (x->loop_options)
|
|
dx->loop_options = x->loop_options;
|
|
return dx;
|
|
}
|
|
|
|
inline static void assign_attrs(Var* a, Var* b) {
|
|
if (b->flags.get(NodeFlags::_stop_fuse))
|
|
a->flags.set(NodeFlags::_stop_fuse);
|
|
}
|
|
|
|
map<string,int> grad_breaks;
|
|
|
|
void warn_grad_break(int i, Var* v) {
|
|
if (grad_breaks.count(v->name.c_str())) return;
|
|
grad_breaks[v->name.c_str()] = 1;
|
|
LOGw << "grads[">>i>>"] '">> v->name>>"' doesn't have gradient. It will be set to zero:" << v;
|
|
}
|
|
|
|
vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|
LOGvv << "loss:" >> loss << "targets:" >> targets;
|
|
CHECK(loss->is_float()) << "Loss should be float";
|
|
for (Var* var : targets)
|
|
CHECK(var->is_float()) << "Targets of grad should be float";
|
|
// successors of targets
|
|
vector<Node*> ts(targets.begin(), targets.end());
|
|
// bfs visit find all successors of targets
|
|
LOGvv << "Size of successors:" << ts.size();
|
|
bfs_forward(ts, [](Node*){ return true; });
|
|
vector<Node*> gnodes;
|
|
gnodes.reserve(ts.size());
|
|
auto nt = Node::tflag_count;
|
|
if (loss->tflag == nt)
|
|
gnodes.push_back(loss);
|
|
bfs_backward(gnodes, [&](Node* node) {
|
|
if (node->tflag != nt)
|
|
return false;
|
|
if (node->is_stop_grad())
|
|
return false;
|
|
return true;
|
|
});
|
|
LOGvv << "Size of grad nodes:" << gnodes.size();
|
|
|
|
vector<Node*> sorted;
|
|
toplogical_sort_backward(gnodes, sorted, [](Node*){});
|
|
nt = Node::tflag_count;
|
|
vector<Var*> gvars;
|
|
gvars.reserve(sorted.size());
|
|
for (Node* node : sorted)
|
|
if (node->is_var()) {
|
|
Var* v = node->var();
|
|
v->custom_data = gvars.size();
|
|
gvars.push_back(v);
|
|
}
|
|
LOGvv << "Size of grad vars:" << gvars.size();
|
|
|
|
vector<VarPtr> grads(gvars.size());
|
|
vector<VarPtr> results(targets.size());
|
|
vector<int> target_id(targets.size());
|
|
for (int i=0; i<targets.size(); i++) {
|
|
Var* var = targets[i];
|
|
target_id[i] = (var->tflag == nt) ?
|
|
var->custom_data : -1;
|
|
}
|
|
|
|
if (grads.size()) {
|
|
grads[0] = make_number(1.f, loss);
|
|
assign_attrs(grads[0].ptr, loss);
|
|
}
|
|
|
|
vector<pair<Node*, int64>> id_buffer;
|
|
id_buffer.reserve(sorted.size()+10);
|
|
|
|
// backup id in custum data
|
|
for (int i=1; i<gvars.size(); i++) {
|
|
Var* var = gvars[i];
|
|
for (auto it : var->outputs_with_index()) {
|
|
Op* op = it.op;
|
|
auto index = it.index;
|
|
if (op->tflag != nt) continue;
|
|
id_buffer.emplace_back(op, index);
|
|
|
|
// backward together
|
|
if (op->flags.get(NodeFlags::_grads)) {
|
|
// dont backward next time
|
|
op->tflag = 0;
|
|
for (Var* out : op->outputs()) {
|
|
id_buffer.emplace_back(
|
|
out,
|
|
out->tflag == nt ? out->custom_data : -1);
|
|
}
|
|
for (Var* in : op->inputs()) {
|
|
id_buffer.emplace_back(
|
|
in,
|
|
in->tflag == nt ? in->custom_data : -1);
|
|
}
|
|
} else {
|
|
// single var backward
|
|
for (Var* out : op->outputs()) {
|
|
id_buffer.emplace_back(
|
|
out,
|
|
out->tflag == nt ? out->custom_data : -1);
|
|
}
|
|
}
|
|
}
|
|
// end of var output
|
|
id_buffer.emplace_back(nullptr, 0);
|
|
}
|
|
|
|
// real backward construction from prev backuped ids
|
|
int j=0;
|
|
for (int i=1; i<gvars.size(); i++,j++) {
|
|
Var* var = gvars[i];
|
|
auto& grad = grads[i];
|
|
#ifdef PREVENT_LARGE_FUSED_OP
|
|
int gsum = 0;
|
|
#endif
|
|
// dump "for (auto it : var->outputs_with_index())"
|
|
while (id_buffer[j].first) {
|
|
Op* op = id_buffer[j].first->op();
|
|
auto index = id_buffer[j].second;
|
|
j++;
|
|
auto n_o = op->outputs().size();
|
|
|
|
if (op->flags.get(NodeFlags::_grads)) {
|
|
// backward together
|
|
auto n_i = op->inputs().size();
|
|
STACK_ALLOC(Var*, douts, n_o);
|
|
STACK_ALLOC2(VarPtr, dins, n_i);
|
|
// dump "for (Var* out : op->outputs())"
|
|
for (int i=0; i<n_o; i++,j++) {
|
|
auto id = id_buffer[j].second;
|
|
if (id>=0) {
|
|
douts[i] = grads[id];
|
|
} else
|
|
douts[i] = nullptr;
|
|
}
|
|
trace_grad_op = op;
|
|
{
|
|
AmpGradGuard agg(op);
|
|
op->grads(douts, dins);
|
|
}
|
|
// dump "for (Var* in : op->inputs())"
|
|
for (int i=0; i<n_i; i++,j++) {
|
|
auto id = id_buffer[j].second;
|
|
if (id>=0) {
|
|
auto& din = dins[i];
|
|
auto& grad = grads[id];
|
|
if (din && grad) {
|
|
grad = make_binary(grad, din, ns_add);
|
|
} else
|
|
grad = move(din);
|
|
}
|
|
}
|
|
} else {
|
|
// single var backward
|
|
// dump "for (Var* out : op->outputs())"
|
|
for (int i=0; i<n_o; i++,j++) {
|
|
auto id = id_buffer[j].second;
|
|
auto out = id_buffer[j].first->var();
|
|
if (id<0) continue;
|
|
Var* dout = grads[id];
|
|
trace_grad_op = op;
|
|
VarPtr dvar = make_grad(op, out, dout, var, index);
|
|
if (dvar && dvar->num>=0 && var->num>0)
|
|
// var->num == 0 represents a any match var
|
|
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
|
<< "dvar" << dvar << "var" << var;
|
|
if (!grad)
|
|
grad = move(dvar);
|
|
else if (dvar) {
|
|
grad = make_binary(grad, dvar, ns_add);
|
|
#ifdef PREVENT_LARGE_FUSED_OP
|
|
gsum ++;
|
|
if (gsum>=PREVENT_LARGE_FUSED_OP) {
|
|
// TODO: this is a dirty fix for
|
|
// stopping fuse lots of op together,
|
|
// try to find a better solution
|
|
grad->flags.set(NodeFlags::_stop_fuse);
|
|
}
|
|
#endif
|
|
assign_attrs(grad.ptr, var);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
trace_grad_op = nullptr;
|
|
// set zero grad
|
|
for (size_t i=0; i<results.size(); i++) {
|
|
Var* var = targets[i];
|
|
VarPtr& grad = results[i];
|
|
auto id = target_id[i];
|
|
if (id>=0)
|
|
grad = move(grads[id]);
|
|
if (!grad) {
|
|
// TODO: better warning message
|
|
warn_grad_break(i, var);
|
|
grad = make_number(0.f, var);
|
|
assign_attrs(grad.ptr, var);
|
|
}
|
|
}
|
|
return results;
|
|
}
|
|
|
|
} // jittor
|