test array reindex fuse

This commit is contained in:
Dun Liang 2021-12-15 15:38:16 +08:00
parent 03dd698e65
commit bc083360e7
4 changed files with 16 additions and 2 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.1.23'
__version__ = '1.3.1.24'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -28,7 +28,7 @@ EXTERN_LIB string_view_map<FusedOpContext*> jit_fused_ops;
struct FusedOp final : Op {
vector<Op*> ops;
// edges: [[i,j,k,l], ...] represents opi.output(j) == opk.input(i)
// edges: [[i,j,k,l], ...] represents opi.output(j) == opk.input(l)
vector<std::tuple<uint,uint,uint,uint>> edges;
vector<VarInfo> vars;
loop_options_t loop_options_merged, loop_options_tuned;

View File

@ -266,6 +266,15 @@ void LoopVarAnalyzePass::run() {
LOGvvv << "replace_vars" << replace_vars;
ir->replace(replace_vars);
for (int i=0; i<this->op->ops.size(); i++) {
auto op = this->op->ops[i];
if (op->type() == OpType::element &&
op->name() == string("array") &&
op->outputs().front()->num == 1) {
ir->replace({{"op"+S(i)+"_outputshape0", "1"}});
}
}
LOGvvvv << "KernelIR after replace\n" >> ir->to_string(0, true);
// move define
ir->move_loop_back();

View File

@ -236,6 +236,11 @@ class TestFusedOp(unittest.TestCase):
check(64, 60, 64, 1, 0, 42)
check(64, 60, 64, 0, 0, 30) # TODO: why slower?
def test_array_reindex(self):
a = jt.array([1])
b = a.reindex([3], ['i0-1'])
np.testing.assert_allclose(b.data, [0,1,0])
@unittest.skipIf(skip_slow_test, "Skip slow test")
def test_profile_fused_op_restride(self):