polish fuser

This commit is contained in:
Dun Liang 2022-01-11 17:25:00 +08:00
parent 5b4576c4dd
commit 6d1b5e42bc
3 changed files with 21 additions and 1 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.1.34'
__version__ = '1.3.1.35'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -276,6 +276,19 @@ void LoopVarAnalyzePass::run() {
ir->replace({{"op"+S(i)+"_outputshape0", "1"}});
}
}
// fix index op stride not found
replace_vars.clear();
for (int i=0; i<this->op->ops.size(); i++) {
auto op = this->op->ops[i];
if (op->type() == OpType::element &&
op->name() == string("index")) {
for (int j=1; j<op->outputs().size(); i++)
replace_vars.push_back({"op"+S(i)+"_x"+S(j)+"stride", "op"+S(i)+"_x0stride"});
}
}
if (replace_vars.size())
ir->replace(replace_vars);
LOGvvvv << "KernelIR after replace\n" >> ir->to_string(0, true);
// move define
ir->move_loop_back();

View File

@ -56,5 +56,12 @@ class TestIndexOp(unittest.TestCase):
def test_doc(self):
assert "Index Operator" in jt.index.__doc__
def test_wrong_fuse(self):
a,b = jt.index([10,10])
c = jt.zeros([10,10])
c = c.reindex([b+1,a])
x = b.clone()
jt.sync([c, x])
if __name__ == "__main__":
unittest.main()