fix mkl conv

This commit is contained in:
Dun Liang 2020-05-02 12:03:21 +08:00
parent 557a4bf9d2
commit 87f51fc13d
6 changed files with 33 additions and 15 deletions

View File

@ -45,7 +45,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));

View File

@ -16,7 +16,7 @@ struct MklConvBackwardWOp : Op {
int kernel_size, stride, padding, dilation;
string xformat, wformat, yformat;
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "mkl_conv_backward_w"; }
void infer_shape() override;

View File

@ -45,7 +45,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));

View File

@ -16,7 +16,7 @@ struct MklConvBackwardXOp : Op {
int xh, xw, stride, padding, dilation;
string xformat, wformat, yformat;
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "mkl_conv_backward_x"; }
void infer_shape() override;

View File

@ -22,7 +22,7 @@ struct OpInfo {
for (uint i=0; i<constructors.size(); i++)
if (std::type_index(*(constructors[i].first)) == std::type_index(tid))
return func_t(constructors[i].second);
LOGf << "constructor" << tid.name() << "not found.";
LOGf << "constructor" << name << tid.name() << "not found.";
return func_t(nullptr);
}
};

View File

@ -160,6 +160,24 @@ struct OpInspector {
return 0;
}
string format2(const string& fmt, const vector<int>& order) {
string new_fmt = fmt;
if (order.size() != fmt.size()) {
failed = 1;
return "";
}
if (check_overlap(order))
return "";
for (uint i=0; i<order.size(); i++) {
if (order[i]>=(int)new_fmt.size()) {
failed = 1;
return "";
}
new_fmt[order[i]] = fmt[i];
}
return new_fmt;
}
string format(const string& fmt, const vector<int>& order) {
string new_fmt = fmt;
if (order.size() != fmt.size()) {
@ -234,19 +252,19 @@ void ConvTuner::forwardTune(FusedOp* fop) {
xh = xoi.mm[zh];
xw = xoi.mm[zw];
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
auto xformat = xoi.format("abcd", {xn, xc, xh, xw});
auto xformat = xoi.format2("abcd", {xn, xc, xh, xw});
LOGvvvv << "xformat =" << xformat;
wci = woi.mm[zci];
wco = woi.mm[zco];
wh = woi.mm[zwh];
ww = woi.mm[zww];
auto wformat = xoi.format("iohw", {wci, wco, wh, ww});
auto wformat = xoi.format2("iohw", {wci, wco, wh, ww});
LOGvvvv << "wformat =" << wformat;
yn = yoi.mm[zn];
yc = yoi.mm[zco];
yh = yoi.mm[zh];
yw = yoi.mm[zw];
auto yformat = xoi.format("abcd", {yn, yc, yh, yw});
auto yformat = xoi.format2("abcd", {yn, yc, yh, yw});
LOGvvvv << "yformat =" << yformat;
// mkl doesn't support "cdab" format
if (yformat == "cdab") continue;
@ -307,7 +325,7 @@ void ConvTuner::forwardTune(FusedOp* fop) {
continue;
auto make_conv = get_op_info(relay_conv_name)
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
auto rvar = make_conv(x, w, stride, padding, dilation, 1, xformat, wformat, yformat);
auto rvar = make_conv(x, w, stride, padding, dilation, 1, xformat, wformat, yformat);
auto rid = fop->context->vrm.add_relay_group({{rvar, rop->y}});
if (rid>=0) {
auto srid = "relay"+S(rid);
@ -359,19 +377,19 @@ void ConvTuner::backwardTune(FusedOp* fop) {
xh = xoi.mm[zh];
xw = xoi.mm[zw];
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
xformat = xoi.format("abcd", {xn, xc, xh, xw});
xformat = xoi.format2("abcd", {xn, xc, xh, xw});
LOGvvvv << "xformat =" << xformat;
wci = woi.mm[zci];
wco = woi.mm[zco];
wh = woi.mm[zwh];
ww = woi.mm[zww];
wformat = xoi.format("iohw", {wci, wco, wh, ww});
wformat = xoi.format2("iohw", {wci, wco, wh, ww});
LOGvvvv << "wformat =" << wformat;
yn = yoi.mm[zn];
yc = yoi.mm[zco];
yh = yoi.mm[zh];
yw = yoi.mm[zw];
yformat = xoi.format("abcd", {yn, yc, yh, yw});
yformat = xoi.format2("abcd", {yn, yc, yh, yw});
LOGvvvv << "yformat =" << yformat;
// mkl doesn't support "cdab" format
if (yformat == "cdab") continue;
@ -434,19 +452,19 @@ void ConvTuner::backwardTune(FusedOp* fop) {
xh = xoi.mm[zh];
xw = xoi.mm[zw];
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
xformat = xoi.format("abcd", {xn, xc, xh, xw});
xformat = xoi.format2("abcd", {xn, xc, xh, xw});
LOGvvvv << "xformat =" << xformat;
wci = woi.mm[zci];
wco = woi.mm[zco];
wh = woi.mm[zwh];
ww = woi.mm[zww];
wformat = xoi.format("iohw", {wci, wco, wh, ww});
wformat = xoi.format2("iohw", {wci, wco, wh, ww});
LOGvvvv << "wformat =" << wformat;
yn = yoi.mm[zn];
yc = yoi.mm[zco];
yh = yoi.mm[zh];
yw = yoi.mm[zw];
yformat = xoi.format("abcd", {yn, yc, yh, yw});
yformat = xoi.format2("abcd", {yn, yc, yh, yw});
LOGvvvv << "yformat =" << yformat;
// mkl doesn't support "cdab" format
if (yformat == "cdab") continue;