mirror of https://github.com/llvm/circt.git
[FIRRTL] FlattenMemories: code cleanup, nfci
This commit is contained in:
parent
b4da671b10
commit
959e0f18e2
|
@ -34,30 +34,25 @@ using namespace firrtl;
|
|||
namespace {
|
||||
struct FlattenMemoryPass
|
||||
: public circt::firrtl::impl::FlattenMemoryBase<FlattenMemoryPass> {
|
||||
|
||||
/// Returns true if the the memory has annotations on a subfield of any of the
|
||||
/// ports.
|
||||
static bool hasSubAnno(MemOp op) {
|
||||
for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
|
||||
for (auto attr : op.getPortAnnotation(portIdx))
|
||||
if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
/// This pass flattens the aggregate data of memory into a UInt, and inserts
|
||||
/// appropriate bitcasts to access the data.
|
||||
void runOnOperation() override {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:"
|
||||
<< getOperation().getName());
|
||||
SmallVector<Operation *> opsToErase;
|
||||
auto hasSubAnno = [&](MemOp op) -> bool {
|
||||
for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
|
||||
for (auto attr : op.getPortAnnotation(portIdx))
|
||||
if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
};
|
||||
getOperation().getBodyBlock()->walk([&](MemOp memOp) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n Memory:" << memOp);
|
||||
// The vector of leaf elements type after flattening the data.
|
||||
SmallVector<IntType> flatMemType;
|
||||
// MaskGranularity : how many bits each mask bit controls.
|
||||
size_t maskGran = 1;
|
||||
// Total mask bitwidth after flattening.
|
||||
uint32_t totalmaskWidths = 0;
|
||||
// How many mask bits each field type requires.
|
||||
SmallVector<unsigned> maskWidths;
|
||||
|
||||
// Cannot flatten a memory if it has debug ports, because debug port
|
||||
// implies a memtap and we cannot transform the datatype for a memory that
|
||||
|
@ -65,36 +60,51 @@ struct FlattenMemoryPass
|
|||
for (auto res : memOp.getResults())
|
||||
if (isa<RefType>(res.getType()))
|
||||
return;
|
||||
|
||||
// If subannotations present on aggregate fields, we cannot flatten the
|
||||
// memory. It must be split into one memory per aggregate field.
|
||||
// Do not overwrite the pass flag!
|
||||
if (hasSubAnno(memOp) || !flattenType(memOp.getDataType(), flatMemType))
|
||||
if (hasSubAnno(memOp))
|
||||
return;
|
||||
|
||||
SmallVector<Operation *, 8> flatData;
|
||||
SmallVector<int32_t> memWidths;
|
||||
// The vector of leaf elements type after flattening the data. If any of
|
||||
// the datatypes cannot be flattened, then we cannot flatten the memory.
|
||||
SmallVector<FIRRTLBaseType> flatMemType;
|
||||
if (!flattenType(memOp.getDataType(), flatMemType))
|
||||
return;
|
||||
|
||||
// Calculate the width of the memory data type, and the width of
|
||||
// each individual aggregate leaf elements.
|
||||
size_t memFlatWidth = 0;
|
||||
// Get the width of individual aggregate leaf elements.
|
||||
SmallVector<int32_t> memWidths;
|
||||
for (auto f : flatMemType) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n field type:" << f);
|
||||
auto w = *f.getWidth();
|
||||
auto w = f.getBitWidthOrSentinel();
|
||||
memWidths.push_back(w);
|
||||
memFlatWidth += w;
|
||||
}
|
||||
// If all the widths are zero, ignore the memory.
|
||||
if (!memFlatWidth)
|
||||
return;
|
||||
maskGran = memWidths[0];
|
||||
// Compute the GCD of all data bitwidths.
|
||||
for (auto w : memWidths) {
|
||||
|
||||
// Calculate the mask granularity of this memory, which is how many bits
|
||||
// of the data each mask bit controls. This is the greatest common
|
||||
// denominator of the widths of the flattened data types.
|
||||
auto maskGran = memWidths.front();
|
||||
for (auto w : ArrayRef(memWidths).drop_front())
|
||||
maskGran = std::gcd(maskGran, w);
|
||||
}
|
||||
|
||||
// Total mask bitwidth after flattening.
|
||||
uint32_t totalmaskWidths = 0;
|
||||
// How many mask bits each field type requires.
|
||||
SmallVector<unsigned> maskWidths;
|
||||
for (auto w : memWidths) {
|
||||
// How many mask bits required for each flattened field.
|
||||
auto mWidth = w / maskGran;
|
||||
maskWidths.push_back(mWidth);
|
||||
totalmaskWidths += mWidth;
|
||||
}
|
||||
|
||||
// Now create a new memory of type flattened data.
|
||||
// ----------------------------------------------
|
||||
SmallVector<Type, 8> ports;
|
||||
|
@ -102,25 +112,29 @@ struct FlattenMemoryPass
|
|||
|
||||
auto *context = memOp.getContext();
|
||||
ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
|
||||
// Create a new memoty data type of unsigned and computed width.
|
||||
// Create a new memory data type of unsigned and computed width.
|
||||
auto flatType = UIntType::get(context, memFlatWidth);
|
||||
auto opPorts = memOp.getPorts();
|
||||
for (size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
|
||||
auto port = opPorts[portIdx];
|
||||
for (auto port : memOp.getPorts()) {
|
||||
ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
|
||||
port.second, totalmaskWidths));
|
||||
portNames.push_back(port.first);
|
||||
}
|
||||
|
||||
// Create the new flattened memory.
|
||||
auto flatMem = builder.create<MemOp>(
|
||||
ports, memOp.getReadLatency(), memOp.getWriteLatency(),
|
||||
memOp.getDepth(), memOp.getRuw(), builder.getArrayAttr(portNames),
|
||||
memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
|
||||
memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
|
||||
memOp.getInitAttr(), memOp.getPrefixAttr());
|
||||
|
||||
// Hook up the new memory to the wires the old memory was replaced with.
|
||||
for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
|
||||
++index) {
|
||||
|
||||
// Create a wire with the original type, and replace all uses of the old
|
||||
// memory with the wire. We will be reconstructing the original type
|
||||
// in the wire from the bitvector of the flattened memory.
|
||||
auto result = memOp.getResult(index);
|
||||
auto wire = builder
|
||||
.create<WireOp>(result.getType(),
|
||||
|
@ -134,7 +148,7 @@ struct FlattenMemoryPass
|
|||
auto rType = type_cast<BundleType>(result.getType());
|
||||
for (size_t fieldIndex = 0, fend = rType.getNumElements();
|
||||
fieldIndex != fend; ++fieldIndex) {
|
||||
auto name = rType.getElement(fieldIndex).name.getValue();
|
||||
auto name = rType.getElement(fieldIndex).name;
|
||||
auto oldField = builder.create<SubfieldOp>(result, fieldIndex);
|
||||
FIRRTLBaseValue newField =
|
||||
builder.create<SubfieldOp>(newResult, fieldIndex);
|
||||
|
@ -153,7 +167,6 @@ struct FlattenMemoryPass
|
|||
// Write the aggregate read data.
|
||||
emitConnect(builder, realOldField, castField);
|
||||
} else {
|
||||
// Cast the input aggregate write data to flat type.
|
||||
// Cast the input aggregate write data to flat type.
|
||||
auto newFieldType = newField.getType();
|
||||
auto oldFieldBitWidth = getBitWidth(oldField.getType());
|
||||
|
@ -197,10 +210,11 @@ struct FlattenMemoryPass
|
|||
}
|
||||
|
||||
private:
|
||||
// Convert an aggregate type into a flat list of fields.
|
||||
// This is used to flatten the aggregate memory datatype.
|
||||
// Recursively populate the results with each ground type field.
|
||||
static bool flattenType(FIRRTLType type, SmallVectorImpl<IntType> &results) {
|
||||
// Convert an aggregate type into a flat list of fields. This is used to
|
||||
// flatten the aggregate memory datatype. Recursively populate the results
|
||||
// with each ground type field.
|
||||
static bool flattenType(FIRRTLType type,
|
||||
SmallVectorImpl<FIRRTLBaseType> &results) {
|
||||
std::function<bool(FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
|
||||
return FIRRTLTypeSwitch<FIRRTLType, bool>(type)
|
||||
.Case<BundleType>([&](auto bundle) {
|
||||
|
@ -226,9 +240,7 @@ private:
|
|||
.Default([&](auto) { return false; });
|
||||
};
|
||||
// Return true only if this is an aggregate with more than one element.
|
||||
if (flatten(type) && results.size() > 1)
|
||||
return true;
|
||||
return false;
|
||||
return flatten(type) && results.size() > 1;
|
||||
}
|
||||
|
||||
Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val, size_t index) {
|
||||
|
|
Loading…
Reference in New Issue