diff --git a/lib/Dialect/Arc/Transforms/MergeIfs.cpp b/lib/Dialect/Arc/Transforms/MergeIfs.cpp index cb25c42ac8..4ae8cf28c9 100644 --- a/lib/Dialect/Arc/Transforms/MergeIfs.cpp +++ b/lib/Dialect/Arc/Transforms/MergeIfs.cpp @@ -133,13 +133,19 @@ void MergeIfsPass::sinkOps(Block &rootBlock) { // Assign an order to this op. auto order = OpOrder{opOrder.size() + 1, 0}; opOrder[&op] = order; + // Track whether the op is, or contains, any writes (and thus can't + // generally be moved into a block) + bool opContainsWrites = false; // Analyze the side effects in the op. op.walk([&](Operation *subOp) { - if (auto ptr = getPointerWrittenByOp(subOp)) + if (auto ptr = getPointerWrittenByOp(subOp)) { nextWrite[ptr] = &op; - else if (!isa(subOp) && hasSideEffects(subOp)) + opContainsWrites = true; + } else if (!isa(subOp) && + hasSideEffects(subOp)) { nextSideEffect = &op; + } }); // Determine how much the op can be moved. @@ -151,7 +157,7 @@ void MergeIfsPass::sinkOps(Block &rootBlock) { // Don't move across general side-effecting ops. if (nextSideEffect) moveLimit.maximize({nextSideEffect, opOrder.lookup(nextSideEffect)}); - } else if (isa(&op) || nextSideEffect == &op) { + } else if (opContainsWrites || nextSideEffect == &op) { // Don't move writes or side-effecting ops. continue; } diff --git a/test/Dialect/Arc/merge-ifs.mlir b/test/Dialect/Arc/merge-ifs.mlir index cc245f7554..bab275674d 100644 --- a/test/Dialect/Arc/merge-ifs.mlir +++ b/test/Dialect/Arc/merge-ifs.mlir @@ -281,3 +281,28 @@ func.func @MergeNestedIfs(%arg0: i42, %arg1: i1, %arg2: i1) { } return } + +// Check that ops containing a write aren't sunk +// CHECK-LABEL: func.func @DontNestWrites +func.func @DontNestWrites(%arg0: !arc.state, %arg1: i1, %arg2: i1) { + // We just want to check that the first if hasn't been moved into the second + // CHECK-NEXT: {{%.+}} = scf.if %arg1 -> (i1) { + // CHECK: } else { + // CHECK: } + // CHECK-NEXT: scf.if %arg2 { + // CHECK: } + // CHECK-NEXT: return + + %1 = scf.if %arg1 -> (i1) { + %0 = hw.constant true + arc.state_write %arg0 = %0 : + scf.yield %0 : i1 + } else { + %0 = hw.constant false + scf.yield %0 : i1 + } + scf.if %arg2 { + %0 = comb.or %1, %1 : i1 + } + return +}