mirror of https://github.com/llvm/circt.git
396 lines
14 KiB
C++
396 lines
14 KiB
C++
//===- FirMemLowering.cpp - FirMem lowering utilities ---------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "FirMemLowering.h"
|
|
#include "mlir/IR/Threading.h"
|
|
#include "llvm/ADT/MapVector.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
using namespace circt;
|
|
using namespace hw;
|
|
using namespace seq;
|
|
using llvm::MapVector;
|
|
|
|
#define DEBUG_TYPE "lower-seq-firmem"
|
|
|
|
FirMemLowering::FirMemLowering(ModuleOp circuit)
|
|
: context(circuit.getContext()), circuit(circuit) {
|
|
symbolCache.addDefinitions(circuit);
|
|
globalNamespace.add(symbolCache);
|
|
|
|
// For each module, assign an index. Use it to identify the insertion point
|
|
// for the generated ops.
|
|
for (auto [index, module] : llvm::enumerate(circuit.getOps<HWModuleOp>()))
|
|
moduleIndex[module] = index;
|
|
}
|
|
|
|
/// Collect the memories in a list of HW modules.
|
|
FirMemLowering::UniqueConfigs
|
|
FirMemLowering::collectMemories(ArrayRef<HWModuleOp> modules) {
|
|
// For each module in the list populate a separate vector of `FirMemOp`s in
|
|
// that module. This allows for the traversal of the HW modules to be
|
|
// parallelized.
|
|
using ModuleMemories = SmallVector<std::pair<FirMemConfig, FirMemOp>, 0>;
|
|
SmallVector<ModuleMemories> memories(modules.size());
|
|
|
|
mlir::parallelFor(context, 0, modules.size(), [&](auto idx) {
|
|
// TODO: Check if this module is in the DUT hierarchy.
|
|
// bool isInDut = state.isInDUT(module);
|
|
HWModuleOp(modules[idx]).walk([&](seq::FirMemOp op) {
|
|
memories[idx].push_back({collectMemory(op), op});
|
|
});
|
|
});
|
|
|
|
// Group the gathered memories by unique `FirMemConfig` details.
|
|
MapVector<FirMemConfig, SmallVector<FirMemOp, 1>> grouped;
|
|
for (auto [module, moduleMemories] : llvm::zip(modules, memories))
|
|
for (auto [summary, memOp] : moduleMemories)
|
|
grouped[summary].push_back(memOp);
|
|
|
|
return grouped;
|
|
}
|
|
|
|
/// Trace a value through wires to its original definition.
|
|
static Value lookThroughWires(Value value) {
|
|
while (value) {
|
|
if (auto wireOp = value.getDefiningOp<WireOp>()) {
|
|
value = wireOp.getInput();
|
|
continue;
|
|
}
|
|
break;
|
|
}
|
|
return value;
|
|
}
|
|
|
|
/// Determine the exact parametrization of the memory that should be generated
|
|
/// for a given `FirMemOp`.
|
|
FirMemConfig FirMemLowering::collectMemory(FirMemOp op) {
|
|
FirMemConfig cfg;
|
|
cfg.dataWidth = op.getType().getWidth();
|
|
cfg.depth = op.getType().getDepth();
|
|
cfg.readLatency = op.getReadLatency();
|
|
cfg.writeLatency = op.getWriteLatency();
|
|
cfg.maskBits = op.getType().getMaskWidth().value_or(1);
|
|
cfg.readUnderWrite = op.getRuw();
|
|
cfg.writeUnderWrite = op.getWuw();
|
|
if (auto init = op.getInitAttr()) {
|
|
cfg.initFilename = init.getFilename();
|
|
cfg.initIsBinary = init.getIsBinary();
|
|
cfg.initIsInline = init.getIsInline();
|
|
}
|
|
cfg.outputFile = op.getOutputFileAttr();
|
|
if (auto prefix = op.getPrefixAttr())
|
|
cfg.prefix = prefix.getValue();
|
|
// TODO: Handle modName (maybe not?)
|
|
// TODO: Handle groupID (maybe not?)
|
|
|
|
// Count the read, write, and read-write ports, and identify the clocks
|
|
// driving the write ports.
|
|
SmallDenseMap<Value, unsigned> clockValues;
|
|
for (auto *user : op->getUsers()) {
|
|
if (isa<FirMemReadOp>(user))
|
|
++cfg.numReadPorts;
|
|
else if (isa<FirMemWriteOp>(user))
|
|
++cfg.numWritePorts;
|
|
else if (isa<FirMemReadWriteOp>(user))
|
|
++cfg.numReadWritePorts;
|
|
|
|
// Assign IDs to the values used as clock. This allows later passes to
|
|
// easily detect which clocks are effectively driven by the same value.
|
|
if (isa<FirMemWriteOp, FirMemReadWriteOp>(user)) {
|
|
auto clock = lookThroughWires(user->getOperand(2));
|
|
cfg.writeClockIDs.push_back(
|
|
clockValues.insert({clock, clockValues.size()}).first->second);
|
|
}
|
|
}
|
|
|
|
return cfg;
|
|
}
|
|
|
|
FlatSymbolRefAttr FirMemLowering::getOrCreateSchema() {
|
|
if (!schemaOp) {
|
|
// Create or re-use the generator schema.
|
|
for (auto op : circuit.getOps<hw::HWGeneratorSchemaOp>()) {
|
|
if (op.getDescriptor() == "FIRRTL_Memory") {
|
|
schemaOp = op;
|
|
break;
|
|
}
|
|
}
|
|
if (!schemaOp) {
|
|
auto builder = OpBuilder::atBlockBegin(circuit.getBody());
|
|
std::array<StringRef, 14> schemaFields = {
|
|
"depth", "numReadPorts",
|
|
"numWritePorts", "numReadWritePorts",
|
|
"readLatency", "writeLatency",
|
|
"width", "maskGran",
|
|
"readUnderWrite", "writeUnderWrite",
|
|
"writeClockIDs", "initFilename",
|
|
"initIsBinary", "initIsInline"};
|
|
schemaOp = hw::HWGeneratorSchemaOp::create(
|
|
builder, circuit.getLoc(), "FIRRTLMem", "FIRRTL_Memory",
|
|
builder.getStrArrayAttr(schemaFields));
|
|
}
|
|
}
|
|
return FlatSymbolRefAttr::get(schemaOp);
|
|
}
|
|
|
|
/// Create the `HWModuleGeneratedOp` for a single memory parametrization.
|
|
HWModuleGeneratedOp
|
|
FirMemLowering::createMemoryModule(FirMemConfig &mem,
|
|
ArrayRef<seq::FirMemOp> memOps) {
|
|
auto schemaSymRef = getOrCreateSchema();
|
|
|
|
// Identify the first module which uses the memory configuration.
|
|
// Insert the generated module before it.
|
|
HWModuleOp insertPt;
|
|
for (auto memOp : memOps) {
|
|
auto parent = memOp->getParentOfType<HWModuleOp>();
|
|
if (!insertPt || moduleIndex[parent] < moduleIndex[insertPt])
|
|
insertPt = parent;
|
|
}
|
|
|
|
OpBuilder builder(context);
|
|
builder.setInsertionPoint(insertPt);
|
|
|
|
// Pick a name for the memory. Honor the optional prefix and try to include
|
|
// the common part of the names of the memory instances that use this
|
|
// configuration. The resulting name is of the form:
|
|
//
|
|
// <prefix>_<commonName>_<depth>x<width>
|
|
//
|
|
StringRef baseName = "";
|
|
bool firstFound = false;
|
|
for (auto memOp : memOps) {
|
|
if (auto memName = memOp.getName()) {
|
|
if (!firstFound) {
|
|
baseName = *memName;
|
|
firstFound = true;
|
|
continue;
|
|
}
|
|
unsigned idx = 0;
|
|
for (; idx < memName->size() && idx < baseName.size(); ++idx)
|
|
if ((*memName)[idx] != baseName[idx])
|
|
break;
|
|
baseName = baseName.take_front(idx);
|
|
}
|
|
}
|
|
baseName = baseName.rtrim('_');
|
|
|
|
SmallString<32> nameBuffer;
|
|
nameBuffer += mem.prefix;
|
|
if (!baseName.empty()) {
|
|
nameBuffer += baseName;
|
|
} else {
|
|
nameBuffer += "mem";
|
|
}
|
|
nameBuffer += "_";
|
|
(Twine(mem.depth) + "x" + Twine(mem.dataWidth)).toVector(nameBuffer);
|
|
auto name = builder.getStringAttr(globalNamespace.newName(nameBuffer));
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Creating " << name << " for " << mem.depth
|
|
<< " x " << mem.dataWidth << " memory\n");
|
|
|
|
bool withMask = mem.maskBits > 1;
|
|
SmallVector<hw::PortInfo> ports;
|
|
|
|
// Common types used for memory ports.
|
|
Type clkType = ClockType::get(context);
|
|
Type bitType = IntegerType::get(context, 1);
|
|
Type dataType = IntegerType::get(context, std::max((size_t)1, mem.dataWidth));
|
|
Type maskType = IntegerType::get(context, mem.maskBits);
|
|
Type addrType =
|
|
IntegerType::get(context, std::max(1U, llvm::Log2_64_Ceil(mem.depth)));
|
|
|
|
// Helper to add an input port.
|
|
size_t inputIdx = 0;
|
|
auto addInput = [&](StringRef prefix, size_t idx, StringRef suffix,
|
|
Type type) {
|
|
ports.push_back({{builder.getStringAttr(prefix + Twine(idx) + suffix), type,
|
|
ModulePort::Direction::Input},
|
|
inputIdx++});
|
|
};
|
|
|
|
// Helper to add an output port.
|
|
size_t outputIdx = 0;
|
|
auto addOutput = [&](StringRef prefix, size_t idx, StringRef suffix,
|
|
Type type) {
|
|
ports.push_back({{builder.getStringAttr(prefix + Twine(idx) + suffix), type,
|
|
ModulePort::Direction::Output},
|
|
outputIdx++});
|
|
};
|
|
|
|
// Helper to add the ports common to read, read-write, and write ports.
|
|
auto addCommonPorts = [&](StringRef prefix, size_t idx) {
|
|
addInput(prefix, idx, "_addr", addrType);
|
|
addInput(prefix, idx, "_en", bitType);
|
|
addInput(prefix, idx, "_clk", clkType);
|
|
};
|
|
|
|
// Add the read ports.
|
|
for (size_t i = 0, e = mem.numReadPorts; i != e; ++i) {
|
|
addCommonPorts("R", i);
|
|
addOutput("R", i, "_data", dataType);
|
|
}
|
|
|
|
// Add the read-write ports.
|
|
for (size_t i = 0, e = mem.numReadWritePorts; i != e; ++i) {
|
|
addCommonPorts("RW", i);
|
|
addInput("RW", i, "_wmode", bitType);
|
|
addInput("RW", i, "_wdata", dataType);
|
|
addOutput("RW", i, "_rdata", dataType);
|
|
if (withMask)
|
|
addInput("RW", i, "_wmask", maskType);
|
|
}
|
|
|
|
// Add the write ports.
|
|
for (size_t i = 0, e = mem.numWritePorts; i != e; ++i) {
|
|
addCommonPorts("W", i);
|
|
addInput("W", i, "_data", dataType);
|
|
if (withMask)
|
|
addInput("W", i, "_mask", maskType);
|
|
}
|
|
|
|
// Mask granularity is the number of data bits that each mask bit can
|
|
// guard. By default it is equal to the data bitwidth.
|
|
auto genAttr = [&](StringRef name, Attribute attr) {
|
|
return builder.getNamedAttr(name, attr);
|
|
};
|
|
auto genAttrUI32 = [&](StringRef name, uint32_t value) {
|
|
return genAttr(name, builder.getUI32IntegerAttr(value));
|
|
};
|
|
NamedAttribute genAttrs[] = {
|
|
genAttr("depth", builder.getI64IntegerAttr(mem.depth)),
|
|
genAttrUI32("numReadPorts", mem.numReadPorts),
|
|
genAttrUI32("numWritePorts", mem.numWritePorts),
|
|
genAttrUI32("numReadWritePorts", mem.numReadWritePorts),
|
|
genAttrUI32("readLatency", mem.readLatency),
|
|
genAttrUI32("writeLatency", mem.writeLatency),
|
|
genAttrUI32("width", mem.dataWidth),
|
|
genAttrUI32("maskGran", mem.dataWidth / mem.maskBits),
|
|
genAttr("readUnderWrite",
|
|
seq::RUWAttr::get(builder.getContext(), mem.readUnderWrite)),
|
|
genAttr("writeUnderWrite",
|
|
seq::WUWAttr::get(builder.getContext(), mem.writeUnderWrite)),
|
|
genAttr("writeClockIDs", builder.getI32ArrayAttr(mem.writeClockIDs)),
|
|
genAttr("initFilename", builder.getStringAttr(mem.initFilename)),
|
|
genAttr("initIsBinary", builder.getBoolAttr(mem.initIsBinary)),
|
|
genAttr("initIsInline", builder.getBoolAttr(mem.initIsInline))};
|
|
|
|
// Combine the locations of all actual `FirMemOp`s to be the location of the
|
|
// generated memory.
|
|
Location loc = FirMemOp(memOps.front()).getLoc();
|
|
if (memOps.size() > 1) {
|
|
SmallVector<Location> locs;
|
|
for (auto memOp : memOps)
|
|
locs.push_back(memOp.getLoc());
|
|
loc = FusedLoc::get(context, locs);
|
|
}
|
|
|
|
// Create the module.
|
|
auto genOp =
|
|
hw::HWModuleGeneratedOp::create(builder, loc, schemaSymRef, name, ports,
|
|
StringRef{}, ArrayAttr{}, genAttrs);
|
|
if (mem.outputFile)
|
|
genOp->setAttr("output_file", mem.outputFile);
|
|
|
|
return genOp;
|
|
}
|
|
|
|
/// Replace all `FirMemOp`s in an HW module with an instance of the
|
|
/// corresponding generated module.
|
|
void FirMemLowering::lowerMemoriesInModule(
|
|
HWModuleOp module,
|
|
ArrayRef<std::tuple<FirMemConfig *, HWModuleGeneratedOp, FirMemOp>> mems) {
|
|
LLVM_DEBUG(llvm::dbgs() << "Lowering " << mems.size() << " memories in "
|
|
<< module.getName() << "\n");
|
|
|
|
DenseMap<unsigned, Value> constOneOps;
|
|
auto constOne = [&](unsigned width = 1) {
|
|
auto it = constOneOps.try_emplace(width, Value{});
|
|
if (it.second) {
|
|
auto builder = OpBuilder::atBlockBegin(module.getBodyBlock());
|
|
it.first->second = hw::ConstantOp::create(
|
|
builder, module.getLoc(), builder.getIntegerType(width), 1);
|
|
}
|
|
return it.first->second;
|
|
};
|
|
auto valueOrOne = [&](Value value, unsigned width = 1) {
|
|
return value ? value : constOne(width);
|
|
};
|
|
|
|
for (auto [config, genOp, memOp] : mems) {
|
|
LLVM_DEBUG(llvm::dbgs() << "- Lowering " << memOp.getName() << "\n");
|
|
SmallVector<Value> inputs;
|
|
SmallVector<Value> outputs;
|
|
|
|
auto addInput = [&](Value value) { inputs.push_back(value); };
|
|
auto addOutput = [&](Value value) { outputs.push_back(value); };
|
|
|
|
// Add the read ports.
|
|
for (auto *op : memOp->getUsers()) {
|
|
auto port = dyn_cast<FirMemReadOp>(op);
|
|
if (!port)
|
|
continue;
|
|
addInput(port.getAddress());
|
|
addInput(valueOrOne(port.getEnable()));
|
|
addInput(port.getClk());
|
|
addOutput(port.getData());
|
|
}
|
|
|
|
// Add the read-write ports.
|
|
for (auto *op : memOp->getUsers()) {
|
|
auto port = dyn_cast<FirMemReadWriteOp>(op);
|
|
if (!port)
|
|
continue;
|
|
addInput(port.getAddress());
|
|
addInput(valueOrOne(port.getEnable()));
|
|
addInput(port.getClk());
|
|
addInput(port.getMode());
|
|
addInput(port.getWriteData());
|
|
addOutput(port.getReadData());
|
|
if (config->maskBits > 1)
|
|
addInput(valueOrOne(port.getMask(), config->maskBits));
|
|
}
|
|
|
|
// Add the write ports.
|
|
for (auto *op : memOp->getUsers()) {
|
|
auto port = dyn_cast<FirMemWriteOp>(op);
|
|
if (!port)
|
|
continue;
|
|
addInput(port.getAddress());
|
|
addInput(valueOrOne(port.getEnable()));
|
|
addInput(port.getClk());
|
|
addInput(port.getData());
|
|
if (config->maskBits > 1)
|
|
addInput(valueOrOne(port.getMask(), config->maskBits));
|
|
}
|
|
|
|
// Create the module instance.
|
|
StringRef memName = "mem";
|
|
if (auto name = memOp.getName(); name && !name->empty())
|
|
memName = *name;
|
|
ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
|
|
auto instOp = hw::InstanceOp::create(
|
|
builder, genOp, builder.getStringAttr(memName + "_ext"), inputs,
|
|
ArrayAttr{}, memOp.getInnerSymAttr());
|
|
for (auto [oldOutput, newOutput] : llvm::zip(outputs, instOp.getResults()))
|
|
oldOutput.replaceAllUsesWith(newOutput);
|
|
|
|
// Carry attributes over from the `FirMemOp` to the `InstanceOp`.
|
|
auto defaultAttrNames = memOp.getAttributeNames();
|
|
for (auto namedAttr : memOp->getAttrs())
|
|
if (!llvm::is_contained(defaultAttrNames, namedAttr.getName()))
|
|
instOp->setAttr(namedAttr.getName(), namedAttr.getValue());
|
|
|
|
// Get rid of the `FirMemOp`.
|
|
for (auto *user : llvm::make_early_inc_range(memOp->getUsers()))
|
|
user->erase();
|
|
memOp.erase();
|
|
}
|
|
}
|