mirror of https://github.com/Jittor/Jittor
Merge pull request #28 from Jittor/fix_tuner_bug
fix conv&matmul tuner recognize bug
This commit is contained in:
commit
5f84bf11f3
|
@ -166,11 +166,11 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
for (Op* op : fop->ops)
|
||||
if (op->name_ex()=="reduce.add") {
|
||||
auto rop = (ReduceOp*)op;
|
||||
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply"))
|
||||
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
|
||||
continue;
|
||||
auto bop = (BinaryOp*)(rop->x->input());
|
||||
|
||||
if (!(bop->y->input() && bop->x->input())) continue;
|
||||
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
|
||||
if (!((bop->x->input()->name_ex()=="reindex" && bop->y->input()->name_ex()=="broadcast_to") ||
|
||||
(bop->y->input()->name_ex()=="reindex" && bop->x->input()->name_ex()=="broadcast_to"))) return;
|
||||
// riop1 reindex -> xx
|
||||
|
@ -290,11 +290,11 @@ void ConvTuner::backwardTune(FusedOp* fop) {
|
|||
string xformat, yformat, wformat;
|
||||
if (op->name_ex() == "reduce.add") {
|
||||
auto rop = (ReduceOp*)op;
|
||||
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply"))
|
||||
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
|
||||
continue;
|
||||
auto bop = (BinaryOp*)(rop->x->input());
|
||||
|
||||
if (!(bop->y->input() && bop->x->input())) continue;
|
||||
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
|
||||
if (!((bop->x->input()->name_ex()=="reindex" && bop->y->input()->name_ex()=="broadcast_to") ||
|
||||
(bop->y->input()->name_ex()=="reindex" && bop->x->input()->name_ex()=="broadcast_to"))) continue;
|
||||
auto riop1 = bop->x->input()->name_ex()=="reindex" ? (ReindexOp*)(bop->x->input()) : (ReindexOp*)(bop->y->input());
|
||||
|
@ -365,11 +365,11 @@ void ConvTuner::backwardTune(FusedOp* fop) {
|
|||
bo++;
|
||||
} else if (op->name_ex() == "reindex_reduce.add") {
|
||||
auto rop = (ReindexReduceOp*)op;
|
||||
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply"))
|
||||
if (!(rop->y->input() && rop->y->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
|
||||
continue;
|
||||
auto bop = (BinaryOp*)(rop->y->input());
|
||||
|
||||
if (!(bop->y->input() && bop->x->input())) continue;
|
||||
if (!(bop->y->input() && bop->x->input() && bop->x->input()->tflag==op->tflag && bop->y->input()->tflag==op->tflag)) continue;
|
||||
if (!((bop->x->input()->name_ex()=="broadcast_to" && bop->y->input()->name_ex()=="broadcast_to"))) return;
|
||||
auto riop1 = (BroadcastToOp*)(bop->x->input());
|
||||
auto riop2 = (BroadcastToOp*)(bop->y->input());
|
||||
|
|
|
@ -22,12 +22,12 @@ void MatmulTuner::run(PassManager* pm, TunerManager* tm) {
|
|||
for (Op* op : fop->ops) {
|
||||
if (op->name_ex()!="reduce.add") continue;
|
||||
auto rop = (ReduceOp*)op;
|
||||
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply"))
|
||||
if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply" && rop->x->input()->tflag==op->tflag))
|
||||
continue;
|
||||
auto bop = (BinaryOp*)(rop->x->input());
|
||||
if (!(bop->x->input() && bop->x->input()->name_ex()=="broadcast_to"))
|
||||
if (!(bop->x->input() && bop->x->input()->name_ex()=="broadcast_to" && bop->x->input()->tflag==op->tflag))
|
||||
continue;
|
||||
if (!(bop->y->input() && bop->y->input()->name_ex()=="broadcast_to"))
|
||||
if (!(bop->y->input() && bop->y->input()->name_ex()=="broadcast_to" && bop->y->input()->tflag==op->tflag))
|
||||
continue;
|
||||
auto bcop1 = (BroadcastToOp*)(bop->x->input());
|
||||
auto bcop2 = (BroadcastToOp*)(bop->y->input());
|
||||
|
|
Loading…
Reference in New Issue