mirror of https://github.com/Jittor/Jittor
test array reindex fuse
This commit is contained in:
parent
03dd698e65
commit
bc083360e7
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.23'
|
||||
__version__ = '1.3.1.24'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -28,7 +28,7 @@ EXTERN_LIB string_view_map<FusedOpContext*> jit_fused_ops;
|
|||
|
||||
struct FusedOp final : Op {
|
||||
vector<Op*> ops;
|
||||
// edges: [[i,j,k,l], ...] represents opi.output(j) == opk.input(i)
|
||||
// edges: [[i,j,k,l], ...] represents opi.output(j) == opk.input(l)
|
||||
vector<std::tuple<uint,uint,uint,uint>> edges;
|
||||
vector<VarInfo> vars;
|
||||
loop_options_t loop_options_merged, loop_options_tuned;
|
||||
|
|
|
@ -266,6 +266,15 @@ void LoopVarAnalyzePass::run() {
|
|||
|
||||
LOGvvv << "replace_vars" << replace_vars;
|
||||
ir->replace(replace_vars);
|
||||
|
||||
for (int i=0; i<this->op->ops.size(); i++) {
|
||||
auto op = this->op->ops[i];
|
||||
if (op->type() == OpType::element &&
|
||||
op->name() == string("array") &&
|
||||
op->outputs().front()->num == 1) {
|
||||
ir->replace({{"op"+S(i)+"_outputshape0", "1"}});
|
||||
}
|
||||
}
|
||||
LOGvvvv << "KernelIR after replace\n" >> ir->to_string(0, true);
|
||||
// move define
|
||||
ir->move_loop_back();
|
||||
|
|
|
@ -236,6 +236,11 @@ class TestFusedOp(unittest.TestCase):
|
|||
check(64, 60, 64, 1, 0, 42)
|
||||
check(64, 60, 64, 0, 0, 30) # TODO: why slower?
|
||||
|
||||
def test_array_reindex(self):
|
||||
a = jt.array([1])
|
||||
b = a.reindex([3], ['i0-1'])
|
||||
np.testing.assert_allclose(b.data, [0,1,0])
|
||||
|
||||
|
||||
@unittest.skipIf(skip_slow_test, "Skip slow test")
|
||||
def test_profile_fused_op_restride(self):
|
||||
|
|
Loading…
Reference in New Issue