This commit is contained in:
Dun Liang 2020-03-24 16:28:35 +08:00
commit 4dee9cc6fa
2 changed files with 9 additions and 9 deletions

View File

@ -166,11 +166,11 @@ void ConvTuner::forwardTune(FusedOp* fop) {
for (Op* op : fop->ops)
if (op->name_ex()=="reduce.add") {
auto rop = (ReduceOp*)op;
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply"))
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
continue;
auto bop = (BinaryOp*)(rop->x->input());
if (!(bop->y->input() && bop->x->input())) continue;
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
if (!((bop->x->input()->name_ex()=="reindex" && bop->y->input()->name_ex()=="broadcast_to") ||
(bop->y->input()->name_ex()=="reindex" && bop->x->input()->name_ex()=="broadcast_to"))) return;
// riop1 reindex -> xx
@ -290,11 +290,11 @@ void ConvTuner::backwardTune(FusedOp* fop) {
string xformat, yformat, wformat;
if (op->name_ex() == "reduce.add") {
auto rop = (ReduceOp*)op;
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply"))
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
continue;
auto bop = (BinaryOp*)(rop->x->input());
if (!(bop->y->input() && bop->x->input())) continue;
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
if (!((bop->x->input()->name_ex()=="reindex" && bop->y->input()->name_ex()=="broadcast_to") ||
(bop->y->input()->name_ex()=="reindex" && bop->x->input()->name_ex()=="broadcast_to"))) continue;
auto riop1 = bop->x->input()->name_ex()=="reindex" ? (ReindexOp*)(bop->x->input()) : (ReindexOp*)(bop->y->input());
@ -365,11 +365,11 @@ void ConvTuner::backwardTune(FusedOp* fop) {
bo++;
} else if (op->name_ex() == "reindex_reduce.add") {
auto rop = (ReindexReduceOp*)op;
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply"))
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
continue;
auto bop = (BinaryOp*)(rop->y->input());
if (!(bop->y->input() && bop->x->input())) continue;
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
if (!((bop->x->input()->name_ex()=="broadcast_to" && bop->y->input()->name_ex()=="broadcast_to"))) return;
auto riop1 = (BroadcastToOp*)(bop->x->input());
auto riop2 = (BroadcastToOp*)(bop->y->input());

View File

@ -22,12 +22,12 @@ void MatmulTuner::run(PassManager* pm, TunerManager* tm) {
for (Op* op : fop->ops) {
if (op->name_ex()!="reduce.add") continue;
auto rop = (ReduceOp*)op;
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply"))
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
continue;
auto bop = (BinaryOp*)(rop->x->input());
if (!(bop->x->input() && bop->x->input()->name_ex()=="broadcast_to"))
if (!(bop->x->input() && bop->x->input()->name_ex()=="broadcast_to" && bop->x->input()->tflag==op->tflag))
continue;
if (!(bop->y->input() && bop->y->input()->name_ex()=="broadcast_to"))
if (!(bop->y->input() && bop->y->input()->name_ex()=="broadcast_to" && bop->y->input()->tflag==op->tflag))
continue;
auto bcop1 = (BroadcastToOp*)(bop->x->input());
auto bcop2 = (BroadcastToOp*)(bop->y->input());