[FIRRTL] FlattenMemories: code cleanup, nfci

This commit is contained in:
Andrew Young 2025-07-18 17:58:57 -07:00
parent b4da671b10
commit 959e0f18e2
1 changed files with 50 additions and 38 deletions

View File

@ -34,30 +34,25 @@ using namespace firrtl;
namespace {
struct FlattenMemoryPass
: public circt::firrtl::impl::FlattenMemoryBase<FlattenMemoryPass> {
/// Returns true if the the memory has annotations on a subfield of any of the
/// ports.
static bool hasSubAnno(MemOp op) {
for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
for (auto attr : op.getPortAnnotation(portIdx))
if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
return true;
return false;
};
/// This pass flattens the aggregate data of memory into a UInt, and inserts
/// appropriate bitcasts to access the data.
void runOnOperation() override {
LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:"
<< getOperation().getName());
SmallVector<Operation *> opsToErase;
auto hasSubAnno = [&](MemOp op) -> bool {
for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
for (auto attr : op.getPortAnnotation(portIdx))
if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
return true;
return false;
};
getOperation().getBodyBlock()->walk([&](MemOp memOp) {
LLVM_DEBUG(llvm::dbgs() << "\n Memory:" << memOp);
// The vector of leaf elements type after flattening the data.
SmallVector<IntType> flatMemType;
// MaskGranularity : how many bits each mask bit controls.
size_t maskGran = 1;
// Total mask bitwidth after flattening.
uint32_t totalmaskWidths = 0;
// How many mask bits each field type requires.
SmallVector<unsigned> maskWidths;
// Cannot flatten a memory if it has debug ports, because debug port
// implies a memtap and we cannot transform the datatype for a memory that
@ -65,36 +60,51 @@ struct FlattenMemoryPass
for (auto res : memOp.getResults())
if (isa<RefType>(res.getType()))
return;
// If subannotations present on aggregate fields, we cannot flatten the
// memory. It must be split into one memory per aggregate field.
// Do not overwrite the pass flag!
if (hasSubAnno(memOp) || !flattenType(memOp.getDataType(), flatMemType))
if (hasSubAnno(memOp))
return;
SmallVector<Operation *, 8> flatData;
SmallVector<int32_t> memWidths;
// The vector of leaf elements type after flattening the data. If any of
// the datatypes cannot be flattened, then we cannot flatten the memory.
SmallVector<FIRRTLBaseType> flatMemType;
if (!flattenType(memOp.getDataType(), flatMemType))
return;
// Calculate the width of the memory data type, and the width of
// each individual aggregate leaf elements.
size_t memFlatWidth = 0;
// Get the width of individual aggregate leaf elements.
SmallVector<int32_t> memWidths;
for (auto f : flatMemType) {
LLVM_DEBUG(llvm::dbgs() << "\n field type:" << f);
auto w = *f.getWidth();
auto w = f.getBitWidthOrSentinel();
memWidths.push_back(w);
memFlatWidth += w;
}
// If all the widths are zero, ignore the memory.
if (!memFlatWidth)
return;
maskGran = memWidths[0];
// Compute the GCD of all data bitwidths.
for (auto w : memWidths) {
// Calculate the mask granularity of this memory, which is how many bits
// of the data each mask bit controls. This is the greatest common
// denominator of the widths of the flattened data types.
auto maskGran = memWidths.front();
for (auto w : ArrayRef(memWidths).drop_front())
maskGran = std::gcd(maskGran, w);
}
// Total mask bitwidth after flattening.
uint32_t totalmaskWidths = 0;
// How many mask bits each field type requires.
SmallVector<unsigned> maskWidths;
for (auto w : memWidths) {
// How many mask bits required for each flattened field.
auto mWidth = w / maskGran;
maskWidths.push_back(mWidth);
totalmaskWidths += mWidth;
}
// Now create a new memory of type flattened data.
// ----------------------------------------------
SmallVector<Type, 8> ports;
@ -102,25 +112,29 @@ struct FlattenMemoryPass
auto *context = memOp.getContext();
ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
// Create a new memoty data type of unsigned and computed width.
// Create a new memory data type of unsigned and computed width.
auto flatType = UIntType::get(context, memFlatWidth);
auto opPorts = memOp.getPorts();
for (size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
auto port = opPorts[portIdx];
for (auto port : memOp.getPorts()) {
ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
port.second, totalmaskWidths));
portNames.push_back(port.first);
}
// Create the new flattened memory.
auto flatMem = builder.create<MemOp>(
ports, memOp.getReadLatency(), memOp.getWriteLatency(),
memOp.getDepth(), memOp.getRuw(), builder.getArrayAttr(portNames),
memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
memOp.getInitAttr(), memOp.getPrefixAttr());
// Hook up the new memory to the wires the old memory was replaced with.
for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
++index) {
// Create a wire with the original type, and replace all uses of the old
// memory with the wire. We will be reconstructing the original type
// in the wire from the bitvector of the flattened memory.
auto result = memOp.getResult(index);
auto wire = builder
.create<WireOp>(result.getType(),
@ -134,7 +148,7 @@ struct FlattenMemoryPass
auto rType = type_cast<BundleType>(result.getType());
for (size_t fieldIndex = 0, fend = rType.getNumElements();
fieldIndex != fend; ++fieldIndex) {
auto name = rType.getElement(fieldIndex).name.getValue();
auto name = rType.getElement(fieldIndex).name;
auto oldField = builder.create<SubfieldOp>(result, fieldIndex);
FIRRTLBaseValue newField =
builder.create<SubfieldOp>(newResult, fieldIndex);
@ -153,7 +167,6 @@ struct FlattenMemoryPass
// Write the aggregate read data.
emitConnect(builder, realOldField, castField);
} else {
// Cast the input aggregate write data to flat type.
// Cast the input aggregate write data to flat type.
auto newFieldType = newField.getType();
auto oldFieldBitWidth = getBitWidth(oldField.getType());
@ -197,10 +210,11 @@ struct FlattenMemoryPass
}
private:
// Convert an aggregate type into a flat list of fields.
// This is used to flatten the aggregate memory datatype.
// Recursively populate the results with each ground type field.
static bool flattenType(FIRRTLType type, SmallVectorImpl<IntType> &results) {
// Convert an aggregate type into a flat list of fields. This is used to
// flatten the aggregate memory datatype. Recursively populate the results
// with each ground type field.
static bool flattenType(FIRRTLType type,
SmallVectorImpl<FIRRTLBaseType> &results) {
std::function<bool(FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
return FIRRTLTypeSwitch<FIRRTLType, bool>(type)
.Case<BundleType>([&](auto bundle) {
@ -226,9 +240,7 @@ private:
.Default([&](auto) { return false; });
};
// Return true only if this is an aggregate with more than one element.
if (flatten(type) && results.size() > 1)
return true;
return false;
return flatten(type) && results.size() > 1;
}
Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val, size_t index) {