mirror of https://github.com/Jittor/Jittor
fix multi output grad
This commit is contained in:
parent
130ecea0c5
commit
b27082f944
2
setup.py
2
setup.py
|
@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
|
|||
|
||||
setuptools.setup(
|
||||
name='jittor',
|
||||
version='1.1.5.3',
|
||||
version='1.1.5.4',
|
||||
# scripts=[],
|
||||
author="Jittor Group",
|
||||
author_email="ran.donglang@gmail.com",
|
||||
|
|
45
src/grad.cc
45
src/grad.cc
|
@ -92,29 +92,30 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
Op* op = it.op;
|
||||
auto index = it.index;
|
||||
if (op->tflag != nt) continue;
|
||||
// TODO: support two outputs backprop.
|
||||
Var* out = op->outputs().back();
|
||||
Var* dout = grads[out->custom_data];
|
||||
VarPtr dvar = make_grad(op, out, dout, var, index);
|
||||
registe_node_trace_grad(dvar.ptr, op, index);
|
||||
if (dvar)
|
||||
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
||||
<< "dvar" << dvar << "var" << var;
|
||||
if (!grad)
|
||||
grad = move(dvar);
|
||||
else if (dvar) {
|
||||
grad = make_binary(grad, dvar, ns_add);
|
||||
#ifdef PREVENT_LARGE_FUSED_OP
|
||||
gsum ++;
|
||||
if (gsum>=PREVENT_LARGE_FUSED_OP) {
|
||||
// TODO: this is a dirty fix for
|
||||
// stopping fuse lots of op together,
|
||||
// try to find a better solution
|
||||
grad->flags.set(NodeFlags::_stop_fuse);
|
||||
for (Var* out : op->outputs()) {
|
||||
if (out->tflag != nt) continue;
|
||||
Var* dout = grads[out->custom_data];
|
||||
VarPtr dvar = make_grad(op, out, dout, var, index);
|
||||
registe_node_trace_grad(dvar.ptr, op, index);
|
||||
if (dvar)
|
||||
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
||||
<< "dvar" << dvar << "var" << var;
|
||||
if (!grad)
|
||||
grad = move(dvar);
|
||||
else if (dvar) {
|
||||
grad = make_binary(grad, dvar, ns_add);
|
||||
#ifdef PREVENT_LARGE_FUSED_OP
|
||||
gsum ++;
|
||||
if (gsum>=PREVENT_LARGE_FUSED_OP) {
|
||||
// TODO: this is a dirty fix for
|
||||
// stopping fuse lots of op together,
|
||||
// try to find a better solution
|
||||
grad->flags.set(NodeFlags::_stop_fuse);
|
||||
}
|
||||
#endif
|
||||
assign_attrs(grad.ptr, var);
|
||||
registe_node_trace_grad(grad.ptr, var, index);
|
||||
}
|
||||
#endif
|
||||
assign_attrs(grad.ptr, var);
|
||||
registe_node_trace_grad(grad.ptr, var, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue