circt/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp

238 lines
10 KiB
C++

//===- FlattenMemroy.cpp - Flatten Memory Pass ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the FlattenMemory pass.
//
//===----------------------------------------------------------------------===//
#include "PassDetails.h"
#include "circt/Dialect/FIRRTL/FIRRTLAnnotations.h"
#include "circt/Dialect/FIRRTL/FIRRTLOps.h"
#include "circt/Dialect/FIRRTL/FIRRTLTypes.h"
#include "circt/Dialect/FIRRTL/FIRRTLUtils.h"
#include "circt/Dialect/FIRRTL/Passes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <numeric>
#define DEBUG_TYPE "lower-memory"
using namespace circt;
using namespace firrtl;
namespace {
struct FlattenMemoryPass : public FlattenMemoryBase<FlattenMemoryPass> {
/// This pass flattens the aggregate data of memory into a UInt, and inserts
/// appropriate bitcasts to access the data.
void runOnOperation() override {
LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:"
<< getOperation().getName());
SmallVector<Operation *> opsToErase;
auto hasSubAnno = [&](MemOp op) -> bool {
for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
for (auto attr : op.getPortAnnotation(portIdx))
if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
return true;
return false;
};
getOperation().getBodyBlock()->walk([&](MemOp memOp) {
LLVM_DEBUG(llvm::dbgs() << "\n Memory:" << memOp);
// The vector of leaf elements type after flattening the data.
SmallVector<IntType> flatMemType;
// MaskGranularity : how many bits each mask bit controls.
size_t maskGran = 1;
// Total mask bitwidth after flattening.
uint32_t totalmaskWidths = 0;
// How many mask bits each field type requires.
SmallVector<unsigned> maskWidths;
// Cannot flatten a memory if it has debug ports, because debug port
// implies a memtap and we cannot transform the datatype for a memory that
// is tapped.
for (auto res : memOp.getResults())
if (isa<RefType>(res.getType()))
return;
// If subannotations present on aggregate fields, we cannot flatten the
// memory. It must be split into one memory per aggregate field.
// Do not overwrite the pass flag!
if (hasSubAnno(memOp) || !flattenType(memOp.getDataType(), flatMemType))
return;
SmallVector<Operation *, 8> flatData;
SmallVector<int32_t> memWidths;
size_t memFlatWidth = 0;
// Get the width of individual aggregate leaf elements.
for (auto f : flatMemType) {
LLVM_DEBUG(llvm::dbgs() << "\n field type:" << f);
auto w = *f.getWidth();
memWidths.push_back(w);
memFlatWidth += w;
}
// If all the widths are zero, ignore the memory.
if (!memFlatWidth)
return;
maskGran = memWidths[0];
// Compute the GCD of all data bitwidths.
for (auto w : memWidths) {
maskGran = std::gcd(maskGran, w);
}
for (auto w : memWidths) {
// How many mask bits required for each flattened field.
auto mWidth = w / maskGran;
maskWidths.push_back(mWidth);
totalmaskWidths += mWidth;
}
// Now create a new memory of type flattened data.
// ----------------------------------------------
SmallVector<Type, 8> ports;
SmallVector<Attribute, 8> portNames;
auto *context = memOp.getContext();
ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
// Create a new memoty data type of unsigned and computed width.
auto flatType = UIntType::get(context, memFlatWidth);
auto opPorts = memOp.getPorts();
for (size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
auto port = opPorts[portIdx];
ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
port.second, totalmaskWidths));
portNames.push_back(port.first);
}
auto flatMem = builder.create<MemOp>(
ports, memOp.getReadLatency(), memOp.getWriteLatency(),
memOp.getDepth(), memOp.getRuw(), builder.getArrayAttr(portNames),
memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
memOp.getInitAttr(), memOp.getPrefixAttr());
// Hook up the new memory to the wires the old memory was replaced with.
for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
++index) {
auto result = memOp.getResult(index);
auto wire = builder
.create<WireOp>(result.getType(),
(memOp.getName() + "_" +
memOp.getPortName(index).getValue())
.str())
.getResult();
result.replaceAllUsesWith(wire);
result = wire;
auto newResult = flatMem.getResult(index);
auto rType = type_cast<BundleType>(result.getType());
for (size_t fieldIndex = 0, fend = rType.getNumElements();
fieldIndex != fend; ++fieldIndex) {
auto name = rType.getElement(fieldIndex).name.getValue();
auto oldField = builder.create<SubfieldOp>(result, fieldIndex);
FIRRTLBaseValue newField =
builder.create<SubfieldOp>(newResult, fieldIndex);
// data and mask depend on the memory type which was split. They can
// also go both directions, depending on the port direction.
if (!(name == "data" || name == "mask" || name == "wdata" ||
name == "wmask" || name == "rdata")) {
emitConnect(builder, newField, oldField);
continue;
}
Value realOldField = oldField;
if (rType.getElement(fieldIndex).isFlip) {
// Cast the memory read data from flat type to aggregate.
auto castField =
builder.createOrFold<BitCastOp>(oldField.getType(), newField);
// Write the aggregate read data.
emitConnect(builder, realOldField, castField);
} else {
// Cast the input aggregate write data to flat type.
// Cast the input aggregate write data to flat type.
auto newFieldType = newField.getType();
auto oldFieldBitWidth = getBitWidth(oldField.getType());
// Following condition is true, if a data field is 0 bits. Then
// newFieldType is of smaller bits than old.
if (getBitWidth(newFieldType) != *oldFieldBitWidth)
newFieldType = UIntType::get(context, *oldFieldBitWidth);
realOldField = builder.create<BitCastOp>(newFieldType, oldField);
// Mask bits require special handling, since some of the mask bits
// need to be repeated, direct bitcasting wouldn't work. Depending
// on the mask granularity, some mask bits will be repeated.
if ((name == "mask" || name == "wmask") &&
(maskWidths.size() != totalmaskWidths)) {
Value catMasks;
for (const auto &m : llvm::enumerate(maskWidths)) {
// Get the mask bit.
auto mBit = builder.createOrFold<BitsPrimOp>(
realOldField, m.index(), m.index());
// Check how many times the mask bit needs to be prepend.
for (size_t repeat = 0; repeat < m.value(); repeat++)
if ((m.index() == 0 && repeat == 0) || !catMasks)
catMasks = mBit;
else
catMasks = builder.createOrFold<CatPrimOp>(mBit, catMasks);
}
realOldField = catMasks;
}
// Now set the mask or write data.
// Ensure that the types match.
emitConnect(builder, newField,
builder.createOrFold<BitCastOp>(newField.getType(),
realOldField));
}
}
}
++numFlattenedMems;
memOp.erase();
return;
});
}
private:
// Convert an aggregate type into a flat list of fields.
// This is used to flatten the aggregate memory datatype.
// Recursively populate the results with each ground type field.
static bool flattenType(FIRRTLType type, SmallVectorImpl<IntType> &results) {
std::function<bool(FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
return FIRRTLTypeSwitch<FIRRTLType, bool>(type)
.Case<BundleType>([&](auto bundle) {
for (auto &elt : bundle)
if (!flatten(elt.type))
return false;
return true;
})
.Case<FVectorType>([&](auto vector) {
for (size_t i = 0, e = vector.getNumElements(); i != e; ++i)
if (!flatten(vector.getElementType()))
return false;
return true;
})
.Case<IntType>([&](auto iType) {
results.push_back({iType});
return iType.getWidth().has_value();
})
.Default([&](auto) { return false; });
};
// Return true only if this is an aggregate with more than one element.
if (flatten(type) && results.size() > 1)
return true;
return false;
}
Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val, size_t index) {
if (BundleType bundle = type_dyn_cast<BundleType>(val.getType()))
return builder->create<SubfieldOp>(val, index);
if (FVectorType fvector = type_dyn_cast<FVectorType>(val.getType()))
return builder->create<SubindexOp>(val, index);
llvm_unreachable("Unknown aggregate type");
return nullptr;
}
};
} // end anonymous namespace
std::unique_ptr<mlir::Pass> circt::firrtl::createFlattenMemoryPass() {
return std::make_unique<FlattenMemoryPass>();
}