fix multi output grad

This commit is contained in:
Dun Liang 2020-07-08 22:27:56 +08:00
parent 130ecea0c5
commit b27082f944
2 changed files with 24 additions and 23 deletions

View File

@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
setuptools.setup(
name='jittor',
version='1.1.5.3',
version='1.1.5.4',
# scripts=[],
author="Jittor Group",
author_email="ran.donglang@gmail.com",

View File

@ -92,29 +92,30 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
Op* op = it.op;
auto index = it.index;
if (op->tflag != nt) continue;
// TODO: support two outputs backprop.
Var* out = op->outputs().back();
Var* dout = grads[out->custom_data];
VarPtr dvar = make_grad(op, out, dout, var, index);
registe_node_trace_grad(dvar.ptr, op, index);
if (dvar)
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
<< "dvar" << dvar << "var" << var;
if (!grad)
grad = move(dvar);
else if (dvar) {
grad = make_binary(grad, dvar, ns_add);
#ifdef PREVENT_LARGE_FUSED_OP
gsum ++;
if (gsum>=PREVENT_LARGE_FUSED_OP) {
// TODO: this is a dirty fix for
// stopping fuse lots of op together,
// try to find a better solution
grad->flags.set(NodeFlags::_stop_fuse);
for (Var* out : op->outputs()) {
if (out->tflag != nt) continue;
Var* dout = grads[out->custom_data];
VarPtr dvar = make_grad(op, out, dout, var, index);
registe_node_trace_grad(dvar.ptr, op, index);
if (dvar)
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
<< "dvar" << dvar << "var" << var;
if (!grad)
grad = move(dvar);
else if (dvar) {
grad = make_binary(grad, dvar, ns_add);
#ifdef PREVENT_LARGE_FUSED_OP
gsum ++;
if (gsum>=PREVENT_LARGE_FUSED_OP) {
// TODO: this is a dirty fix for
// stopping fuse lots of op together,
// try to find a better solution
grad->flags.set(NodeFlags::_stop_fuse);
}
#endif
assign_attrs(grad.ptr, var);
registe_node_trace_grad(grad.ptr, var, index);
}
#endif
assign_attrs(grad.ptr, var);
registe_node_trace_grad(grad.ptr, var, index);
}
}
}