circt/lib/Dialect/Arc/Transforms/ArcCanonicalizer.cpp

606 lines
22 KiB
C++

//===- ArcCanonicalizer.cpp -------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===----------------------------------------------------------------------===//
//
// Simulation centric canonicalizations for non-arc operations and
// canonicalizations that require efficient symbol lookups.
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/Arc/ArcOps.h"
#include "circt/Dialect/Arc/ArcPasses.h"
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Support/Namespace.h"
#include "circt/Support/SymCache.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "arc-canonicalizer"
namespace circt {
namespace arc {
#define GEN_PASS_DEF_ARCCANONICALIZER
#include "circt/Dialect/Arc/ArcPasses.h.inc"
} // namespace arc
} // namespace circt
using namespace circt;
using namespace arc;
//===----------------------------------------------------------------------===//
// Datastructures
//===----------------------------------------------------------------------===//
namespace {
/// A combination of SymbolCache and SymbolUserMap that also allows to add users
/// and remove symbols on-demand.
class SymbolHandler : public SymbolCache {
public:
/// Return the users of the provided symbol operation.
ArrayRef<Operation *> getUsers(Operation *symbol) const {
auto it = userMap.find(symbol);
return it != userMap.end() ? it->second.getArrayRef() : std::nullopt;
}
/// Return true if the given symbol has no uses.
bool useEmpty(Operation *symbol) {
return !userMap.count(symbol) || userMap[symbol].empty();
}
void addUser(Operation *def, Operation *user) {
assert(isa<mlir::SymbolOpInterface>(def));
if (!symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
symbolCache.insert(
{cast<mlir::SymbolOpInterface>(def).getNameAttr(), def});
userMap[def].insert(user);
}
void removeUser(Operation *def, Operation *user) {
assert(isa<mlir::SymbolOpInterface>(def));
if (symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
userMap[def].remove(user);
if (userMap[def].empty())
userMap.erase(def);
}
void removeDefinitionAndAllUsers(Operation *def) {
assert(isa<mlir::SymbolOpInterface>(def));
symbolCache.erase(cast<mlir::SymbolOpInterface>(def).getNameAttr());
userMap.erase(def);
}
void collectAllSymbolUses(Operation *symbolTableOp,
SymbolTableCollection &symbolTable) {
// NOTE: the following is almost 1-1 taken from the SymbolUserMap
// constructor. They made it difficult to extend the implementation by
// having a lot of members private and non-virtual methods.
SmallVector<Operation *> symbols;
auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
assert(symbolUses && "expected uses to be valid");
for (const SymbolTable::SymbolUse &use : *symbolUses) {
symbols.clear();
(void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
symbols);
for (Operation *symbolOp : symbols)
userMap[symbolOp].insert(use.getUser());
}
}
};
// We just set `allSymUsesVisible` to false here because it isn't necessary
// for building the user map.
SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
walkFn);
}
private:
DenseMap<Operation *, SetVector<Operation *>> userMap;
};
/// A Listener keeping the provided SymbolHandler up-to-date. This is especially
/// important for simplifications (e.g. DCE) the rewriter performs automatically
/// that we cannot or do not want to turn off.
class ArcListener : public mlir::RewriterBase::Listener {
public:
explicit ArcListener(SymbolHandler *handler) : Listener(), handler(handler) {}
void notifyOperationReplaced(Operation *op, Operation *replacement) override {
// If, e.g., a DefineOp is replaced with another DefineOp but with the same
// symbol, we don't want to drop the list of users.
auto symOp = dyn_cast<mlir::SymbolOpInterface>(op);
auto symReplacement = dyn_cast<mlir::SymbolOpInterface>(replacement);
if (symOp && symReplacement &&
symOp.getNameAttr() == symReplacement.getNameAttr())
return;
remove(op);
// TODO: if an operation is inserted that defines a symbol and the symbol
// already has uses, those users are not added.
add(replacement);
}
void notifyOperationReplaced(Operation *op, ValueRange replacement) override {
remove(op);
}
void notifyOperationRemoved(Operation *op) override { remove(op); }
void notifyOperationInserted(Operation *op) override {
// TODO: if an operation is inserted that defines a symbol and the symbol
// already has uses, those users are not added.
add(op);
}
private:
FailureOr<Operation *> maybeGetDefinition(Operation *op) {
if (auto callOp = dyn_cast<mlir::CallOpInterface>(op)) {
auto symAttr =
callOp.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>();
if (!symAttr)
return failure();
if (auto *def = handler->getDefinition(symAttr.getLeafReference()))
return def;
}
return failure();
}
void remove(Operation *op) {
auto maybeDef = maybeGetDefinition(op);
if (!failed(maybeDef))
handler->removeUser(*maybeDef, op);
if (isa<mlir::SymbolOpInterface>(op))
handler->removeDefinitionAndAllUsers(op);
}
void add(Operation *op) {
auto maybeDef = maybeGetDefinition(op);
if (!failed(maybeDef))
handler->addUser(*maybeDef, op);
if (auto defOp = dyn_cast<mlir::SymbolOpInterface>(op))
handler->addDefinition(defOp.getNameAttr(), op);
}
SymbolHandler *handler;
};
struct PatternStatistics {
unsigned removeUnusedArcArgumentsPatternNumArgsRemoved = 0;
};
} // namespace
//===----------------------------------------------------------------------===//
// Canonicalization patterns
//===----------------------------------------------------------------------===//
namespace {
/// A rewrite pattern that has access to a symbol cache to access and modify the
/// symbol-defining op and symbol users as well as a namespace to query new
/// names. Each pattern has to make sure that the symbol handler is kept
/// up-to-date no matter whether the pattern succeeds of fails.
template <typename SourceOp>
class SymOpRewritePattern : public OpRewritePattern<SourceOp> {
public:
SymOpRewritePattern(MLIRContext *ctxt, SymbolHandler &symbolCache,
Namespace &names, PatternStatistics &stats,
mlir::PatternBenefit benefit = 1,
ArrayRef<StringRef> generatedNames = {})
: OpRewritePattern<SourceOp>(ctxt, benefit, generatedNames), names(names),
symbolCache(symbolCache), statistics(stats) {}
protected:
Namespace &names;
SymbolHandler &symbolCache;
PatternStatistics &statistics;
};
class MemWritePortEnableAndMaskCanonicalizer
: public SymOpRewritePattern<MemoryWritePortOp> {
public:
MemWritePortEnableAndMaskCanonicalizer(
MLIRContext *ctxt, SymbolHandler &symbolCache, Namespace &names,
PatternStatistics &stats, DenseMap<StringAttr, StringAttr> &arcMapping)
: SymOpRewritePattern<MemoryWritePortOp>(ctxt, symbolCache, names, stats),
arcMapping(arcMapping) {}
LogicalResult matchAndRewrite(MemoryWritePortOp op,
PatternRewriter &rewriter) const final;
private:
DenseMap<StringAttr, StringAttr> &arcMapping;
};
struct CallPassthroughArc : public SymOpRewritePattern<CallOp> {
using SymOpRewritePattern::SymOpRewritePattern;
LogicalResult matchAndRewrite(CallOp op,
PatternRewriter &rewriter) const final;
};
struct RemoveUnusedArcs : public SymOpRewritePattern<DefineOp> {
using SymOpRewritePattern::SymOpRewritePattern;
LogicalResult matchAndRewrite(DefineOp op,
PatternRewriter &rewriter) const final;
};
struct ICMPCanonicalizer : public OpRewritePattern<comb::ICmpOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(comb::ICmpOp op,
PatternRewriter &rewriter) const final;
};
struct RemoveUnusedArcArgumentsPattern : public SymOpRewritePattern<DefineOp> {
using SymOpRewritePattern::SymOpRewritePattern;
LogicalResult matchAndRewrite(DefineOp op,
PatternRewriter &rewriter) const final;
};
struct SinkArcInputsPattern : public SymOpRewritePattern<DefineOp> {
using SymOpRewritePattern::SymOpRewritePattern;
LogicalResult matchAndRewrite(DefineOp op,
PatternRewriter &rewriter) const final;
};
} // namespace
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp,
SymbolHandler &symbolCache,
PatternRewriter &rewriter) {
auto defOp = cast<DefineOp>(symbolCache.getDefinition(
callOp.getCallableForCallee().get<SymbolRefAttr>().getLeafReference()));
if (defOp.isPassthrough()) {
symbolCache.removeUser(defOp, callOp);
rewriter.replaceOp(callOp, callOp.getArgOperands());
return success();
}
return failure();
}
//===----------------------------------------------------------------------===//
// Canonicalization pattern implementations
//===----------------------------------------------------------------------===//
LogicalResult MemWritePortEnableAndMaskCanonicalizer::matchAndRewrite(
MemoryWritePortOp op, PatternRewriter &rewriter) const {
auto defOp = cast<DefineOp>(symbolCache.getDefinition(op.getArcAttr()));
APInt enable;
if (op.getEnable() &&
mlir::matchPattern(
defOp.getBodyBlock().getTerminator()->getOperand(op.getEnableIdx()),
mlir::m_ConstantInt(&enable))) {
if (enable.isZero()) {
symbolCache.removeUser(defOp, op);
rewriter.eraseOp(op);
if (symbolCache.useEmpty(defOp)) {
symbolCache.removeDefinitionAndAllUsers(defOp);
rewriter.eraseOp(defOp);
}
return success();
}
if (enable.isAllOnes()) {
if (arcMapping.count(defOp.getNameAttr())) {
auto arcWithoutEnable = arcMapping[defOp.getNameAttr()];
// Remove the enable attribute
rewriter.modifyOpInPlace(op, [&]() {
op.setEnable(false);
op.setArc(arcWithoutEnable.getValue());
});
symbolCache.removeUser(defOp, op);
symbolCache.addUser(symbolCache.getDefinition(arcWithoutEnable), op);
return success();
}
auto newName = names.newName(defOp.getName());
auto users = SmallVector<Operation *>(symbolCache.getUsers(defOp));
symbolCache.removeDefinitionAndAllUsers(defOp);
// Remove the enable attribute
rewriter.modifyOpInPlace(op, [&]() {
op.setEnable(false);
op.setArc(newName);
});
auto newResultTypes = op.getArcResultTypes();
// Create a new arc that acts as replacement for other users
rewriter.setInsertionPoint(defOp);
auto newDefOp = rewriter.cloneWithoutRegions(defOp);
auto *block = rewriter.createBlock(
&newDefOp.getBody(), newDefOp.getBody().end(),
newDefOp.getArgumentTypes(),
SmallVector<Location>(newDefOp.getNumArguments(), defOp.getLoc()));
auto callOp = rewriter.create<CallOp>(newDefOp.getLoc(), newResultTypes,
newName, block->getArguments());
SmallVector<Value> results(callOp->getResults());
Value constTrue = rewriter.create<hw::ConstantOp>(
newDefOp.getLoc(), rewriter.getI1Type(), 1);
results.insert(results.begin() + op.getEnableIdx(), constTrue);
rewriter.create<OutputOp>(newDefOp.getLoc(), results);
// Remove the enable output from the current arc
auto *terminator = defOp.getBodyBlock().getTerminator();
rewriter.modifyOpInPlace(
terminator, [&]() { terminator->eraseOperand(op.getEnableIdx()); });
rewriter.modifyOpInPlace(defOp, [&]() {
defOp.setName(newName);
defOp.setFunctionType(
rewriter.getFunctionType(defOp.getArgumentTypes(), newResultTypes));
});
// Update symbol cache
symbolCache.addDefinition(defOp.getNameAttr(), defOp);
symbolCache.addDefinition(newDefOp.getNameAttr(), newDefOp);
symbolCache.addUser(defOp, callOp);
for (auto *user : users)
symbolCache.addUser(user == op ? defOp : newDefOp, user);
arcMapping[newDefOp.getNameAttr()] = defOp.getNameAttr();
return success();
}
}
return failure();
}
LogicalResult
CallPassthroughArc::matchAndRewrite(CallOp op,
PatternRewriter &rewriter) const {
return canonicalizePassthoughCall(op, symbolCache, rewriter);
}
LogicalResult
RemoveUnusedArcs::matchAndRewrite(DefineOp op,
PatternRewriter &rewriter) const {
if (symbolCache.useEmpty(op)) {
op.getBody().walk([&](mlir::CallOpInterface user) {
if (auto symbol = user.getCallableForCallee().dyn_cast<SymbolRefAttr>())
if (auto *defOp = symbolCache.getDefinition(symbol.getLeafReference()))
symbolCache.removeUser(defOp, user);
});
symbolCache.removeDefinitionAndAllUsers(op);
rewriter.eraseOp(op);
return success();
}
return failure();
}
LogicalResult
ICMPCanonicalizer::matchAndRewrite(comb::ICmpOp op,
PatternRewriter &rewriter) const {
auto getConstant = [&](const APInt &constant) -> Value {
return rewriter.create<hw::ConstantOp>(op.getLoc(), constant);
};
auto sameWidthIntegers = [](TypeRange types) -> std::optional<unsigned> {
if (llvm::all_equal(types) && !types.empty())
if (auto intType = dyn_cast<IntegerType>(*types.begin()))
return intType.getWidth();
return std::nullopt;
};
auto negate = [&](Value input) -> Value {
auto constTrue = rewriter.create<hw::ConstantOp>(op.getLoc(), APInt(1, 1));
return rewriter.create<comb::XorOp>(op.getLoc(), input, constTrue,
op.getTwoState());
};
APInt rhs;
if (matchPattern(op.getRhs(), mlir::m_ConstantInt(&rhs))) {
if (auto concatOp = op.getLhs().getDefiningOp<comb::ConcatOp>()) {
if (auto optionalWidth =
sameWidthIntegers(concatOp->getOperands().getTypes())) {
if ((op.getPredicate() == comb::ICmpPredicate::eq ||
op.getPredicate() == comb::ICmpPredicate::ne) &&
rhs.isAllOnes()) {
Value andOp = rewriter.create<comb::AndOp>(
op.getLoc(), concatOp.getInputs(), op.getTwoState());
if (*optionalWidth == 1) {
if (op.getPredicate() == comb::ICmpPredicate::ne)
andOp = negate(andOp);
rewriter.replaceOp(op, andOp);
return success();
}
rewriter.replaceOpWithNewOp<comb::ICmpOp>(
op, op.getPredicate(), andOp,
getConstant(APInt(*optionalWidth, rhs.getZExtValue())),
op.getTwoState());
return success();
}
if ((op.getPredicate() == comb::ICmpPredicate::ne ||
op.getPredicate() == comb::ICmpPredicate::eq) &&
rhs.isZero()) {
Value orOp = rewriter.create<comb::OrOp>(
op.getLoc(), concatOp.getInputs(), op.getTwoState());
if (*optionalWidth == 1) {
if (op.getPredicate() == comb::ICmpPredicate::eq)
orOp = negate(orOp);
rewriter.replaceOp(op, orOp);
return success();
}
rewriter.replaceOpWithNewOp<comb::ICmpOp>(
op, op.getPredicate(), orOp,
getConstant(APInt(*optionalWidth, rhs.getZExtValue())),
op.getTwoState());
return success();
}
}
}
}
return failure();
}
LogicalResult RemoveUnusedArcArgumentsPattern::matchAndRewrite(
DefineOp op, PatternRewriter &rewriter) const {
BitVector toDelete(op.getNumArguments());
for (auto [i, arg] : llvm::enumerate(op.getArguments()))
if (arg.use_empty())
toDelete.set(i);
if (toDelete.none())
return failure();
// Collect the mutable callers in a first iteration. If there is a user that
// does not implement the interface, we have to abort the rewrite and have to
// make sure that we didn't change anything so far.
SmallVector<mlir::CallOpInterface> mutableUsers;
for (auto *user : symbolCache.getUsers(op)) {
auto callOpMutable = dyn_cast<mlir::CallOpInterface>(user);
if (!callOpMutable)
return failure();
mutableUsers.push_back(callOpMutable);
}
// Do the actual rewrites.
for (auto user : mutableUsers)
for (int i = toDelete.size() - 1; i >= 0; --i)
if (toDelete[i])
user.getArgOperandsMutable().erase(i);
op.eraseArguments(toDelete);
op.setFunctionType(
rewriter.getFunctionType(op.getArgumentTypes(), op.getResultTypes()));
statistics.removeUnusedArcArgumentsPatternNumArgsRemoved += toDelete.count();
return success();
}
LogicalResult
SinkArcInputsPattern::matchAndRewrite(DefineOp op,
PatternRewriter &rewriter) const {
// First check that all users implement the interface we need to be able to
// modify the users.
auto users = symbolCache.getUsers(op);
if (llvm::any_of(
users, [](auto *user) { return !isa<mlir::CallOpInterface>(user); }))
return failure();
// Find all arguments that use constant operands only.
SmallVector<Operation *> stateConsts(op.getNumArguments());
bool first = true;
for (auto *user : users) {
auto callOp = cast<mlir::CallOpInterface>(user);
for (auto [constArg, input] :
llvm::zip(stateConsts, callOp.getArgOperands())) {
if (auto *constOp = input.getDefiningOp();
constOp && constOp->template hasTrait<OpTrait::ConstantLike>()) {
if (first) {
constArg = constOp;
continue;
}
if (constArg &&
constArg->getName() == input.getDefiningOp()->getName() &&
constArg->getAttrDictionary() ==
input.getDefiningOp()->getAttrDictionary())
continue;
}
constArg = nullptr;
}
first = false;
}
// Move the constants into the arc and erase the block arguments.
rewriter.setInsertionPointToStart(&op.getBodyBlock());
llvm::BitVector toDelete(op.getBodyBlock().getNumArguments());
for (auto [constArg, arg] : llvm::zip(stateConsts, op.getArguments())) {
if (!constArg)
continue;
auto *inlinedConst = rewriter.clone(*constArg);
rewriter.replaceAllUsesWith(arg, inlinedConst->getResult(0));
toDelete.set(arg.getArgNumber());
}
op.getBodyBlock().eraseArguments(toDelete);
op.setType(rewriter.getFunctionType(op.getBodyBlock().getArgumentTypes(),
op.getResultTypes()));
// Rewrite all arc uses to not pass in the constant anymore.
for (auto *user : users) {
auto callOp = cast<mlir::CallOpInterface>(user);
SmallPtrSet<Value, 4> maybeUnusedValues;
SmallVector<Value> newInputs;
for (auto [index, value] : llvm::enumerate(callOp.getArgOperands())) {
if (toDelete[index])
maybeUnusedValues.insert(value);
else
newInputs.push_back(value);
}
rewriter.modifyOpInPlace(
callOp, [&]() { callOp.getArgOperandsMutable().assign(newInputs); });
for (auto value : maybeUnusedValues)
if (value.use_empty())
rewriter.eraseOp(value.getDefiningOp());
}
return success(toDelete.any());
}
//===----------------------------------------------------------------------===//
// ArcCanonicalizerPass implementation
//===----------------------------------------------------------------------===//
namespace {
struct ArcCanonicalizerPass
: public arc::impl::ArcCanonicalizerBase<ArcCanonicalizerPass> {
void runOnOperation() override;
};
} // namespace
void ArcCanonicalizerPass::runOnOperation() {
MLIRContext &ctxt = getContext();
SymbolTableCollection symbolTable;
SymbolHandler cache;
cache.addDefinitions(getOperation());
cache.collectAllSymbolUses(getOperation(), symbolTable);
Namespace names;
names.add(cache);
DenseMap<StringAttr, StringAttr> arcMapping;
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification = false;
config.maxIterations = 10;
config.useTopDownTraversal = true;
ArcListener listener(&cache);
config.listener = &listener;
PatternStatistics statistics;
RewritePatternSet symbolPatterns(&getContext());
symbolPatterns.add<CallPassthroughArc, RemoveUnusedArcs,
RemoveUnusedArcArgumentsPattern, SinkArcInputsPattern>(
&getContext(), cache, names, statistics);
symbolPatterns.add<MemWritePortEnableAndMaskCanonicalizer>(
&getContext(), cache, names, statistics, arcMapping);
if (failed(mlir::applyPatternsAndFoldGreedily(
getOperation(), std::move(symbolPatterns), config)))
return signalPassFailure();
numArcArgsRemoved = statistics.removeUnusedArcArgumentsPatternNumArgsRemoved;
RewritePatternSet patterns(&ctxt);
for (auto *dialect : ctxt.getLoadedDialects())
dialect->getCanonicalizationPatterns(patterns);
for (mlir::RegisteredOperationName op : ctxt.getRegisteredOperations())
op.getCanonicalizationPatterns(patterns, &ctxt);
patterns.add<ICMPCanonicalizer>(&getContext());
// Don't test for convergence since it is often not reached.
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
}
std::unique_ptr<mlir::Pass> arc::createArcCanonicalizerPass() {
return std::make_unique<ArcCanonicalizerPass>();
}