fix mkl don't support bacd dilation

This commit is contained in:
Dun Liang 2020-04-09 13:54:34 +08:00
parent 95ea57d095
commit 602d705609
1 changed files with 8 additions and 1 deletions

View File

@ -235,10 +235,12 @@ void ConvTuner::forwardTune(FusedOp* fop) {
LOGvvvv << "Expr not match" << src_h << expr_h;
continue;
}
LOGvvvv << "H Expr matched" << src_h << expr_h;
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number) || !rh[2]->is(expr::_number)) return;
auto src_w = expr::make(riop1->indexes[xw]);
if (!expr::match(src_w.get(), expr_w.get(), {"stride", "padding", "dilation"}, {"i"+S(zw), "i"+S(zww)}, rw))
return;
LOGvvvv << "W Expr matched" << src_w << expr_w;
if (!rw[0]->is(expr::_number) || !rw[1]->is(expr::_number) || !rw[2]->is(expr::_number)) return;
int stride_h = rh[0]->as_int();
int padding_h = -rh[1]->as_int();
@ -253,7 +255,10 @@ void ConvTuner::forwardTune(FusedOp* fop) {
continue;
}
LOGvvvv << "get stride padding and dilation" << stride_h << padding_h << dilation_h;
if (xformat == "bacd" && dilation_h != 1) {
LOGvvvv << "mkl not support bacd dilation, continue";
continue;
}
int stride = stride_h;
int padding = padding_h;
int dilation = dilation_h;
@ -363,6 +368,7 @@ void ConvTuner::backwardTune(FusedOp* fop) {
x = riop1->x;
y = riop2->x;
bo++;
LOGvvvv << "backward_w get stride padding and dilation" << stride << padding << dilation;
} else if (op->name_ex() == "reindex_reduce.add") {
auto rop = (ReindexReduceOp*)op;
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
@ -438,6 +444,7 @@ void ConvTuner::backwardTune(FusedOp* fop) {
w = riop1->x;
y = riop2->x;
bo+=2;
LOGvvvv << "backward_x get stride padding and dilation" << stride << padding << dilation;
}
// TODO: CUDA only support nchw(abcd)