mirror of https://github.com/Jittor/Jittor
polish fuser
This commit is contained in:
parent
5b4576c4dd
commit
6d1b5e42bc
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.34'
|
||||
__version__ = '1.3.1.35'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -276,6 +276,19 @@ void LoopVarAnalyzePass::run() {
|
|||
ir->replace({{"op"+S(i)+"_outputshape0", "1"}});
|
||||
}
|
||||
}
|
||||
|
||||
// fix index op stride not found
|
||||
replace_vars.clear();
|
||||
for (int i=0; i<this->op->ops.size(); i++) {
|
||||
auto op = this->op->ops[i];
|
||||
if (op->type() == OpType::element &&
|
||||
op->name() == string("index")) {
|
||||
for (int j=1; j<op->outputs().size(); i++)
|
||||
replace_vars.push_back({"op"+S(i)+"_x"+S(j)+"stride", "op"+S(i)+"_x0stride"});
|
||||
}
|
||||
}
|
||||
if (replace_vars.size())
|
||||
ir->replace(replace_vars);
|
||||
LOGvvvv << "KernelIR after replace\n" >> ir->to_string(0, true);
|
||||
// move define
|
||||
ir->move_loop_back();
|
||||
|
|
|
@ -56,5 +56,12 @@ class TestIndexOp(unittest.TestCase):
|
|||
def test_doc(self):
|
||||
assert "Index Operator" in jt.index.__doc__
|
||||
|
||||
def test_wrong_fuse(self):
|
||||
a,b = jt.index([10,10])
|
||||
c = jt.zeros([10,10])
|
||||
c = c.reindex([b+1,a])
|
||||
x = b.clone()
|
||||
jt.sync([c, x])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue