circt/lib/Dialect/Ibis/Transforms/IbisCallPrep.cpp

274 lines
9.7 KiB
C++

//===- IbisCallPrep.cpp - Implementation of call prep lowering ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetails.h"
#include "circt/Dialect/Ibis/IbisDialect.h"
#include "circt/Dialect/Ibis/IbisOps.h"
#include "circt/Dialect/Ibis/IbisPasses.h"
#include "circt/Dialect/Ibis/IbisTypes.h"
#include "circt/Dialect/HW/ConversionPatterns.h"
#include "circt/Dialect/HW/HWTypes.h"
#include "circt/Support/BackedgeBuilder.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace circt;
using namespace ibis;
/// Build indexes to make lookups faster. Create the new argument types as well.
struct CallPrepPrecomputed {
CallPrepPrecomputed(ModuleOp mod);
// Lookup a class from its symbol.
DenseMap<StringAttr, ClassOp> classSymbols;
// Mapping of method to argument type.
DenseMap<SymbolRefAttr, std::pair<hw::StructType, Location>> argTypes;
// Lookup the class to which a particular instance (in a particular class) is
// referring.
DenseMap<std::pair<ClassOp, StringAttr>, ClassOp> instanceMap;
// Lookup an entry in instanceMap. If not found, return null.
ClassOp lookupNext(ClassOp scope, StringAttr instSym) const {
auto entry = instanceMap.find(std::make_pair(scope, instSym));
if (entry == instanceMap.end())
return {};
return entry->second;
}
// Given an instance path, get the class::func symbolref for it.
SymbolRefAttr resolveInstancePath(Operation *scope, SymbolRefAttr path) const;
// Utility function to create a symbolref to a method.
static SymbolRefAttr getSymbol(MethodOp method) {
ClassOp cls = method.getParentOp();
return SymbolRefAttr::get(
cls.getSymNameAttr(),
{FlatSymbolRefAttr::get(method.getContext(), *method.getInnerName())});
}
};
CallPrepPrecomputed::CallPrepPrecomputed(ModuleOp mod) {
auto *ctxt = mod.getContext();
// Populate the class-symbol lookup table.
for (auto cls : mod.getOps<ClassOp>())
classSymbols[cls.getSymNameAttr()] = cls;
for (auto cls : mod.getOps<ClassOp>()) {
// Compute new argument types for each method.
for (auto method : cls.getOps<MethodOp>()) {
// Create the struct type.
SmallVector<hw::StructType::FieldInfo> argFields;
for (auto [argName, argType] :
llvm::zip(method.getArgNamesAttr().getAsRange<StringAttr>(),
cast<MethodLikeOpInterface>(method.getOperation())
.getArgumentTypes()))
argFields.push_back({argName, argType});
auto argStruct = hw::StructType::get(ctxt, argFields);
// Later we're gonna want the block locations, so compute a fused location
// and store it.
Location argLoc = UnknownLoc::get(ctxt);
if (method->getNumRegions() > 0) {
SmallVector<Location> argLocs;
Block *body = &method.getBody().front();
for (auto arg : body->getArguments())
argLocs.push_back(arg.getLoc());
argLoc = FusedLoc::get(ctxt, argLocs);
}
// Add both to the lookup table.
argTypes.insert(
std::make_pair(getSymbol(method), std::make_pair(argStruct, argLoc)));
}
// Populate the instances table.
for (auto inst : cls.getOps<InstanceOp>()) {
auto clsEntry = classSymbols.find(inst.getTargetNameAttr().getAttr());
assert(clsEntry != classSymbols.end() &&
"class being instantiated doesn't exist");
instanceMap[std::make_pair(cls, inst.getInnerSym().getSymName())] =
clsEntry->second;
}
}
}
SymbolRefAttr
CallPrepPrecomputed::resolveInstancePath(Operation *scope,
SymbolRefAttr path) const {
auto cls = scope->getParentOfType<ClassOp>();
assert(cls && "scope outside of ibis class");
// SymbolRefAttr is rather silly. The start of the path is root reference...
cls = lookupNext(cls, path.getRootReference());
if (!cls)
return {};
// ... then the rest are the nested references. The last one is the function
// name rather than an instance.
for (auto instSym : path.getNestedReferences().drop_back()) {
cls = lookupNext(cls, instSym.getAttr());
if (!cls)
return {};
}
// The last one is the function symbol.
return SymbolRefAttr::get(cls.getSymNameAttr(),
{FlatSymbolRefAttr::get(path.getLeafReference())});
}
namespace {
/// For each CallOp, the corresponding method signature will have changed. Pack
/// all the operands into a struct.
struct MergeCallArgs : public OpConversionPattern<CallOp> {
MergeCallArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
: OpConversionPattern(ctxt), info(info) {}
void rewrite(CallOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final;
LogicalResult match(CallOp) const override { return success(); }
private:
const CallPrepPrecomputed &info;
};
} // anonymous namespace
void MergeCallArgs::rewrite(CallOp call, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = call.getLoc();
rewriter.setInsertionPoint(call);
auto method = call->getParentOfType<ibis::MethodLikeOpInterface>();
// Use the 'info' accelerator structures to find the argument type.
SymbolRefAttr calleeSym =
info.resolveInstancePath(method, adaptor.getCalleeAttr());
auto argStructEntry = info.argTypes.find(calleeSym);
assert(argStructEntry != info.argTypes.end() && "Method symref not found!");
auto [argStruct, argLoc] = argStructEntry->second;
// Pack all of the operands into it.
auto newArg = rewriter.create<hw::StructCreateOp>(loc, argStruct,
adaptor.getOperands());
newArg->setAttr("sv.namehint",
rewriter.getStringAttr(
call.getCalleeAttr().getLeafReference().getValue() +
"_args_called_from_" +
method.getMethodName().getValue()));
// Update the call to use just the new struct.
rewriter.modifyOpInPlace(call, [&]() {
call.getOperandsMutable().clear();
call.getOperandsMutable().append(newArg.getResult());
});
}
namespace {
/// Change the method signatures to only have one argument: a struct capturing
/// all of the original arguments.
struct MergeMethodArgs : public OpConversionPattern<MethodOp> {
MergeMethodArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
: OpConversionPattern(ctxt), info(info) {}
void rewrite(MethodOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final;
LogicalResult match(MethodOp) const override { return success(); }
private:
const CallPrepPrecomputed &info;
};
} // anonymous namespace
void MergeMethodArgs::rewrite(MethodOp func, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = func.getLoc();
auto *ctxt = getContext();
// Find the pre-computed arg struct for this method.
auto argStructEntry =
info.argTypes.find(CallPrepPrecomputed::getSymbol(func));
assert(argStructEntry != info.argTypes.end() && "Cannot find symref!");
auto [argStruct, argLoc] = argStructEntry->second;
// Create a new method with the new signature.
FunctionType funcType = func.getFunctionType();
FunctionType newFuncType =
FunctionType::get(ctxt, {argStruct}, funcType.getResults());
auto newArgNames = ArrayAttr::get(ctxt, {StringAttr::get(ctxt, "arg")});
auto newMethod =
rewriter.create<MethodOp>(loc, func.getInnerSym(), newFuncType,
newArgNames, ArrayAttr(), ArrayAttr());
if (func->getNumRegions() > 0) {
// Create a body block with a struct explode to the arg struct into the
// original arguments.
Block *b =
rewriter.createBlock(&newMethod.getRegion(), {}, {argStruct}, {argLoc});
rewriter.setInsertionPointToStart(b);
auto replacementArgs =
rewriter.create<hw::StructExplodeOp>(loc, b->getArgument(0));
// Merge the original method body, rewiring the args.
Block *funcBody = &func.getBody().front();
rewriter.mergeBlocks(funcBody, b, replacementArgs.getResults());
}
rewriter.eraseOp(func);
}
namespace {
/// Run all the physical lowerings.
struct CallPrepPass : public IbisCallPrepBase<CallPrepPass> {
void runOnOperation() override;
private:
// Merge the arguments into one struct.
LogicalResult merge(const CallPrepPrecomputed &);
};
} // anonymous namespace
void CallPrepPass::runOnOperation() {
CallPrepPrecomputed info(getOperation());
if (failed(merge(info))) {
signalPassFailure();
return;
}
}
LogicalResult CallPrepPass::merge(const CallPrepPrecomputed &info) {
// Set up a conversion and give it a set of laws.
ConversionTarget target(getContext());
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
target.addDynamicallyLegalOp<CallOp>([](CallOp call) {
auto argValues = call.getArgOperands();
return argValues.size() == 1 &&
hw::type_isa<hw::StructType>(argValues.front().getType());
});
target.addDynamicallyLegalOp<MethodOp>([](MethodOp func) {
ArrayRef<Type> argTypes = func.getFunctionType().getInputs();
return argTypes.size() == 1 &&
hw::type_isa<hw::StructType>(argTypes.front());
});
// Add patterns to merge the args on both the call and method sides.
RewritePatternSet patterns(&getContext());
patterns.insert<MergeCallArgs>(&getContext(), info);
patterns.insert<MergeMethodArgs>(&getContext(), info);
return applyPartialConversion(getOperation(), target, std::move(patterns));
}
std::unique_ptr<Pass> circt::ibis::createCallPrepPass() {
return std::make_unique<CallPrepPass>();
}