mirror of https://github.com/llvm/circt.git
408 lines
15 KiB
C++
408 lines
15 KiB
C++
//===- CalyxToFSM.cpp - Calyx to FSM conversion pass ----------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This is the main Calyx control to FSM Conversion Pass Implementation.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "circt/Conversion/CalyxToFSM.h"
|
|
#include "circt/Dialect/Calyx/CalyxOps.h"
|
|
#include "circt/Dialect/Comb/CombOps.h"
|
|
#include "circt/Dialect/FSM/FSMDialect.h"
|
|
#include "circt/Dialect/FSM/FSMGraph.h"
|
|
#include "circt/Dialect/FSM/FSMOps.h"
|
|
#include "circt/Support/Namespace.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
namespace circt {
|
|
#define GEN_PASS_DEF_CALYXTOFSM
|
|
#include "circt/Conversion/Passes.h.inc"
|
|
} // namespace circt
|
|
|
|
using namespace mlir;
|
|
using namespace circt;
|
|
using namespace calyx;
|
|
using namespace fsm;
|
|
|
|
namespace {
|
|
|
|
class CompileFSMVisitor {
|
|
public:
|
|
CompileFSMVisitor(SymbolCache &sc, FSMGraph &graph)
|
|
: graph(graph), sc(sc), ctx(graph.getMachine().getContext()),
|
|
builder(graph.getMachine().getContext()) {
|
|
ns.add(sc);
|
|
}
|
|
|
|
/// Lowers the provided 'op' into a new FSM StateOp.
|
|
LogicalResult dispatch(StateOp currentState, Operation *op,
|
|
StateOp nextState) {
|
|
return TypeSwitch<Operation *, LogicalResult>(op)
|
|
.template Case<SeqOp, EnableOp, IfOp, WhileOp>(
|
|
[&](auto opNode) { return visit(currentState, opNode, nextState); })
|
|
.Default([&](auto) {
|
|
return op->emitError() << "Operation '" << op->getName()
|
|
<< "' not supported for FSM compilation";
|
|
});
|
|
}
|
|
|
|
ArrayRef<Attribute> getCompiledGroups() { return compiledGroups; }
|
|
|
|
private:
|
|
/// Operation visitors;
|
|
/// Apart from the visited operation, a visitor is provided with two extra
|
|
/// arguments:
|
|
/// currentState:
|
|
/// This represents a state which the callee has allocated to this visitor;
|
|
/// the visitor is free to use this state to its liking.
|
|
/// nextState:
|
|
/// This represent the next state which this visitor eventually must
|
|
/// transition to.
|
|
LogicalResult visit(StateOp currentState, SeqOp, StateOp nextState);
|
|
LogicalResult visit(StateOp currentState, EnableOp, StateOp nextState);
|
|
LogicalResult visit(StateOp currentState, IfOp, StateOp nextState);
|
|
LogicalResult visit(StateOp currentState, WhileOp, StateOp nextState);
|
|
|
|
/// Represents unique state name scopes generated from pushing states onto
|
|
/// the state stack. The guard carries a unique name as well as managing the
|
|
/// lifetime of suffixes on the state stack.
|
|
struct StateScopeGuard {
|
|
public:
|
|
StateScopeGuard(CompileFSMVisitor &visitor, StringRef name,
|
|
StringRef suffix)
|
|
: visitor(visitor), name(name) {
|
|
visitor.stateStack.push_back(suffix.str());
|
|
}
|
|
~StateScopeGuard() {
|
|
assert(!visitor.stateStack.empty());
|
|
visitor.stateStack.pop_back();
|
|
}
|
|
|
|
StringRef getName() { return name; }
|
|
|
|
private:
|
|
CompileFSMVisitor &visitor;
|
|
std::string name;
|
|
};
|
|
|
|
/// Generates a new state name based on the current state stack and the
|
|
/// provided suffix. The new suffix is pushed onto the state stack. Returns a
|
|
/// guard object which pops the new suffix upon destruction.
|
|
StateScopeGuard pushStateScope(StringRef suffix) {
|
|
std::string name;
|
|
llvm::raw_string_ostream ss(name);
|
|
llvm::interleave(
|
|
stateStack, ss, [&](const auto &it) { ss << it; }, "_");
|
|
ss << "_" << suffix.str();
|
|
return StateScopeGuard(*this, ns.newName(name), suffix);
|
|
}
|
|
|
|
FSMGraph &graph;
|
|
SymbolCache ≻
|
|
MLIRContext *ctx;
|
|
OpBuilder builder;
|
|
Namespace ns;
|
|
SmallVector<std::string, 4> stateStack;
|
|
|
|
/// Maintain the set of compiled groups within this FSM, to pass Calyx
|
|
/// verifiers.
|
|
SmallVector<Attribute, 8> compiledGroups;
|
|
};
|
|
|
|
LogicalResult CompileFSMVisitor::visit(StateOp currentState, IfOp ifOp,
|
|
StateOp nextState) {
|
|
auto stateGuard = pushStateScope("if");
|
|
auto loc = ifOp.getLoc();
|
|
|
|
// Rename the current state now that we know it's an if header.
|
|
graph.renameState(currentState, stateGuard.getName());
|
|
|
|
auto lowerBranch = [&](Value cond, StringRef nextStateSuffix, bool invert,
|
|
Operation *innerBranchOp) {
|
|
auto branchStateGuard = pushStateScope(nextStateSuffix);
|
|
auto branchStateOp =
|
|
graph.createState(builder, ifOp.getLoc(), branchStateGuard.getName())
|
|
->getState();
|
|
|
|
auto transitionOp = graph
|
|
.createTransition(builder, ifOp.getLoc(),
|
|
currentState, branchStateOp)
|
|
->getTransition();
|
|
transitionOp.ensureGuard(builder);
|
|
fsm::ReturnOp returnOp = transitionOp.getGuardReturn();
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToStart(&transitionOp.getGuard().front());
|
|
Value branchTaken = cond;
|
|
if (invert) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
branchTaken = comb::createOrFoldNot(loc, branchTaken, builder);
|
|
}
|
|
|
|
returnOp.setOperand(branchTaken);
|
|
|
|
// Recurse into the body of the branch, with an exit state targeting
|
|
// 'nextState'.
|
|
if (failed(dispatch(branchStateOp, innerBranchOp, nextState)))
|
|
return failure();
|
|
return success();
|
|
};
|
|
|
|
// Then branch.
|
|
if (failed(lowerBranch(ifOp.getCond(), "then", /*invert=*/false,
|
|
&ifOp.getThenBody()->front())))
|
|
return failure();
|
|
|
|
// Else branch.
|
|
if (ifOp.elseBodyExists() &&
|
|
failed(lowerBranch(ifOp.getCond(), "else", /*invert=*/true,
|
|
&ifOp.getElseBody()->front())))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult CompileFSMVisitor::visit(StateOp currentState, SeqOp seqOp,
|
|
StateOp nextState) {
|
|
Location loc = seqOp.getLoc();
|
|
auto seqStateGuard = pushStateScope("seq");
|
|
|
|
// Create a new state for each nested operation within this seqOp.
|
|
auto &seqOps = seqOp.getBodyBlock()->getOperations();
|
|
llvm::SmallVector<std::pair<Operation *, StateOp>> seqStates;
|
|
|
|
// Iterate over the operations within the sequence. We do this in reverse
|
|
// order to ensure that we always know the next state.
|
|
StateOp currentOpNextState = nextState;
|
|
int n = seqOps.size() - 1;
|
|
for (auto &op : llvm::reverse(*seqOp.getBodyBlock())) {
|
|
auto subStateGuard = pushStateScope(std::to_string(n--));
|
|
auto thisStateOp =
|
|
graph.createState(builder, op.getLoc(), subStateGuard.getName())
|
|
->getState();
|
|
seqStates.insert(seqStates.begin(), {&op, thisStateOp});
|
|
sc.addSymbol(thisStateOp);
|
|
|
|
// Recurse into the current operation.
|
|
if (failed(dispatch(thisStateOp, &op, currentOpNextState)))
|
|
return failure();
|
|
|
|
// This state is now the next state for the following operation.
|
|
currentOpNextState = thisStateOp;
|
|
}
|
|
|
|
// Make 'currentState' transition directly the first state in the sequence.
|
|
graph.createTransition(builder, loc, currentState, seqStates.front().second);
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult CompileFSMVisitor::visit(StateOp currentState, WhileOp whileOp,
|
|
StateOp nextState) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
auto whileStateGuard = pushStateScope("while");
|
|
auto loc = whileOp.getLoc();
|
|
|
|
// The current state is the while header (branch to whileOp or nextState).
|
|
// Rename the current state now that we know it's a while header state.
|
|
StateOp whileHeaderState = currentState;
|
|
graph.renameState(whileHeaderState,
|
|
(whileStateGuard.getName() + "_header").str());
|
|
sc.addSymbol(whileHeaderState);
|
|
|
|
// Dispatch into the while body. The while body will always return to the
|
|
// header.
|
|
auto whileBodyEntryState =
|
|
graph
|
|
.createState(builder, loc,
|
|
(whileStateGuard.getName() + "_entry").str())
|
|
->getState();
|
|
sc.addSymbol(whileBodyEntryState);
|
|
Operation *whileBodyOp = &whileOp.getBodyBlock()->front();
|
|
if (failed(dispatch(whileBodyEntryState, whileBodyOp, whileHeaderState)))
|
|
return failure();
|
|
|
|
// Create transitions to either the while body or the next state based on the
|
|
// while condition.
|
|
auto bodyTransition =
|
|
graph
|
|
.createTransition(builder, loc, whileHeaderState, whileBodyEntryState)
|
|
->getTransition();
|
|
auto nextStateTransition =
|
|
graph.createTransition(builder, loc, whileHeaderState, nextState)
|
|
->getTransition();
|
|
|
|
bodyTransition.ensureGuard(builder);
|
|
bodyTransition.getGuardReturn().setOperand(whileOp.getCond());
|
|
nextStateTransition.ensureGuard(builder);
|
|
builder.setInsertionPoint(nextStateTransition.getGuardReturn());
|
|
nextStateTransition.getGuardReturn().setOperand(
|
|
comb::createOrFoldNot(loc, whileOp.getCond(), builder));
|
|
return success();
|
|
}
|
|
|
|
LogicalResult CompileFSMVisitor::visit(StateOp currentState, EnableOp enableOp,
|
|
StateOp nextState) {
|
|
assert(currentState &&
|
|
"Expected this enableOp to be nested into some provided state");
|
|
|
|
// Rename the current state now that we know it's an enable state.
|
|
auto enableStateGuard = pushStateScope(enableOp.getGroupName());
|
|
graph.renameState(currentState, enableStateGuard.getName());
|
|
|
|
// Create a new calyx.enable in the output state referencing the enabled
|
|
// group. We create a new op here as opposed to moving the existing, to make
|
|
// callers iterating over nested ops safer.
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToStart(¤tState.getOutput().front());
|
|
calyx::EnableOp::create(builder, enableOp.getLoc(), enableOp.getGroupName());
|
|
|
|
if (nextState)
|
|
graph.createTransition(builder, enableOp.getLoc(), currentState, nextState);
|
|
|
|
// Append this group to the set of compiled groups.
|
|
compiledGroups.push_back(
|
|
SymbolRefAttr::get(builder.getContext(), enableOp.getGroupName()));
|
|
|
|
return success();
|
|
}
|
|
|
|
// CompileInvoke is used to convert invoke operations to group operations and
|
|
// enable operations.
|
|
class CompileInvoke {
|
|
public:
|
|
CompileInvoke(ComponentOp component, OpBuilder builder)
|
|
: component(component), builder(builder) {}
|
|
void compile();
|
|
|
|
private:
|
|
void lowerInvokeOp(InvokeOp invokeOp);
|
|
std::string getTransitionName(InvokeOp invokeOp);
|
|
ComponentOp component;
|
|
OpBuilder builder;
|
|
// Part of the group name. It is used to generate unique group names, the
|
|
// unique counter is reused across multiple calls to lowerInvokeOp, so the
|
|
// loop that's checking for name uniqueness usually finds a unique name on the
|
|
// first try.
|
|
size_t transitionNameTail = 0;
|
|
};
|
|
|
|
// Access all invokeOp.
|
|
void CompileInvoke::compile() {
|
|
llvm::SmallVector<InvokeOp> invokeOps =
|
|
component.getControlOp().getInvokeOps();
|
|
for (InvokeOp op : invokeOps)
|
|
lowerInvokeOp(op);
|
|
}
|
|
|
|
// Get the name of the generation group.
|
|
std::string CompileInvoke::getTransitionName(InvokeOp invokeOp) {
|
|
llvm::StringRef callee = invokeOp.getCallee();
|
|
std::string transitionNameHead = "invoke_" + callee.str() + "_";
|
|
std::string transitionName;
|
|
|
|
// The following loop is used to check if the transitionName already exists.
|
|
// If it does, the loop regenerates the transitionName.
|
|
do {
|
|
transitionName = transitionNameHead + std::to_string(transitionNameTail++);
|
|
} while (component.getWiresOp().lookupSymbol(transitionName));
|
|
return transitionName;
|
|
}
|
|
|
|
// Convert an invoke operation to a group operation and an enable operation.
|
|
void CompileInvoke::lowerInvokeOp(InvokeOp invokeOp) {
|
|
// Create a ConstantOp to assign a value to the go port.
|
|
Operation *prevNode = component.getWiresOp().getOperation()->getPrevNode();
|
|
builder.setInsertionPointAfter(prevNode);
|
|
hw::ConstantOp constantOp = hw::ConstantOp::create(
|
|
builder, prevNode->getLoc(), builder.getI1Type(), 1);
|
|
Location loc = component.getWiresOp().getLoc();
|
|
|
|
// Set the insertion point at the end of the wires block.
|
|
builder.setInsertionPointToEnd(component.getWiresOp().getBodyBlock());
|
|
std::string transitionName = getTransitionName(invokeOp);
|
|
GroupOp groupOp = GroupOp::create(builder, loc, transitionName);
|
|
builder.setInsertionPointToStart(groupOp.getBodyBlock());
|
|
Value go = invokeOp.getInstGoValue();
|
|
|
|
// Assign a value to the go port.
|
|
AssignOp::create(builder, loc, go, constantOp);
|
|
auto ports = invokeOp.getPorts();
|
|
auto inputs = invokeOp.getInputs();
|
|
|
|
// Generate a series of assignment operations from a list of parameters.
|
|
for (auto [port, input] : llvm::zip(ports, inputs))
|
|
AssignOp::create(builder, loc, port, input);
|
|
Value done = invokeOp.getInstDoneValue();
|
|
|
|
// Generate a group_done operation with the instance's done port.
|
|
calyx::GroupDoneOp::create(builder, loc, done);
|
|
builder.setInsertionPointAfter(invokeOp.getOperation());
|
|
EnableOp::create(builder, invokeOp.getLoc(), transitionName);
|
|
invokeOp.erase();
|
|
}
|
|
|
|
class CalyxToFSMPass : public circt::impl::CalyxToFSMBase<CalyxToFSMPass> {
|
|
public:
|
|
void runOnOperation() override;
|
|
}; // end anonymous namespace
|
|
|
|
void CalyxToFSMPass::runOnOperation() {
|
|
ComponentOp component = getOperation();
|
|
OpBuilder builder(&getContext());
|
|
auto ctrlOp = component.getControlOp();
|
|
assert(ctrlOp.getBodyBlock()->getOperations().size() == 1 &&
|
|
"Expected a single top-level operation in the schedule");
|
|
CompileInvoke compileInvoke(component, builder);
|
|
compileInvoke.compile();
|
|
Operation &topLevelCtrlOp = ctrlOp.getBodyBlock()->front();
|
|
builder.setInsertionPoint(&topLevelCtrlOp);
|
|
|
|
// Create a side-effect-only FSM (no inputs, no outputs) which will strictly
|
|
// refer to the symbols and SSA values defined in the regions of the
|
|
// ComponentOp. This makes for an intermediate step, which allows for
|
|
// outlining the FSM (materializing FSM I/O) at a later point.
|
|
auto machineName = ("control_" + component.getName()).str();
|
|
auto funcType = FunctionType::get(&getContext(), {}, {});
|
|
auto machine = MachineOp::create(builder, ctrlOp.getLoc(), machineName,
|
|
/*initialState=*/"fsm_entry", funcType);
|
|
auto graph = FSMGraph(machine);
|
|
|
|
SymbolCache sc;
|
|
sc.addDefinitions(machine);
|
|
|
|
// Create entry and exit states
|
|
auto entryState =
|
|
graph.createState(builder, ctrlOp.getLoc(), calyxToFSM::sEntryStateName)
|
|
->getState();
|
|
auto exitState =
|
|
graph.createState(builder, ctrlOp.getLoc(), calyxToFSM::sExitStateName)
|
|
->getState();
|
|
|
|
auto visitor = CompileFSMVisitor(sc, graph);
|
|
if (failed(visitor.dispatch(entryState, &topLevelCtrlOp, exitState))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
// Remove the top-level calyx control operation that we've now converted to an
|
|
// FSM.
|
|
topLevelCtrlOp.erase();
|
|
|
|
// Add the set of compiled groups as an attribute to the fsm.
|
|
machine->setAttr(
|
|
"compiledGroups",
|
|
ArrayAttr::get(builder.getContext(), visitor.getCompiledGroups()));
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<mlir::Pass> circt::createCalyxToFSMPass() {
|
|
return std::make_unique<CalyxToFSMPass>();
|
|
}
|