mirror of https://github.com/llvm/circt.git
339 lines
12 KiB
C++
339 lines
12 KiB
C++
//===- MuxToControlFlow.cpp - Implement the MuxToControlFlow Pass ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Implement a pass to convert muxes to control flow branches whenever it is
|
|
// beneficial for performance (i.e., when expected work avoided is more than
|
|
// branching costs)
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "circt/Dialect/Arc/ArcInterfaces.h"
|
|
#include "circt/Dialect/Arc/ArcOps.h"
|
|
#include "circt/Dialect/Arc/ArcPasses.h"
|
|
#include "circt/Dialect/Comb/CombOps.h"
|
|
#include "circt/Dialect/HW/HWOps.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#define DEBUG_TYPE "arc-mux-to-control-flow"
|
|
|
|
namespace circt {
|
|
namespace arc {
|
|
#define GEN_PASS_DEF_MUXTOCONTROLFLOW
|
|
#include "circt/Dialect/Arc/ArcPasses.h.inc"
|
|
} // namespace arc
|
|
} // namespace circt
|
|
|
|
using namespace circt;
|
|
using namespace arc;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MuxToControlFlow pass declarations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
/// Convert muxes to if-statements.
|
|
struct MuxToControlFlowPass
|
|
: public arc::impl::MuxToControlFlowBase<MuxToControlFlowPass> {
|
|
MuxToControlFlowPass() = default;
|
|
MuxToControlFlowPass(const MuxToControlFlowPass &pass)
|
|
: MuxToControlFlowPass() {}
|
|
|
|
void runOnOperation() override;
|
|
|
|
Statistic numMuxesConverted{
|
|
this, "num-muxes-converted",
|
|
"Number of muxes that were converted to if-statements"};
|
|
Statistic numMuxesRetained{this, "num-muxes-retained",
|
|
"Number of muxes that were not converted"};
|
|
};
|
|
|
|
/// Abstract over muxes to easy addition of support for other operations.
|
|
struct BranchInfo {
|
|
BranchInfo() = default;
|
|
BranchInfo(Value condition, Value trueValue, Value falseValue)
|
|
: condition(condition), trueValue(trueValue), falseValue(falseValue) {}
|
|
|
|
Value condition;
|
|
Value trueValue;
|
|
Value falseValue;
|
|
|
|
operator bool() { return condition && trueValue && falseValue; }
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Check whether @param curr is valid to be moved into the if-branch, which is
|
|
/// the stopping condition of the BFS traversal.
|
|
static bool isValidToProceedTraversal(Operation *mux, Operation *curr,
|
|
Value useValue,
|
|
SmallPtrSetImpl<Operation *> &visited) {
|
|
for (auto res : curr->getResults()) {
|
|
for (auto *user : res.getUsers()) {
|
|
// The use-sites of all results have to be within the same branch, thus
|
|
// already have to be visited already. The only exception is the first
|
|
// operation in the branch used by the mux itself.
|
|
if (!visited.contains(user) && user != mux)
|
|
return false;
|
|
|
|
// The second part of the special case mentioned above (because otherwise
|
|
// we would also include the first operation of the other branch).
|
|
if (user == mux && res != useValue)
|
|
return false;
|
|
|
|
if (user->getBlock() != curr->getBlock())
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
/// Compute the set of operations that would only be used in the branch
|
|
/// represented by @param useValue.
|
|
static void computeFanIn(Operation *mux, Value useValue,
|
|
SmallPtrSetImpl<Operation *> &visited) {
|
|
auto *op = useValue.getDefiningOp();
|
|
if (!op)
|
|
return;
|
|
|
|
SmallVector<Operation *> worklist{op};
|
|
|
|
while (!worklist.empty()) {
|
|
auto *curr = worklist.front();
|
|
worklist.erase(worklist.begin());
|
|
|
|
if (visited.contains(curr))
|
|
continue;
|
|
|
|
if (!isValidToProceedTraversal(mux, curr, useValue, visited))
|
|
continue;
|
|
|
|
visited.insert(curr);
|
|
|
|
for (auto val : curr->getOperands()) {
|
|
if (auto *defOp = val.getDefiningOp())
|
|
worklist.push_back(defOp);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Clone ops that are used in both branches of an if-statement but not outside
|
|
/// of it. This is just here because of experimentation reasons. Doing this
|
|
/// might allow for better instruction scheduling to slightly reduce ISA
|
|
/// register pressure (however, it is currently too naive to only take the
|
|
/// beneficial situations), but it will increase binary size which is especially
|
|
/// bad when the hot part would otherwise fit in instruction cache (but doesn't
|
|
/// really matter when it doesn't fit anyways as there is no temporal locality
|
|
/// anyways).
|
|
[[maybe_unused]] static void
|
|
cloneOpsIntoBranchesWhenUsedInBoth(mlir::scf::IfOp ifOp) {
|
|
// Iterate over all operations at the same nesting level as the if-statement
|
|
// (not the operations inside the if-statement).
|
|
for (auto &op : llvm::reverse(*ifOp->getBlock())) {
|
|
if (op.getNumResults() == 0)
|
|
continue;
|
|
|
|
// Collect all users of the current operations results.
|
|
SmallVector<Operation *> users;
|
|
for (auto result : op.getResults())
|
|
users.append(llvm::to_vector(result.getUsers()));
|
|
|
|
auto parentsOfUsers =
|
|
llvm::map_range(users, [](auto user) { return user->getParentOp(); });
|
|
|
|
auto allUsersNestedInIf = llvm::any_of(parentsOfUsers, [&](auto *parent) {
|
|
return !(isa<mlir::scf::IfOp>(parent) &&
|
|
parent->getBlock() == op.getBlock());
|
|
});
|
|
|
|
// Check that all users of the results are nested inside the same scf.if
|
|
// operation
|
|
if (allUsersNestedInIf || !llvm::all_equal(parentsOfUsers))
|
|
continue;
|
|
|
|
DenseMap<Region *, Value> cloneMap;
|
|
for (auto &use : llvm::make_early_inc_range(op.getUses())) {
|
|
auto *parentRegion = use.getOwner()->getParentRegion();
|
|
if (!cloneMap.count(parentRegion)) {
|
|
OpBuilder builder(&parentRegion->front().front());
|
|
cloneMap[parentRegion] = builder.clone(op)->getResult(0);
|
|
}
|
|
use.set(cloneMap[parentRegion]);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Perform the actual conversion. Create the if-statement, move the operations
|
|
/// in its regions and delete the mux.
|
|
static void doConversion(Operation *op, BranchInfo info,
|
|
const SmallPtrSetImpl<Operation *> &thenOps,
|
|
const SmallPtrSetImpl<Operation *> &elseOps) {
|
|
if (op->getNumResults() != 1)
|
|
return;
|
|
|
|
// Build the scf.if operation with the scf.yields inside.
|
|
ImplicitLocOpBuilder builder(op->getLoc(), op);
|
|
mlir::scf::IfOp ifOp = mlir::scf::IfOp::create(
|
|
builder, info.condition,
|
|
[&](OpBuilder &builder, Location loc) {
|
|
mlir::scf::YieldOp::create(builder, loc, info.trueValue);
|
|
},
|
|
[&](OpBuilder &builder, Location loc) {
|
|
mlir::scf::YieldOp::create(builder, loc, info.falseValue);
|
|
});
|
|
|
|
op->getResult(0).replaceAllUsesWith(ifOp.getResult(0));
|
|
|
|
for (auto &ops :
|
|
llvm::make_early_inc_range(op->getParentRegion()->getOps())) {
|
|
// Move operations into the then-branch if they are only used in there.
|
|
// The original lexicographical order is preserved.
|
|
if (thenOps.contains(&ops))
|
|
ops.moveBefore(ifOp.thenBlock()->getTerminator());
|
|
|
|
// Move operations into the else-branch if they are only used in there.
|
|
// The original lexicographical order is preserved.
|
|
if (elseOps.contains(&ops))
|
|
ops.moveBefore(ifOp.elseBlock()->getTerminator());
|
|
}
|
|
|
|
op->erase();
|
|
|
|
// NOTE: this is just here for some experimentation purposes
|
|
// cloneOpsIntoBranchesWhenUsedInBoth(ifOp);
|
|
}
|
|
|
|
/// Simple helper to invoke the runtime cost interface for every operation in a
|
|
/// set and sum up the costs.
|
|
static uint32_t getCostEstimate(const SmallPtrSetImpl<Operation *> &ops) {
|
|
uint32_t cost = 0;
|
|
|
|
for (auto *op : ops) {
|
|
if (auto *runtimeCostIF =
|
|
dyn_cast<RuntimeCostEstimateDialectInterface>(op->getDialect())) {
|
|
cost += runtimeCostIF->getCostEstimate(op);
|
|
} else {
|
|
LLVM_DEBUG(llvm::dbgs() << "No runtime cost estimate was provided for '"
|
|
<< op->getName() << "', using default of 10\n");
|
|
cost += 10;
|
|
}
|
|
}
|
|
|
|
return cost;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Decision functions (configure the pass here)
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Convert concrete operations that should be converted to if-statements to a
|
|
/// more abstract representation the rest of the pass works with. This is the
|
|
/// place where support for more operations can be added (nothing else has to be
|
|
/// changed).
|
|
static BranchInfo getConversionInfo(Operation *op) {
|
|
if (auto mux = dyn_cast<comb::MuxOp>(op))
|
|
return BranchInfo{mux.getCond(), mux.getTrueValue(), mux.getFalseValue()};
|
|
|
|
// TODO: we can also check for arith.select or other operations here
|
|
|
|
return {};
|
|
}
|
|
|
|
/// Use the cost measure of each branch to heuristically decide whether to
|
|
/// actually perform the conversion.
|
|
/// TODO: improve and fine-tune this
|
|
static bool isBeneficialToConvert(Operation *op,
|
|
const SmallPtrSetImpl<Operation *> &thenOps,
|
|
const SmallPtrSetImpl<Operation *> &elseOps) {
|
|
const uint32_t thenCost = getCostEstimate(thenOps);
|
|
const uint32_t elseCost = getCostEstimate(elseOps);
|
|
|
|
// Due to the nature of mux sequences we need to make sure that a reasonable
|
|
// amount of operations stay in each if-branch because otherwise we end up
|
|
// with if-statements that only contain anther if-statement, which is usually
|
|
// more costly than keeping some muxes unconverted.
|
|
if (auto parent = op->getParentOfType<mlir::scf::IfOp>()) {
|
|
SmallPtrSet<Operation *, 32> ifBranchOps;
|
|
|
|
for (auto &nestedOp : *op->getBlock()) {
|
|
if (!thenOps.contains(&nestedOp) && !elseOps.contains(&nestedOp))
|
|
ifBranchOps.insert(&nestedOp);
|
|
}
|
|
|
|
if (getCostEstimate(ifBranchOps) < 100)
|
|
return false;
|
|
}
|
|
|
|
// return thenCost + elseCost >= 100 && (thenCost == 0 || elseCost == 0);
|
|
return (thenCost >= 100 || thenCost == 0) &&
|
|
(elseCost >= 100 || elseCost == 0) &&
|
|
std::abs((int)thenCost - (int)elseCost) >= 100;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MuxToControlFlow pass definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// FIXME: Assumes that the regions in which muxes exist are topologically
|
|
// ordered.
|
|
// FIXME: does not consider side-effects
|
|
void MuxToControlFlowPass::runOnOperation() {
|
|
// Collect all operations that support the conversion to scf.if operations.
|
|
// Use 'walk' instead of 'getOps' as we also want to visit nested regions.
|
|
// We need to collect them because moving ops while iterating over them
|
|
// would require complicated iterator advancing/skipping but also tracking
|
|
// back to not miss supported operations.
|
|
SmallVector<Operation *> supportedOps;
|
|
getOperation()->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
|
|
// Skip ops with graph regions and ops that can contain ops with write
|
|
// semantics for now until side-effects and topological ordering is properly
|
|
// handled.
|
|
if (isa<hw::HWModuleOp, arc::ModelOp>(op))
|
|
return WalkResult::skip();
|
|
|
|
if (getConversionInfo(op))
|
|
supportedOps.push_back(op);
|
|
|
|
return WalkResult::advance();
|
|
});
|
|
|
|
// We want to visit the operations bottom-up to visit the operations with the
|
|
// longest fan-in first. However, the other direction would also work with the
|
|
// current implementation.
|
|
for (auto *op : llvm::reverse(supportedOps)) {
|
|
auto info = getConversionInfo(op);
|
|
|
|
// Compute the operations in the fan-in of each branch and use them to
|
|
// decide whether the operation should be converted.
|
|
// Stop at the first value that's also used outside of the branch.
|
|
llvm::SmallPtrSet<Operation *, 32> thenOps, elseOps;
|
|
computeFanIn(op, info.trueValue, thenOps);
|
|
computeFanIn(op, info.falseValue, elseOps);
|
|
|
|
// Apply a cost measure to the operations in the branches and only convert
|
|
// when a performance increase can be expected.
|
|
if (isBeneficialToConvert(op, thenOps, elseOps)) {
|
|
doConversion(op, info, thenOps, elseOps);
|
|
++numMuxesConverted;
|
|
} else {
|
|
++numMuxesRetained;
|
|
}
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<Pass> arc::createMuxToControlFlowPass() {
|
|
return std::make_unique<MuxToControlFlowPass>();
|
|
}
|