[Comb] Fix more recursive mux folders (#8756)

Co-authored-by: Fabian Schuiki <fabian@schuiki.ch>
This commit is contained in:
Bea Healy 2025-07-25 15:57:07 +01:00 committed by GitHub
parent 30a4a38fca
commit 82a8af4acf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 8 deletions

View File

@ -2078,6 +2078,16 @@ static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
return false;
}
auto isARecursiveMux = [](Value v) {
if (auto muxOp = v.getDefiningOp<MuxOp>())
return muxOp.getTrueValue() == v || muxOp.getFalseValue() == v;
return false;
};
// Avoid infinitely recursing canonicalizations
if (isARecursiveMux(otherValue) || isARecursiveMux(subCond))
return false;
// Invert the outer cond if needed, and combine the mux conditions.
if (!isTrueOperand)
cond = createOrFoldNot(op.getLoc(), cond, rewriter);
@ -2515,10 +2525,16 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
return success();
// mux(cond, opA(cond), opB(cond)) -> mux(cond, opA(1), opB(1))
if (assumeMuxCondInOperand(op.getCond(), op.getTrueValue(), true, rewriter))
return success();
if (assumeMuxCondInOperand(op.getCond(), op.getFalseValue(), false, rewriter))
return success();
if (op.getTrueValue().getDefiningOp() &&
op.getTrueValue().getDefiningOp() != op)
if (assumeMuxCondInOperand(op.getCond(), op.getTrueValue(), true, rewriter))
return success();
if (op.getFalseValue().getDefiningOp() &&
op.getFalseValue().getDefiningOp() != op)
if (assumeMuxCondInOperand(op.getCond(), op.getFalseValue(), false,
rewriter))
return success();
return failure();
}

View File

@ -1227,11 +1227,11 @@ hw.module @muxConstantsFold(in %cond: i1, out o: i25) {
// CHECK-LABEL: hw.module @muxCommon
// This handles various cases of mux(cond, x, someop(x, y, z)).
hw.module @muxCommon(in %cond: i1, in %cond2: i1,
hw.module @muxCommon(in %cond: i1, in %cond2: i1, in %cond3: i1,
in %arg0 : i32, in %arg1 : i32, in %arg2: i32, in %arg3: i32,
out o1: i32, out o2: i32, out o3: i32, out o4: i32, out o5: i32,
out orResult: i32, out o6: i32, out o7: i32, out o8 : i1, out o9: i32,
out o10: i32) {
out o10: i32, out o11: i1, out o12: i32) {
// CHECK: [[TRUE:%.+]] = hw.constant true
// CHECK: [[FALSE:%.+]] = hw.constant false
%true = hw.constant true
@ -1291,10 +1291,18 @@ hw.module @muxCommon(in %cond: i1, in %cond2: i1,
// CHECK: [[O11:%.+]] = comb.mux [[O11]], %cond, %cond2 : i1
%o11 = comb.mux %o11, %cond, %cond2 : i1
// Avoid a non-terminating case
// CHECK: [[LOOPMUX:%.+]] = comb.mux %cond, %arg1, %2
// CHECK: [[MUXONLOOP:%.+]] = comb.mux %cond2, %arg0, [[LOOPMUX]] : i32
// CHECK: [[O12:%.+]] = comb.mux %cond3, %arg0, [[MUXONLOOP]] : i32
%2 = comb.mux %cond, %arg1, %2 : i32
%3 = comb.mux %cond2, %arg0, %2 : i32
%o12 = comb.mux %cond3, %arg0, %3 : i32
// CHECK: hw.output [[O1]], [[O2]], [[O3]], [[O4]], [[O5]], [[ORRESULT]],
// CHECK: [[O6]], [[O7]], [[O8]], [[O9]], [[O10]]
hw.output %o1, %o2, %o3, %o4, %o5, %orResult, %o6, %o7, %o8, %o9, %o10
: i32, i32, i32, i32, i32, i32, i32, i32, i1, i32, i32
hw.output %o1, %o2, %o3, %o4, %o5, %orResult, %o6, %o7, %o8, %o9, %o10, %o11, %o12
: i32, i32, i32, i32, i32, i32, i32, i32, i1, i32, i32, i1, i32
}
// CHECK-LABEL: @flatten_multi_use_and