polish rnn grad

This commit is contained in:
Dun Liang 2021-11-22 13:29:04 +08:00
parent 1e1fe66a30
commit 9af14f4f55
2 changed files with 2 additions and 2 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.1.20'
__version__ = '1.3.1.21'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -116,7 +116,7 @@ void CudnnRnnOp::grads(Var** dout, VarPtr* dins) {
VarPtr dhy = dout[1];
VarPtr dcy = cx ? dout[2] : nullptr;
if (!dhy.ptr) dhy = make_number(0.0, hy);
if (!dcy.ptr) dcy = make_number(0.0, cy);
if (!dcy.ptr && cx) dcy = make_number(0.0, cy);
vector<VarPtr> dInput;