mirror of https://github.com/Jittor/Jittor
fix mkl don't support bacd dilation
This commit is contained in:
parent
95ea57d095
commit
602d705609
|
@ -235,10 +235,12 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
LOGvvvv << "Expr not match" << src_h << expr_h;
|
||||
continue;
|
||||
}
|
||||
LOGvvvv << "H Expr matched" << src_h << expr_h;
|
||||
if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number) || !rh[2]->is(expr::_number)) return;
|
||||
auto src_w = expr::make(riop1->indexes[xw]);
|
||||
if (!expr::match(src_w.get(), expr_w.get(), {"stride", "padding", "dilation"}, {"i"+S(zw), "i"+S(zww)}, rw))
|
||||
return;
|
||||
LOGvvvv << "W Expr matched" << src_w << expr_w;
|
||||
if (!rw[0]->is(expr::_number) || !rw[1]->is(expr::_number) || !rw[2]->is(expr::_number)) return;
|
||||
int stride_h = rh[0]->as_int();
|
||||
int padding_h = -rh[1]->as_int();
|
||||
|
@ -253,7 +255,10 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
continue;
|
||||
}
|
||||
LOGvvvv << "get stride padding and dilation" << stride_h << padding_h << dilation_h;
|
||||
|
||||
if (xformat == "bacd" && dilation_h != 1) {
|
||||
LOGvvvv << "mkl not support bacd dilation, continue";
|
||||
continue;
|
||||
}
|
||||
int stride = stride_h;
|
||||
int padding = padding_h;
|
||||
int dilation = dilation_h;
|
||||
|
@ -363,6 +368,7 @@ void ConvTuner::backwardTune(FusedOp* fop) {
|
|||
x = riop1->x;
|
||||
y = riop2->x;
|
||||
bo++;
|
||||
LOGvvvv << "backward_w get stride padding and dilation" << stride << padding << dilation;
|
||||
} else if (op->name_ex() == "reindex_reduce.add") {
|
||||
auto rop = (ReindexReduceOp*)op;
|
||||
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
|
||||
|
@ -438,6 +444,7 @@ void ConvTuner::backwardTune(FusedOp* fop) {
|
|||
w = riop1->x;
|
||||
y = riop2->x;
|
||||
bo+=2;
|
||||
LOGvvvv << "backward_x get stride padding and dilation" << stride << padding << dilation;
|
||||
}
|
||||
|
||||
// TODO: CUDA only support nchw(abcd)
|
||||
|
|
Loading…
Reference in New Issue