[FIRRTL] make mux type inference support enumeration types

This commit is contained in:
Andrew Young 2025-07-23 11:18:45 -07:00
parent da47d826f6
commit 57a2a9aedb
2 changed files with 38 additions and 0 deletions

View File

@ -5403,6 +5403,33 @@ static FIRRTLBaseType inferMuxReturnType(FIRRTLBaseType high,
return (lowWidth > highWidth ? low : high).getConstType(outerTypeIsConst);
}
// Two different Enum types can be compatible if one is the constant version
// of the other.
auto highEnum = type_dyn_cast<FEnumType>(high);
auto lowEnum = type_dyn_cast<FEnumType>(low);
if (lowEnum && highEnum) {
if (lowEnum.getNumElements() != highEnum.getNumElements())
return emitInferRetTypeError<FIRRTLBaseType>(
loc, "incompatible mux operand types, true value type: ", high,
", false value type: ", low);
SmallVector<FEnumType::EnumElement> elements;
for (auto [high, low] : llvm::zip_equal(highEnum, lowEnum)) {
// Variants should have the same name and value.
if (high.name != low.name || high.value != low.value)
return emitInferRetTypeError<FIRRTLBaseType>(
loc, "incompatible mux operand types, true value type: ", highEnum,
", false value type: ", lowEnum);
// Enumerations can only have constant variants only if the whole
// enumeration is constant, so this logic can differ a bit from bundles.
auto inner =
inferMuxReturnType(high.type, low.type, isConstCondition, loc);
if (!inner)
return {};
elements.emplace_back(high.name, high.value, inner);
}
return FEnumType::get(high.getContext(), elements, outerTypeIsConst);
}
// Infer vector types by comparing the element types.
auto highVector = type_dyn_cast<FVectorType>(high);
auto lowVector = type_dyn_cast<FVectorType>(low);

View File

@ -202,6 +202,17 @@ firrtl.module @ElementwiseMixedConstOperandsNonConstResult(in %a: !firrtl.const.
-> !firrtl.vector<uint<1>, 2>
}
// Mux result is const when all inputs are const.
// CHECK: firrtl.module @MuxConstConditionConstEnumsConstResult(in %p: !firrtl.const.uint<1>,
firrtl.module @MuxConstConditionConstEnumsConstResult(in %p: !firrtl.const.uint<1>,
in %a: !firrtl.const.enum<a: const.uint<1>>,
in %b: !firrtl.const.enum<a: uint<1>>) {
%0 = firrtl.mux(%p, %a, %b) : (!firrtl.const.uint<1>,
!firrtl.const.enum<a: const.uint<1>>,
!firrtl.const.enum<a: uint<1>>)
-> !firrtl.const.enum<a: uint<1>>
}
// Mux result is const when all inputs are const.
firrtl.module @MuxConstConditionConstBundlesConstResult(in %p: !firrtl.const.uint<1>,
in %a: !firrtl.const.bundle<a: uint<1>>,