mirror of https://github.com/Jittor/Jittor
fix mkl conv
This commit is contained in:
parent
557a4bf9d2
commit
87f51fc13d
|
@ -45,7 +45,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
|
||||
|
|
|
@ -16,7 +16,7 @@ struct MklConvBackwardWOp : Op {
|
|||
int kernel_size, stride, padding, dilation;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "mkl_conv_backward_w"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -45,7 +45,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
|
||||
|
|
|
@ -16,7 +16,7 @@ struct MklConvBackwardXOp : Op {
|
|||
int xh, xw, stride, padding, dilation;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "mkl_conv_backward_x"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -22,7 +22,7 @@ struct OpInfo {
|
|||
for (uint i=0; i<constructors.size(); i++)
|
||||
if (std::type_index(*(constructors[i].first)) == std::type_index(tid))
|
||||
return func_t(constructors[i].second);
|
||||
LOGf << "constructor" << tid.name() << "not found.";
|
||||
LOGf << "constructor" << name << tid.name() << "not found.";
|
||||
return func_t(nullptr);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -160,6 +160,24 @@ struct OpInspector {
|
|||
return 0;
|
||||
}
|
||||
|
||||
string format2(const string& fmt, const vector<int>& order) {
|
||||
string new_fmt = fmt;
|
||||
if (order.size() != fmt.size()) {
|
||||
failed = 1;
|
||||
return "";
|
||||
}
|
||||
if (check_overlap(order))
|
||||
return "";
|
||||
for (uint i=0; i<order.size(); i++) {
|
||||
if (order[i]>=(int)new_fmt.size()) {
|
||||
failed = 1;
|
||||
return "";
|
||||
}
|
||||
new_fmt[order[i]] = fmt[i];
|
||||
}
|
||||
return new_fmt;
|
||||
}
|
||||
|
||||
string format(const string& fmt, const vector<int>& order) {
|
||||
string new_fmt = fmt;
|
||||
if (order.size() != fmt.size()) {
|
||||
|
@ -234,19 +252,19 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
|
||||
auto xformat = xoi.format("abcd", {xn, xc, xh, xw});
|
||||
auto xformat = xoi.format2("abcd", {xn, xc, xh, xw});
|
||||
LOGvvvv << "xformat =" << xformat;
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
auto wformat = xoi.format("iohw", {wci, wco, wh, ww});
|
||||
auto wformat = xoi.format2("iohw", {wci, wco, wh, ww});
|
||||
LOGvvvv << "wformat =" << wformat;
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
auto yformat = xoi.format("abcd", {yn, yc, yh, yw});
|
||||
auto yformat = xoi.format2("abcd", {yn, yc, yh, yw});
|
||||
LOGvvvv << "yformat =" << yformat;
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
|
@ -307,7 +325,7 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
continue;
|
||||
auto make_conv = get_op_info(relay_conv_name)
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
|
||||
auto rvar = make_conv(x, w, stride, padding, dilation, 1, xformat, wformat, yformat);
|
||||
auto rvar = make_conv(x, w, stride, padding, dilation, 1, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar, rop->y}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
|
@ -359,19 +377,19 @@ void ConvTuner::backwardTune(FusedOp* fop) {
|
|||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
|
||||
xformat = xoi.format("abcd", {xn, xc, xh, xw});
|
||||
xformat = xoi.format2("abcd", {xn, xc, xh, xw});
|
||||
LOGvvvv << "xformat =" << xformat;
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
wformat = xoi.format("iohw", {wci, wco, wh, ww});
|
||||
wformat = xoi.format2("iohw", {wci, wco, wh, ww});
|
||||
LOGvvvv << "wformat =" << wformat;
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
yformat = xoi.format("abcd", {yn, yc, yh, yw});
|
||||
yformat = xoi.format2("abcd", {yn, yc, yh, yw});
|
||||
LOGvvvv << "yformat =" << yformat;
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
|
@ -434,19 +452,19 @@ void ConvTuner::backwardTune(FusedOp* fop) {
|
|||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
|
||||
xformat = xoi.format("abcd", {xn, xc, xh, xw});
|
||||
xformat = xoi.format2("abcd", {xn, xc, xh, xw});
|
||||
LOGvvvv << "xformat =" << xformat;
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
wformat = xoi.format("iohw", {wci, wco, wh, ww});
|
||||
wformat = xoi.format2("iohw", {wci, wco, wh, ww});
|
||||
LOGvvvv << "wformat =" << wformat;
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
yformat = xoi.format("abcd", {yn, yc, yh, yw});
|
||||
yformat = xoi.format2("abcd", {yn, yc, yh, yw});
|
||||
LOGvvvv << "yformat =" << yformat;
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
|
|
Loading…
Reference in New Issue