polish rnn grad

This commit is contained in:
Dun Liang 2021-11-22 13:16:34 +08:00
parent 38694a1b6e
commit 5cebf93e31
2 changed files with 16 additions and 7 deletions

View File

@ -108,11 +108,16 @@ static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x")
.get_constructor<vector<VarPtr>, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>();
static auto make_backwardx_without_cx = get_op_info("cudnn_rnn_backward_x")
.get_constructor<vector<VarPtr>, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>();
static auto make_number = get_op_info("number")
.get_constructor<VarPtr, float, Var*>();
void CudnnRnnOp::grads(Var** dout, VarPtr* dins) {
Var *dy = dout[0];
Var *dhy = dout[1];
Var *dcy = cx ? dout[2] : nullptr;
VarPtr dy = dout[0];
VarPtr dhy = dout[1];
VarPtr dcy = cx ? dout[2] : nullptr;
if (!dhy.ptr) dhy = make_number(0.0, hy);
if (!dcy.ptr) dcy = make_number(0.0, cy);
vector<VarPtr> dInput;
if (cx)

View File

@ -174,12 +174,16 @@ def Resnet50(pretrained=False, **kwargs):
resnet50 = Resnet50
def Resnet38(**kwargs):
return _resnet(Bottleneck, [2, 3, 5, 2], **kwargs)
def Resnet38(pretrained=False, **kwargs):
model = _resnet(Bottleneck, [2, 3, 5, 2], **kwargs)
if pretrained: model.load("jittorhub://resnet38.pkl")
return model
resnet38 = Resnet38
def Resnet26(**kwargs):
return _resnet(Bottleneck, [1, 2, 4, 1], **kwargs)
def Resnet26(pretrained=False, **kwargs):
model = _resnet(Bottleneck, [1, 2, 4, 1], **kwargs)
if pretrained: model.load("jittorhub://resnet26.pkl")
return model
resnet26 = Resnet26
def Resnet101(pretrained=False, **kwargs):