[FIRRTL] Make enums behave less like aggregates (#8742)

This makes enums behave more like ground types, but doesn't quite take
the jump and make them ground types.  It removes the ability to index
into enum variants using field refs, which was used produce error
messages in InferWidths (and enums no longer support having unknown
widths).  This change adds some helpers to determine the sizes of
enumerations.

This change modifes InferWidths to to take advantage of the fact that
enumerations do not support containing uninferred widths, and removes
some dead code.  In addition to this, field refs no longer index into
the variants of an enum type, so we can handle them in a similar way to
ground types.
This commit is contained in:
Andrew Young 2025-07-22 15:20:23 -07:00 committed by GitHub
parent 4ce45d581f
commit 0c5f60cb8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 32 additions and 108 deletions

View File

@ -465,12 +465,6 @@ def IsTagOp : FIRRTLExprOp<"istag"> {
];
let firrtlExtraClassDeclaration = [{
/// Return a `FieldRef` to the accessed field.
FieldRef getAccessedField() {
return FieldRef(getInput(), firrtl::type_cast<FEnumType>(getInput().getType())
.getFieldID(getFieldIndex()));
}
/// Return the name of the accessed field.
StringAttr getFieldNameAttr() {
return firrtl::type_cast<FEnumType>(getInput().getType())
@ -509,12 +503,6 @@ def SubtagOp : FIRRTLExprOp<"subtag"> {
];
let firrtlExtraClassDeclaration = [{
/// Return a `FieldRef` to the accessed field.
FieldRef getAccessedField() {
return FieldRef(getInput(), getInput().getType().base()
.getFieldID(getFieldIndex()));
}
/// Return the name of the accessed field.
StringAttr getFieldNameAttr() {
return getInput().getType().base().getElementNameAttr(getFieldIndex());

View File

@ -297,7 +297,7 @@ def OpenBundleImpl : BaseBundleTypeImpl<"OpenBundle","::circt::firrtl::FIRRTLTyp
let genVerifyDecl = 1;
}
def FEnumImpl : FIRRTLImplType<"FEnum", [DeclareTypeInterfaceMethods<FieldIDTypeInterface>]> {
def FEnumImpl : FIRRTLImplType<"FEnum"> {
let summary = "a sum type of named elements.";
let parameters = (ins "ArrayRef<EnumElement>":$elements, "bool":$isConst);
let storageClass = "FEnumTypeStorage";
@ -344,6 +344,12 @@ def FEnumImpl : FIRRTLImplType<"FEnum", [DeclareTypeInterfaceMethods<FieldIDType
/// Return this type with any type alias types recursively removed from itself.
FIRRTLBaseType getAnonymousType();
/// Get the width of this
size_t getBitWidth();
/// Get the width of the data. Equal to the width of the largest element.
size_t getDataWidth();
/// Get the width of the the tag field.
size_t getTagWidth();

View File

@ -886,7 +886,8 @@ int32_t FIRRTLBaseType::getBitWidthOrSentinel() {
[&](IntType intType) { return intType.getWidthOrSentinel(); })
.Case<AnalogType>(
[](AnalogType analogType) { return analogType.getWidthOrSentinel(); })
.Case<BundleType, FVectorType, FEnumType>([](Type) { return -2; })
.Case<FEnumType>([&](FEnumType fenum) { return fenum.getBitWidth(); })
.Case<BundleType, FVectorType>([](Type) { return -2; })
.Case<BaseTypeAliasType>([](BaseTypeAliasType type) {
// It's faster to use its anonymous type.
return type.getAnonymousType().getBitWidthOrSentinel();
@ -2161,24 +2162,15 @@ struct circt::firrtl::detail::FEnumTypeStorage : detail::FIRRTLBaseTypeStorage {
elements(elements.begin(), elements.end()) {
RecursiveTypeProperties props{true, false, false, isConst,
false, false, false};
uint64_t fieldID = 0;
fieldIDs.reserve(elements.size());
dataSize = 0;
for (auto &element : elements) {
auto type = element.type;
auto eltInfo = type.getRecursiveTypeProperties();
props.isPassive &= eltInfo.isPassive;
props.containsAnalog |= eltInfo.containsAnalog;
props.containsConst |= eltInfo.containsConst;
props.containsReference |= eltInfo.containsReference;
props.containsTypeAlias |= eltInfo.containsTypeAlias;
props.hasUninferredReset |= eltInfo.hasUninferredReset;
props.hasUninferredWidth |= eltInfo.hasUninferredWidth;
fieldID += 1;
fieldIDs.push_back(fieldID);
// Increment the field ID for the next field by the number of subfields.
fieldID += hw::FieldIdImpl::getMaxFieldID(type);
dataSize = std::max((size_t)type.getBitWidthOrSentinel(), dataSize);
}
maxFieldID = fieldID;
recProps = props;
}
@ -2199,10 +2191,8 @@ struct circt::firrtl::detail::FEnumTypeStorage : detail::FIRRTLBaseTypeStorage {
}
SmallVector<FEnumType::EnumElement, 4> elements;
SmallVector<uint64_t, 4> fieldIDs;
uint64_t maxFieldID;
RecursiveTypeProperties recProps;
size_t dataSize;
FIRRTLBaseType anonymousType;
};
@ -2246,7 +2236,10 @@ std::optional<unsigned> FEnumType::getElementIndex(StringAttr name) {
return std::nullopt;
}
/// Get the width of the the tag field.
size_t FEnumType::getBitWidth() { return getDataWidth() + getTagWidth(); }
size_t FEnumType::getDataWidth() { return getImpl()->dataSize; }
size_t FEnumType::getTagWidth() {
if (getElements().size() == 0)
return 0;
@ -2326,45 +2319,6 @@ FIRRTLBaseType FEnumType::getElementTypePreservingConst(size_t index) {
return type.getConstType(type.isConst() || isConst());
}
uint64_t FEnumType::getFieldID(uint64_t index) const {
return getImpl()->fieldIDs[index];
}
uint64_t FEnumType::getIndexForFieldID(uint64_t fieldID) const {
assert(!getElements().empty() && "Enum must have >0 fields");
auto fieldIDs = getImpl()->fieldIDs;
auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
return std::distance(fieldIDs.begin(), it);
}
std::pair<uint64_t, uint64_t>
FEnumType::getIndexAndSubfieldID(uint64_t fieldID) const {
auto index = getIndexForFieldID(fieldID);
auto elementFieldID = getFieldID(index);
return {index, fieldID - elementFieldID};
}
std::pair<Type, uint64_t>
FEnumType::getSubTypeByFieldID(uint64_t fieldID) const {
if (fieldID == 0)
return {*this, 0};
auto subfieldIndex = getIndexForFieldID(fieldID);
auto subfieldType = getElementType(subfieldIndex);
auto subfieldID = fieldID - getFieldID(subfieldIndex);
return {subfieldType, subfieldID};
}
uint64_t FEnumType::getMaxFieldID() const { return getImpl()->maxFieldID; }
std::pair<uint64_t, bool>
FEnumType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
auto childRoot = getFieldID(index);
auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
: (getFieldID(index + 1) - 1);
return std::make_pair(fieldID - childRoot,
fieldID >= childRoot && fieldID <= rangeEnd);
}
LogicalResult FEnumType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
ArrayRef<EnumElement> elements, bool isConst) {
bool first = true;

View File

@ -648,13 +648,6 @@ circt::firrtl::getFieldName(const FieldRef &fieldRef, bool nameSafe) {
// Recurse in to the element type.
type = vecType.getElementType();
localID = localID - vecType.getFieldID(index);
} else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
auto index = enumType.getIndexForFieldID(localID);
auto &element = enumType.getElements()[index];
name += nameSafe ? "_" : ".";
name += element.name.getValue();
type = element.type;
localID = localID - enumType.getFieldID(index);
} else if (auto classType = type_dyn_cast<ClassType>(type)) {
auto index = classType.getIndexForFieldID(localID);
auto &element = classType.getElement(index);
@ -747,10 +740,10 @@ void circt::firrtl::walkGroundTypes(
}
})
.template Case<FEnumType>([&](FEnumType fenum) {
for (size_t i = 0, e = fenum.getNumElements(); i < e; ++i) {
fieldID++;
f(f, fenum.getElementType(i), isFlip);
}
// TODO: are enums aggregates or not? Where is walkGroundTypes called
// from? They are required to have passive types internally, so they
// don't really form an aggregate value.
fn(fieldID, fenum, isFlip);
})
.Default([&](FIRRTLBaseType groundType) {
assert(groundType.isGround() &&

View File

@ -1427,13 +1427,6 @@ LogicalResult InferenceMapping::mapOperation(Operation *op) {
unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 1),
op.getType());
})
.Case<SubtagOp>([&](auto op) {
FEnumType enumType = op.getInput().getType();
auto fieldID = enumType.getFieldID(op.getFieldIndex());
unifyTypes(FieldRef(op.getResult(), 0),
FieldRef(op.getInput(), fieldID), op.getType());
})
.Case<RefSubOp>([&](RefSubOp op) {
uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
op.getInput().getType().getType())
@ -1782,10 +1775,6 @@ void InferenceMapping::declareVars(Value value, bool isDerived) {
declare(vecType.getElementType());
// Skip past the rest of the elements
fieldID = save + vecType.getMaxFieldID();
} else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
fieldID++;
for (auto &element : enumType.getElements())
declare(element.type);
} else {
llvm_unreachable("Unknown type inside a bundle!");
}
@ -1813,9 +1802,10 @@ void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
maximize(vecType.getElementType());
fieldID = save + vecType.getMaxFieldID();
} else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
auto *e = solver.max(getExpr(FieldRef(rhs, fieldID)),
getExpr(FieldRef(lhs, fieldID)));
setExpr(FieldRef(result, fieldID), e);
fieldID++;
for (auto &element : enumType.getElements())
maximize(element.type);
} else if (type.isGround()) {
auto *e = solver.max(getExpr(FieldRef(rhs, fieldID)),
getExpr(FieldRef(lhs, fieldID)));
@ -1861,9 +1851,9 @@ void InferenceMapping::constrainTypes(Value larger, Value smaller, bool equal) {
}
fieldID = save + vecType.getMaxFieldID();
} else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
constrainTypes(getExpr(FieldRef(larger, fieldID)),
getExpr(FieldRef(smaller, fieldID)), false, equal);
fieldID++;
for (auto &element : enumType.getElements())
constrain(element.type, larger, smaller);
} else if (type.isGround()) {
// Leaf element, look up their expressions, and create the constraint.
constrainTypes(getExpr(FieldRef(larger, fieldID)),
@ -1972,9 +1962,13 @@ void InferenceMapping::unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type) {
}
fieldID = save + vecType.getMaxFieldID();
} else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
FieldRef lhsFieldRef(lhs.getValue(), lhs.getFieldID() + fieldID);
FieldRef rhsFieldRef(rhs.getValue(), rhs.getFieldID() + fieldID);
LLVM_DEBUG(llvm::dbgs()
<< "Unify " << getFieldName(lhsFieldRef).first << " = "
<< getFieldName(rhsFieldRef).first << "\n");
setExpr(lhsFieldRef, getExpr(rhsFieldRef));
fieldID++;
for (auto &element : enumType.getElements())
unify(element.type);
} else {
llvm_unreachable("Unknown type inside a bundle!");
}
@ -2234,17 +2228,6 @@ FailureOr<bool> InferenceTypeUpdate::updateValue(Value value) {
// If this is a 0 length vector return the original type.
return type;
}
if (auto enumType = type_dyn_cast<FEnumType>(type)) {
fieldID++;
llvm::SmallVector<FEnumType::EnumElement> elements;
for (auto &element : enumType.getElements()) {
auto updatedBase = updateBase(element.type);
if (!updatedBase)
return {};
elements.emplace_back(element.name, element.value, updatedBase);
}
return FEnumType::get(context, elements, enumType.isConst());
}
llvm_unreachable("Unknown type inside a bundle!");
};