mirror of https://github.com/Jittor/Jittor
polish rnn grad
This commit is contained in:
parent
38694a1b6e
commit
5cebf93e31
|
@ -108,11 +108,16 @@ static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x")
|
|||
.get_constructor<vector<VarPtr>, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>();
|
||||
static auto make_backwardx_without_cx = get_op_info("cudnn_rnn_backward_x")
|
||||
.get_constructor<vector<VarPtr>, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>();
|
||||
static auto make_number = get_op_info("number")
|
||||
.get_constructor<VarPtr, float, Var*>();
|
||||
|
||||
void CudnnRnnOp::grads(Var** dout, VarPtr* dins) {
|
||||
Var *dy = dout[0];
|
||||
Var *dhy = dout[1];
|
||||
Var *dcy = cx ? dout[2] : nullptr;
|
||||
VarPtr dy = dout[0];
|
||||
VarPtr dhy = dout[1];
|
||||
VarPtr dcy = cx ? dout[2] : nullptr;
|
||||
if (!dhy.ptr) dhy = make_number(0.0, hy);
|
||||
if (!dcy.ptr) dcy = make_number(0.0, cy);
|
||||
|
||||
|
||||
vector<VarPtr> dInput;
|
||||
if (cx)
|
||||
|
|
|
@ -174,12 +174,16 @@ def Resnet50(pretrained=False, **kwargs):
|
|||
|
||||
resnet50 = Resnet50
|
||||
|
||||
def Resnet38(**kwargs):
|
||||
return _resnet(Bottleneck, [2, 3, 5, 2], **kwargs)
|
||||
def Resnet38(pretrained=False, **kwargs):
|
||||
model = _resnet(Bottleneck, [2, 3, 5, 2], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet38.pkl")
|
||||
return model
|
||||
resnet38 = Resnet38
|
||||
|
||||
def Resnet26(**kwargs):
|
||||
return _resnet(Bottleneck, [1, 2, 4, 1], **kwargs)
|
||||
def Resnet26(pretrained=False, **kwargs):
|
||||
model = _resnet(Bottleneck, [1, 2, 4, 1], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet26.pkl")
|
||||
return model
|
||||
resnet26 = Resnet26
|
||||
|
||||
def Resnet101(pretrained=False, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue