mirror of https://github.com/Jittor/Jittor
polish rnn grad
This commit is contained in:
parent
1e1fe66a30
commit
9af14f4f55
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.20'
|
||||
__version__ = '1.3.1.21'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -116,7 +116,7 @@ void CudnnRnnOp::grads(Var** dout, VarPtr* dins) {
|
|||
VarPtr dhy = dout[1];
|
||||
VarPtr dcy = cx ? dout[2] : nullptr;
|
||||
if (!dhy.ptr) dhy = make_number(0.0, hy);
|
||||
if (!dcy.ptr) dcy = make_number(0.0, cy);
|
||||
if (!dcy.ptr && cx) dcy = make_number(0.0, cy);
|
||||
|
||||
|
||||
vector<VarPtr> dInput;
|
||||
|
|
Loading…
Reference in New Issue