[Handshake+ESI] Generate ESI memory service wrapper duing extmem lowering (#4033)

This turned out to be a bit more complicated than I had expected (assumptions placed on the hardware interface of the module, this pass still need high-level memory info, index types are used for memory in Handshake, ESI strictly requires clog2(memory size) for address signals). I'm sure there's a more principled way of doing this, but that was not what flowed out of the fingers.

A wrapper is created which instantiates an `esi.mem.ram` service for each `memref` argument of the original handshake function. This wrapper also instantiates an external module which is a stand-in for the module that will be created during `HandshakeToHW`. The load- and store ports of the to-be-lowered instance are plumbed up with esi service requests.

Given a handshake function as follows, with 1 load and 1 store port:
```mlir
handshake.func @main(%arg0: index, %arg1: index, %v: i32, %mem : memref<10xi32>, %argCtrl: none) -> none
```

the following IR is generated:
```mlir
  hw.module.extern @_main_hw(%arg0: !esi.channel<i64>, %arg1: !esi.channel<i64>, %v: !esi.channel<i32>, %mem_ld0.data: !esi.channel<i32>, %mem_st0.done: !esi.channel<i0>, %argCtrl: !esi.channel<i0>, %clock: i1, %reset: i1) -> (out0: !esi.channel<i0>, mem_ld0.addr: !esi.channel<i4>, mem_st0: !esi.channel<!hw.struct<address: i4, data: i32>>)
  esi.mem.ram @mem i32 x 10
  hw.module @main_esi_wrapper(%arg0: !esi.channel<i64>, %arg1: !esi.channel<i64>, %v: !esi.channel<i32>, %argCtrl: !esi.channel<i0>, %clock: i1, %reset: i1) -> (out0: !esi.channel<i0>) {
    esi.service.instance @mem impl as "cosim" opts {}(%clock, %reset) : (i1, i1) -> ()
    %0 = esi.service.req.inout %main.mem_ld0.addr -> <@mem::@read>([]) : !esi.channel<i4> -> !esi.channel<i32>
    %1 = esi.service.req.inout %main.mem_st0 -> <@mem::@write>([]) : !esi.channel<!hw.struct<address: i4, data: i32>> -> !esi.channel<i0>
    %main.out0, %main.mem_ld0.addr, %main.mem_st0 = hw.instance "main" @_main_hw(arg0: %arg0: !esi.channel<i64>, arg1: %arg1: !esi.channel<i64>, v: %v: !esi.channel<i32>, mem_ld0.data: %0: !esi.channel<i32>, mem_st0.done: %1: !esi.channel<i0>, argCtrl: %argCtrl: !esi.channel<i0>, clock: %clock: i1, reset: %reset: i1) -> (out0: !esi.channel<i0>, mem_ld0.addr: !esi.channel<i4>, mem_st0: !esi.channel<!hw.struct<address: i4, data: i32>>)
    hw.output %main.out0 : !esi.channel<i0>
  }
  handshake.func @main(%arg0: index, %arg1: index, %arg2: i32, %arg3: i32, %arg4: none, %arg5: none, ...) -> (none, i4, !hw.struct<address: i4, data: i32>) attributes {argNames = ["arg0", "arg1", "v", "mem_ld0.data", "mem_st0.done", "argCtrl"], resNames = ["out0", "mem_ld0.addr", "mem_st0"]}
```
This commit is contained in:
Morten Borup Petersen 2022-10-07 10:43:58 +02:00 committed by GitHub
parent ea7b547478
commit da1ad9df93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 434 additions and 178 deletions

View File

@ -14,6 +14,8 @@
#ifndef CIRCT_CONVERSION_HANDSHAKETOHW_H
#define CIRCT_CONVERSION_HANDSHAKETOHW_H
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/IR/Builders.h"
#include <memory>
namespace mlir {
@ -21,7 +23,29 @@ class Pass;
} // namespace mlir
namespace circt {
namespace esi {
class ChannelType;
} // namespace esi
std::unique_ptr<mlir::Pass> createHandshakeToHWPass();
namespace handshake {
// Converts 't' into a valid HW type. This is strictly used for converting
// 'index' types into a fixed-width type.
Type toValidType(Type t);
// Wraps a type into an ESI ChannelType type. The inner type is converted to
// ensure comprehensability with the RTL dialects.
esi::ChannelType esiWrapper(mlir::Type t);
// Returns the hw::ModulePortInfo that corresponds to the given handshake
// operation and its in- and output types.
hw::ModulePortInfo getPortInfoForOpTypes(mlir::Operation *op, TypeRange inputs,
TypeRange outputs);
} // namespace handshake
} // namespace circt
#endif // CIRCT_CONVERSION_HANDSHAKETOHW_H

View File

@ -70,7 +70,12 @@ def HandshakeLowerExtmemToHW : Pass<"handshake-lower-extmem-to-hw", "mlir::Modul
level interface).
}];
let constructor = "circt::handshake::createHandshakeLowerExtmemToHWPass()";
let dependentDialects = ["circt::hw::HWDialect"];
let dependentDialects = ["circt::hw::HWDialect", "circt::esi::ESIDialect",
"circt::comb::CombDialect", "mlir::arith::ArithDialect"];
let options = [
Option<"createESIWrapper", "wrap-esi", "bool", "false",
"Create an ESI wrapper for the module. Any extmem will be served by an esi.mem.ram service">,
];
}
def HandshakeAddIDs : Pass<"handshake-add-ids", "handshake::FuncOp"> {

View File

@ -40,6 +40,10 @@ using NameUniquer = std::function<std::string(Operation *)>;
namespace {
static Type tupleToStruct(TypeRange types) {
return toValidType(mlir::TupleType::get(types[0].getContext(), types));
}
// Shared state used by various functions; captured in a struct to reduce the
// number of arguments that we have to pass around.
struct HandshakeLoweringState {
@ -47,59 +51,6 @@ struct HandshakeLoweringState {
NameUniquer nameUniquer;
};
// NOLINTNEXTLINE(misc-no-recursion)
static Type tupleToStruct(TupleType tuple) {
auto *ctx = tuple.getContext();
mlir::SmallVector<hw::StructType::FieldInfo, 8> hwfields;
for (auto [i, innerType] : llvm::enumerate(tuple)) {
Type convertedInnerType = innerType;
if (auto tupleInnerType = innerType.dyn_cast<TupleType>())
convertedInnerType = tupleToStruct(tupleInnerType);
hwfields.push_back({StringAttr::get(ctx, "field" + std::to_string(i)),
convertedInnerType});
}
return hw::StructType::get(ctx, hwfields);
}
static Type tupleToStruct(TypeRange types) {
return tupleToStruct(mlir::TupleType::get(types[0].getContext(), types));
}
// Converts 't' into a valid HW type. This is strictly used for converting
// 'index' types into a fixed-width type.
static Type toValidType(Type t) {
return TypeSwitch<Type, Type>(t)
.Case<IndexType>(
[&](IndexType it) { return IntegerType::get(it.getContext(), 64); })
.Case<TupleType>([&](TupleType tt) {
llvm::SmallVector<Type> types;
for (auto innerType : tt)
types.push_back(toValidType(innerType));
return tupleToStruct(
mlir::TupleType::get(types[0].getContext(), types));
})
.Case<NoneType>(
[&](NoneType nt) { return IntegerType::get(nt.getContext(), 0); })
.Default([&](Type t) { return t; });
}
// Wraps a type into an ESI ChannelType type. The inner type is converted to
// ensure comprehensability by the RTL dialects.
static esi::ChannelType esiWrapper(Type t) {
return TypeSwitch<Type, esi::ChannelType>(t)
.Case<esi::ChannelType>([](auto t) { return t; })
.Case<TupleType>(
[&](TupleType tt) { return esiWrapper(tupleToStruct(tt)); })
.Case<NoneType>([](NoneType nt) {
// todo: change when handshake switches to i0
return esiWrapper(IntegerType::get(nt.getContext(), 0));
})
.Default([](auto t) {
return esi::ChannelType::get(t.getContext(), toValidType(t));
});
}
// A type converter is needed to perform the in-flight materialization of "raw"
// (non-ESI channel) types to their ESI channel correspondents. This comes into
// effect when backedges exist in the input IR.
@ -210,67 +161,6 @@ static std::string getTypeName(Location loc, Type type) {
return typeName;
}
namespace {
/// A class to be used with getPortInfoForOp. Provides an opaque interface for
/// generating the port names of an operation; handshake operations generate
/// names by the Handshake NamedIOInterface; and other operations, such as
/// arith ops, are assigned default names.
class HandshakePortNameGenerator {
public:
explicit HandshakePortNameGenerator(Operation *op)
: builder(op->getContext()) {
auto namedOpInterface = dyn_cast<handshake::NamedIOInterface>(op);
if (namedOpInterface)
inferFromNamedOpInterface(namedOpInterface);
else if (auto funcOp = dyn_cast<handshake::FuncOp>(op))
inferFromFuncOp(funcOp);
else
inferDefault(op);
}
StringAttr inputName(unsigned idx) { return inputs[idx]; }
StringAttr outputName(unsigned idx) { return outputs[idx]; }
private:
using IdxToStrF = const std::function<std::string(unsigned)> &;
void infer(Operation *op, IdxToStrF &inF, IdxToStrF &outF) {
llvm::transform(
llvm::enumerate(op->getOperandTypes()), std::back_inserter(inputs),
[&](auto it) { return builder.getStringAttr(inF(it.index())); });
llvm::transform(
llvm::enumerate(op->getResultTypes()), std::back_inserter(outputs),
[&](auto it) { return builder.getStringAttr(outF(it.index())); });
}
void inferDefault(Operation *op) {
infer(
op, [](unsigned idx) { return "in" + std::to_string(idx); },
[](unsigned idx) { return "out" + std::to_string(idx); });
}
void inferFromNamedOpInterface(handshake::NamedIOInterface op) {
infer(
op, [&](unsigned idx) { return op.getOperandName(idx); },
[&](unsigned idx) { return op.getResultName(idx); });
}
void inferFromFuncOp(handshake::FuncOp op) {
auto inF = [&](unsigned idx) { return op.getArgName(idx).str(); };
auto outF = [&](unsigned idx) { return op.getResName(idx).str(); };
llvm::transform(
llvm::enumerate(op.getArgumentTypes()), std::back_inserter(inputs),
[&](auto it) { return builder.getStringAttr(inF(it.index())); });
llvm::transform(
llvm::enumerate(op.getResultTypes()), std::back_inserter(outputs),
[&](auto it) { return builder.getStringAttr(outF(it.index())); });
}
Builder builder;
llvm::SmallVector<StringAttr> inputs;
llvm::SmallVector<StringAttr> outputs;
};
/// Construct a name for creating HW sub-module.
static std::string getSubModuleName(Operation *oldOp) {
if (auto instanceOp = dyn_cast<handshake::InstanceOp>(oldOp); instanceOp)
@ -347,8 +237,6 @@ static std::string getSubModuleName(Operation *oldOp) {
return subModuleName;
}
} // namespace
//===----------------------------------------------------------------------===//
// HW Sub-module Related Functions
//===----------------------------------------------------------------------===//
@ -376,45 +264,10 @@ static Operation *checkSubModuleOp(mlir::ModuleOp parentModule,
return moduleOp;
}
static ModulePortInfo getPortInfoForOp(OpBuilder &builder, Operation *op,
TypeRange inputs, TypeRange outputs) {
ModulePortInfo ports({}, {});
HandshakePortNameGenerator portNames(op);
// Add all inputs of funcOp.
unsigned inIdx = 0;
for (auto &arg : llvm::enumerate(inputs)) {
ports.inputs.push_back({portNames.inputName(arg.index()),
PortDirection::INPUT, esiWrapper(arg.value()),
arg.index(), StringAttr{}});
inIdx++;
}
// Add all outputs of funcOp.
for (auto &res : llvm::enumerate(outputs)) {
ports.outputs.push_back({portNames.outputName(res.index()),
PortDirection::OUTPUT, esiWrapper(res.value()),
res.index(), StringAttr{}});
}
// Add clock and reset signals.
if (op->hasTrait<mlir::OpTrait::HasClock>()) {
ports.inputs.push_back({builder.getStringAttr("clock"),
PortDirection::INPUT, builder.getI1Type(), inIdx++,
StringAttr{}});
ports.inputs.push_back({builder.getStringAttr("reset"),
PortDirection::INPUT, builder.getI1Type(), inIdx,
StringAttr{}});
}
return ports;
}
/// Returns a vector of PortInfo's which defines the HW interface of the
/// to-be-converted op.
static ModulePortInfo getPortInfoForOp(OpBuilder &builder, Operation *op) {
return getPortInfoForOp(builder, op, op->getOperandTypes(),
op->getResultTypes());
static ModulePortInfo getPortInfoForOp(Operation *op) {
return getPortInfoForOpTypes(op, op->getOperandTypes(), op->getResultTypes());
}
static llvm::SmallVector<hw::detail::FieldInfo>
@ -857,7 +710,7 @@ public:
// builder.
hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
if (!implModule) {
auto portInfo = ModulePortInfo(getPortInfoForOp(rewriter, op));
auto portInfo = ModulePortInfo(getPortInfoForOp(op));
implModule = submoduleBuilder.create<hw::HWModuleOp>(
op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
@ -1798,7 +1651,7 @@ public:
hw::HWModuleLike implModule = checkSubModuleOp(ls.parentModule, op);
if (!implModule) {
auto portInfo = ModulePortInfo(getPortInfoForOp(rewriter, op));
auto portInfo = ModulePortInfo(getPortInfoForOp(op));
implModule = submoduleBuilder.create<hw::HWModuleExternOp>(
op.getLoc(), submoduleBuilder.getStringAttr(getSubModuleName(op)),
portInfo);
@ -1823,8 +1676,8 @@ public:
LogicalResult
matchAndRewrite(handshake::FuncOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
ModulePortInfo ports = getPortInfoForOp(rewriter, op, op.getArgumentTypes(),
op.getResultTypes());
ModulePortInfo ports =
getPortInfoForOpTypes(op, op.getArgumentTypes(), op.getResultTypes());
if (op.isExternal()) {
rewriter.create<hw::HWModuleExternOp>(

View File

@ -11,6 +11,7 @@ add_circt_dialect_library(CIRCTHandshakeTransforms
LINK_LIBS PUBLIC
CIRCTHW
CIRCTESI
CIRCTHandshake
CIRCTSupport
MLIRIR

View File

@ -11,11 +11,14 @@
//===----------------------------------------------------------------------===//
#include "PassDetails.h"
#include "circt/Conversion/HandshakeToHW.h"
#include "circt/Dialect/ESI/ESIOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/HW/HWTypes.h"
#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "circt/Dialect/Handshake/HandshakePasses.h"
#include "circt/Support/BackedgeBuilder.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
@ -27,6 +30,8 @@ namespace {
using NamedType = std::pair<StringAttr, Type>;
struct HandshakeMemType {
llvm::SmallVector<NamedType> inputTypes, outputTypes;
MemRefType memRefType;
unsigned loadPorts, storePorts;
};
struct LoadName {
@ -51,6 +56,14 @@ struct StoreNames {
} // namespace
static Type indexToMemAddr(Type t, MemRefType memRef) {
assert(t.isa<IndexType>() && "Expected index type");
auto shape = memRef.getShape();
assert(shape.size() == 1 && "Expected 1D memref");
unsigned addrWidth = llvm::Log2_64_Ceil(shape[0]);
return IntegerType::get(t.getContext(), addrWidth);
}
static HandshakeMemType getMemTypeForExtmem(Value v) {
auto *ctx = v.getContext();
assert(v.getType().isa<mlir::MemRefType>() && "Value is not a memref type");
@ -58,11 +71,18 @@ static HandshakeMemType getMemTypeForExtmem(Value v) {
HandshakeMemType memType;
llvm::SmallVector<hw::detail::FieldInfo> inFields, outFields;
// Add memory type.
memType.memRefType = v.getType().cast<MemRefType>();
memType.loadPorts = extmemOp.getLdCount();
memType.storePorts = extmemOp.getStCount();
// Add load ports.
for (auto [i, ldif] : llvm::enumerate(extmemOp.getLoadPorts())) {
auto names = LoadName::get(ctx, i);
memType.inputTypes.push_back({names.dataIn, ldif.dataOut.getType()});
memType.outputTypes.push_back({names.addrOut, ldif.addressIn.getType()});
memType.outputTypes.push_back(
{names.addrOut,
indexToMemAddr(ldif.addressIn.getType(), memType.memRefType)});
}
// Add store ports.
@ -72,18 +92,19 @@ static HandshakeMemType getMemTypeForExtmem(Value v) {
// Incoming store data and address
llvm::SmallVector<hw::StructType::FieldInfo> storeOutFields;
storeOutFields.push_back(
{StringAttr::get(ctx, "data"), stif.dataIn.getType()});
{StringAttr::get(ctx, "address"),
indexToMemAddr(stif.addressIn.getType(), memType.memRefType)});
storeOutFields.push_back(
{StringAttr::get(ctx, "addr"), stif.addressIn.getType()});
{StringAttr::get(ctx, "data"), stif.dataIn.getType()});
auto inType = hw::StructType::get(ctx, storeOutFields);
memType.outputTypes.push_back({names.out, inType});
memType.inputTypes.push_back({names.doneIn, stif.doneOut.getType()});
}
return memType;
}
namespace {
struct HandshakeLowerExtmemToHWPass
: public HandshakeLowerExtmemToHWBase<HandshakeLowerExtmemToHWPass> {
void runOnOperation() override {
@ -97,11 +118,152 @@ struct HandshakeLowerExtmemToHWPass
};
LogicalResult lowerExtmemToHW(handshake::FuncOp func);
LogicalResult
wrapESI(handshake::FuncOp func, hw::ModulePortInfo origPorts,
llvm::DenseMap<unsigned, HandshakeMemType> argReplacements);
};
LogicalResult HandshakeLowerExtmemToHWPass::wrapESI(
handshake::FuncOp func, hw::ModulePortInfo origPorts,
llvm::DenseMap<unsigned, HandshakeMemType> argReplacements) {
auto *ctx = func.getContext();
OpBuilder b(func);
auto loc = func.getLoc();
// Create external module which will match the interface of 'func' after it's
// been lowered to HW.
b.setInsertionPoint(func);
auto newPortInfo = handshake::getPortInfoForOpTypes(
func, func.getArgumentTypes(), func.getResultTypes());
auto extMod = b.create<hw::HWModuleExternOp>(
loc, StringAttr::get(ctx, "_" + func.getName() + "_hw"), newPortInfo);
// Create wrapper module. This will have the same ports as the original
// module, sans the replaced arguments.
auto wrapperModPortInfo = origPorts;
llvm::SmallVector<unsigned> argReplacementsIdxs;
llvm::transform(argReplacements, std::back_inserter(argReplacementsIdxs),
[](auto &pair) { return pair.first; });
for (auto i : llvm::reverse(argReplacementsIdxs))
wrapperModPortInfo.inputs.erase(wrapperModPortInfo.inputs.begin() + i);
auto wrapperMod = b.create<hw::HWModuleOp>(
loc, StringAttr::get(ctx, func.getName() + "_esi_wrapper"),
wrapperModPortInfo);
Value clk = wrapperMod.getArgument(wrapperMod.getNumArguments() - 2);
Value rst = wrapperMod.getArgument(wrapperMod.getNumArguments() - 1);
SmallVector<Value> clkRes = {clk, rst};
b.setInsertionPointToStart(wrapperMod.getBodyBlock());
BackedgeBuilder bb(b, loc);
// Create backedges for the results of the external module. These will be
// replaced by the service instance requests if associated with a memory.
llvm::SmallVector<Backedge> backedges;
for (auto resType : extMod.getResultTypes())
backedges.push_back(bb.get(resType));
// Maintain which index we're currently at in the lowered handshake module's
// return.
unsigned resIdx = origPorts.outputs.size();
// Maintain the arguments which each memory will add to the inner module
// instance.
llvm::SmallVector<ValueRange> instanceArgsForMem;
for (auto [i, memType] : argReplacements) {
b.setInsertionPoint(wrapperMod);
// Create a memory service declaration for each memref argument that was
// served.
auto origPortInfo = origPorts.inputs[i];
auto memrefShape = memType.memRefType.getShape();
auto dataType = memType.memRefType.getElementType();
assert(memrefShape.size() == 1 && "Only 1D memrefs are supported");
unsigned memrefSize = memrefShape[0];
auto memServiceDecl = b.create<esi::RandomAccessMemoryDeclOp>(
loc, origPortInfo.name, TypeAttr::get(dataType),
b.getI64IntegerAttr(memrefSize));
SmallVector<Value> instanceArgsFromThisMem;
// Create service requests. This MUST follow the order of which ports were
// added in other parts of this pass (load ports first, then store ports).
b.setInsertionPointToStart(wrapperMod.getBodyBlock());
// Load ports:
auto loadServicePort = hw::InnerRefAttr::get(memServiceDecl.getNameAttr(),
b.getStringAttr("read"));
for (unsigned i = 0; i < memType.loadPorts; ++i) {
auto loadReq = b.create<esi::RequestInOutChannelOp>(
loc, handshake::esiWrapper(dataType), loadServicePort,
backedges[resIdx], b.getArrayAttr({}));
instanceArgsFromThisMem.push_back(loadReq);
++resIdx;
}
// Store ports:
auto storeServicePort = hw::InnerRefAttr::get(memServiceDecl.getNameAttr(),
b.getStringAttr("write"));
for (unsigned i = 0; i < memType.storePorts; ++i) {
auto storeReq = b.create<esi::RequestInOutChannelOp>(
loc, handshake::esiWrapper(b.getIntegerType(0)), storeServicePort,
backedges[resIdx], b.getArrayAttr({}));
instanceArgsFromThisMem.push_back(storeReq);
++resIdx;
}
instanceArgsForMem.push_back(instanceArgsFromThisMem);
}
// Stitch together arguments from the top-level ESI wrapper and the instance
// arguments generated from the service requests.
llvm::SmallVector<Value> instanceArgs;
for (unsigned i = 0, e = wrapperMod.getNumArguments(); i < e; ++i) {
if (argReplacements.count(i)) {
// This index was originally a memref - pop the instance arguments for the
// next-in-line memory and add them.
auto memArgs = instanceArgsForMem.front();
instanceArgsForMem.erase(instanceArgsForMem.begin());
instanceArgs.append(memArgs.begin(), memArgs.end());
}
// Add the argument from the wrapper mod.
instanceArgs.push_back(wrapperMod.getArgument(i));
}
// Instantiate the inner module.
auto instance =
b.create<hw::InstanceOp>(loc, extMod, func.getName(), instanceArgs);
// And resolve the backedges.
for (auto [res, be] : llvm::zip(instance.getResults(), backedges))
be.setValue(res);
// Finally, grab the (non-memory) outputs from the inner module and return
// them through the wrapper.
auto outputOp =
cast<hw::OutputOp>(wrapperMod.getBodyBlock()->getTerminator());
b.setInsertionPoint(outputOp);
b.create<hw::OutputOp>(outputOp.getLoc(), instance.getResults().take_front(
wrapperMod.getNumResults()));
outputOp.erase();
return success();
}
// Truncates the index-typed 'v' into an integer-type of the same width as the
// 'memref' argument.
// Uses arith operations since these are supported in the HandshakeToHW
// lowering.
static Value truncateToMemoryWidth(Location loc, OpBuilder &b, Value v,
MemRefType memRefType) {
assert(v.getType().isa<IndexType>() && "Expected an index-typed value");
auto addrWidth = llvm::Log2_64_Ceil(memRefType.getShape().front());
return b.create<arith::IndexCastOp>(loc, b.getIntegerType(addrWidth), v);
}
static Value plumbLoadPort(Location loc, OpBuilder &b,
const handshake::ExtMemLoadInterface &ldif,
Value loadData) {
Value loadData, MemRefType memrefType) {
// We need to feed both the load data and the load done outputs.
// Fork the extracted load data into two, and 'join' the second one to
// generate a none-typed output to drive the load done.
@ -114,16 +276,20 @@ static Value plumbLoadPort(Location loc, OpBuilder &b,
ldif.dataOut.replaceAllUsesWith(dataOut);
ldif.doneOut.replaceAllUsesWith(dataDone);
// Return load address, to be fed to the top-level output.
return ldif.addressIn;
// Return load address, to be fed to the top-level output, truncated to the
// width of the memory that is accessed.
return truncateToMemoryWidth(loc, b, ldif.addressIn, memrefType);
}
static Value plumbStorePort(Location loc, OpBuilder &b,
const handshake::ExtMemStoreInterface &stif,
Value done, Type outType) {
Value done, Type outType, MemRefType memrefType) {
stif.doneOut.replaceAllUsesWith(done);
// Return the store and data to be fed to the top-level output.
llvm::SmallVector<Value> structArgs = {stif.dataIn, stif.addressIn};
// Return the store address and data to be fed to the top-level output.
// Address is truncated to the width of the memory that is accessed.
llvm::SmallVector<Value> structArgs = {
truncateToMemoryWidth(loc, b, stif.addressIn, memrefType), stif.dataIn};
return b
.create<hw::StructCreateOp>(loc, outType.cast<hw::StructType>(),
structArgs)
@ -160,6 +326,12 @@ static void eraseFromArrayAttr(Operation *op, StringRef attrName,
op->setAttr(attrName, ArrayAttr::get(ctx, newArr));
}
struct ArgTypeReplacement {
unsigned index;
TypeRange ins;
TypeRange outs;
};
LogicalResult
HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
// Gather memref ports to be converted.
@ -171,6 +343,13 @@ HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
if (memrefArgs.empty())
return success(); // nothing to do.
// Record which arg indices were replaces with handshake memory ports.
llvm::DenseMap<unsigned, HandshakeMemType> argReplacements;
// Record the hw.module i/o of the original func (used for ESI wrapper).
auto origPortInfo = handshake::getPortInfoForOpTypes(
func, func.getArgumentTypes(), func.getResultTypes());
OpBuilder b(func);
for (auto it : memrefArgs) {
// Do not use structured bindings for 'it' - cannot reference inside lambda.
@ -183,6 +362,7 @@ HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
// Add memory input - this is the output of the extmemory op.
auto memIOTypes = getMemTypeForExtmem(arg);
MemRefType memrefType = arg.getType().cast<MemRefType>();
auto oldReturnOp =
cast<handshake::ReturnOp>(func.getBody().front().getTerminator());
@ -211,7 +391,8 @@ HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
for (auto loadPort : extmemOp.getLoadPorts()) {
auto newInPort = addArgRes(loadPort.index, memIOTypes.inputTypes[portIdx],
memIOTypes.outputTypes[portIdx]);
newReturnOperands.push_back(plumbLoadPort(loc, b, loadPort, newInPort));
newReturnOperands.push_back(
plumbLoadPort(loc, b, loadPort, newInPort, memrefType));
++portIdx;
}
@ -222,7 +403,7 @@ HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
memIOTypes.outputTypes[portIdx]);
newReturnOperands.push_back(
plumbStorePort(loc, b, storePort, newInPort,
memIOTypes.outputTypes[portIdx].second));
memIOTypes.outputTypes[portIdx].second, memrefType));
++portIdx;
}
@ -240,8 +421,14 @@ HandshakeLowerExtmemToHWPass::lowerExtmemToHW(handshake::FuncOp func) {
// use has been removed.
func.eraseArgument(i + addedInPorts);
eraseFromArrayAttr(func, "argNames", i + addedInPorts);
argReplacements[i] = memIOTypes;
}
if (createESIWrapper)
if (failed(wrapESI(func, origPortInfo, argReplacements)))
return failure();
return success();
}

View File

@ -16,8 +16,11 @@
#ifndef DIALECT_HANDSHAKE_TRANSFORMS_PASSDETAILS_H
#define DIALECT_HANDSHAKE_TRANSFORMS_PASSDETAILS_H
#include "circt/Dialect/Comb/CombDialect.h"
#include "circt/Dialect/ESI/ESIDialect.h"
#include "circt/Dialect/HW/HWDialect.h"
#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Pass/Pass.h"
namespace circt {

View File

@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "PassDetails.h"
#include "circt/Dialect/ESI/ESIOps.h"
#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "circt/Dialect/Handshake/HandshakePasses.h"
#include "circt/Support/LLVM.h"
@ -26,6 +27,7 @@ using namespace handshake;
using namespace mlir;
namespace circt {
namespace handshake {
/// Iterates over the handshake::FuncOp's in the program to build an instance
@ -124,5 +126,160 @@ LogicalResult verifyAllValuesHasOneUse(handshake::FuncOp funcOp) {
return success();
}
// NOLINTNEXTLINE(misc-no-recursion)
static Type tupleToStruct(TupleType tuple) {
auto *ctx = tuple.getContext();
mlir::SmallVector<hw::StructType::FieldInfo, 8> hwfields;
for (auto [i, innerType] : llvm::enumerate(tuple)) {
Type convertedInnerType = innerType;
if (auto tupleInnerType = innerType.dyn_cast<TupleType>())
convertedInnerType = tupleToStruct(tupleInnerType);
hwfields.push_back({StringAttr::get(ctx, "field" + std::to_string(i)),
convertedInnerType});
}
return hw::StructType::get(ctx, hwfields);
}
// Converts 't' into a valid HW type. This is strictly used for converting
// 'index' types into a fixed-width type.
Type toValidType(Type t) {
return TypeSwitch<Type, Type>(t)
.Case<IndexType>(
[&](IndexType it) { return IntegerType::get(it.getContext(), 64); })
.Case<TupleType>([&](TupleType tt) {
llvm::SmallVector<Type> types;
for (auto innerType : tt)
types.push_back(toValidType(innerType));
return tupleToStruct(
mlir::TupleType::get(types[0].getContext(), types));
})
.Case<hw::StructType>([&](auto st) {
llvm::SmallVector<hw::StructType::FieldInfo> structFields(
st.getElements());
for (auto &field : structFields)
field.type = toValidType(field.type);
return hw::StructType::get(st.getContext(), structFields);
})
.Case<NoneType>(
[&](NoneType nt) { return IntegerType::get(nt.getContext(), 0); })
.Default([&](Type t) { return t; });
}
// Wraps a type into an ESI ChannelType type. The inner type is converted to
// ensure comprehensability by the RTL dialects.
esi::ChannelType esiWrapper(Type t) {
return TypeSwitch<Type, esi::ChannelType>(t)
.Case<esi::ChannelType>([](auto t) { return t; })
.Case<TupleType>(
[&](TupleType tt) { return esiWrapper(tupleToStruct(tt)); })
.Case<NoneType>([](NoneType nt) {
// todo: change when handshake switches to i0
return esiWrapper(IntegerType::get(nt.getContext(), 0));
})
.Default([](auto t) {
return esi::ChannelType::get(t.getContext(), toValidType(t));
});
}
namespace {
/// A class to be used with getPortInfoForOp. Provides an opaque interface for
/// generating the port names of an operation; handshake operations generate
/// names by the Handshake NamedIOInterface; and other operations, such as
/// arith ops, are assigned default names.
class HandshakePortNameGenerator {
public:
explicit HandshakePortNameGenerator(Operation *op)
: builder(op->getContext()) {
auto namedOpInterface = dyn_cast<handshake::NamedIOInterface>(op);
if (namedOpInterface)
inferFromNamedOpInterface(namedOpInterface);
else if (auto funcOp = dyn_cast<handshake::FuncOp>(op))
inferFromFuncOp(funcOp);
else
inferDefault(op);
}
StringAttr inputName(unsigned idx) { return inputs[idx]; }
StringAttr outputName(unsigned idx) { return outputs[idx]; }
private:
using IdxToStrF = const std::function<std::string(unsigned)> &;
void infer(Operation *op, IdxToStrF &inF, IdxToStrF &outF) {
llvm::transform(
llvm::enumerate(op->getOperandTypes()), std::back_inserter(inputs),
[&](auto it) { return builder.getStringAttr(inF(it.index())); });
llvm::transform(
llvm::enumerate(op->getResultTypes()), std::back_inserter(outputs),
[&](auto it) { return builder.getStringAttr(outF(it.index())); });
}
void inferDefault(Operation *op) {
infer(
op, [](unsigned idx) { return "in" + std::to_string(idx); },
[](unsigned idx) { return "out" + std::to_string(idx); });
}
void inferFromNamedOpInterface(handshake::NamedIOInterface op) {
infer(
op, [&](unsigned idx) { return op.getOperandName(idx); },
[&](unsigned idx) { return op.getResultName(idx); });
}
void inferFromFuncOp(handshake::FuncOp op) {
auto inF = [&](unsigned idx) { return op.getArgName(idx).str(); };
auto outF = [&](unsigned idx) { return op.getResName(idx).str(); };
llvm::transform(
llvm::enumerate(op.getArgumentTypes()), std::back_inserter(inputs),
[&](auto it) { return builder.getStringAttr(inF(it.index())); });
llvm::transform(
llvm::enumerate(op.getResultTypes()), std::back_inserter(outputs),
[&](auto it) { return builder.getStringAttr(outF(it.index())); });
}
Builder builder;
llvm::SmallVector<StringAttr> inputs;
llvm::SmallVector<StringAttr> outputs;
};
} // namespace
hw::ModulePortInfo getPortInfoForOpTypes(Operation *op, TypeRange inputs,
TypeRange outputs) {
hw::ModulePortInfo ports({}, {});
HandshakePortNameGenerator portNames(op);
auto *ctx = op->getContext();
Type i1Type = IntegerType::get(ctx, 1);
// Add all inputs of funcOp.
unsigned inIdx = 0;
for (auto &arg : llvm::enumerate(inputs)) {
ports.inputs.push_back({portNames.inputName(arg.index()),
hw::PortDirection::INPUT, esiWrapper(arg.value()),
arg.index(), StringAttr{}});
inIdx++;
}
// Add all outputs of funcOp.
for (auto &res : llvm::enumerate(outputs)) {
ports.outputs.push_back({portNames.outputName(res.index()),
hw::PortDirection::OUTPUT, esiWrapper(res.value()),
res.index(), StringAttr{}});
}
// Add clock and reset signals.
if (op->hasTrait<mlir::OpTrait::HasClock>()) {
ports.inputs.push_back({StringAttr::get(ctx, "clock"),
hw::PortDirection::INPUT, i1Type, inIdx++,
StringAttr{}});
ports.inputs.push_back({StringAttr::get(ctx, "reset"),
hw::PortDirection::INPUT, i1Type, inIdx,
StringAttr{}});
}
return ports;
}
} // namespace handshake
} // namespace circt

View File

@ -182,3 +182,4 @@ hw.module @MemoryAccess1(%clk: i1, %rst: i1, %write: !esi.channel<!write>, %read
%readData = esi.service.req.inout %readAddress -> <@MemA::@read> ([]) : !esi.channel<i5> -> !esi.channel<i64>
hw.output %readData, %done : !esi.channel<i64>, !esi.channel<i0>
}

View File

@ -0,0 +1,23 @@
// RUN: circt-opt -handshake-lower-extmem-to-hw="wrap-esi=true" %s | FileCheck %s
//CHECK-LABEL: hw.module.extern @_main_hw(%arg0: !esi.channel<i64>, %arg1: !esi.channel<i64>, %v: !esi.channel<i32>, %mem_ld0.data: !esi.channel<i32>, %mem_st0.done: !esi.channel<i0>, %argCtrl: !esi.channel<i0>, %clock: i1, %reset: i1) -> (out0: !esi.channel<i0>, mem_ld0.addr: !esi.channel<i4>, mem_st0: !esi.channel<!hw.struct<address: i4, data: i32>>)
//CHECK-LABEL: esi.mem.ram @mem i32 x 10
//CHECK-LABEL: hw.module @main_esi_wrapper(%arg0: !esi.channel<i64>, %arg1: !esi.channel<i64>, %v: !esi.channel<i32>, %argCtrl: !esi.channel<i0>, %clock: i1, %reset: i1) -> (out0: !esi.channel<i0>) {
//CHECK-NEXT: %0 = esi.service.req.inout %main.mem_ld0.addr -> <@mem::@read>([]) : !esi.channel<i4> -> !esi.channel<i32>
//CHECK-NEXT: %1 = esi.service.req.inout %main.mem_st0 -> <@mem::@write>([]) : !esi.channel<!hw.struct<address: i4, data: i32>> -> !esi.channel<i0>
//CHECK-NEXT: %main.out0, %main.mem_ld0.addr, %main.mem_st0 = hw.instance "main" @_main_hw(arg0: %arg0: !esi.channel<i64>, arg1: %arg1: !esi.channel<i64>, v: %v: !esi.channel<i32>, mem_ld0.data: %0: !esi.channel<i32>, mem_st0.done: %1: !esi.channel<i0>, argCtrl: %argCtrl: !esi.channel<i0>, clock: %clock: i1, reset: %reset: i1) -> (out0: !esi.channel<i0>, mem_ld0.addr: !esi.channel<i4>, mem_st0: !esi.channel<!hw.struct<address: i4, data: i32>>)
//CHECK-NEXT: hw.output %main.out0 : !esi.channel<i0>
//CHECK-NEXT: }
handshake.func @main(%arg0: index, %arg1: index, %v: i32, %mem : memref<10xi32>, %argCtrl: none) -> none {
%ldData, %stCtrl, %ldCtrl = handshake.extmemory[ld=1, st=1](%mem : memref<10xi32>)(%storeData, %storeAddr, %loadAddr) {id = 0 : i32} : (i32, index, index) -> (i32, none, none)
%fCtrl:2 = fork [2] %argCtrl : none
%loadData, %loadAddr = load [%arg0] %ldData, %fCtrl#0 : index, i32
%storeData, %storeAddr = store [%arg1] %v, %fCtrl#1 : index, i32
sink %loadData : i32
%finCtrl = join %stCtrl, %ldCtrl : none, none
return %finCtrl : none
}

View File

@ -1,16 +1,18 @@
// RUN: circt-opt -handshake-lower-extmem-to-hw %s | FileCheck %s
// CHECK-LABEL: handshake.func @main(
// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: none, %[[VAL_5:.*]]: none, ...) -> (none, index, !hw.struct<data: i32, addr: index>) attributes {argNames = ["arg0", "arg1", "v", "mem_ld0.data", "mem_st0.done", "argCtrl"], resNames = ["out0", "mem_ld0.addr", "mem_st0"]} {
// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: none, %[[VAL_5:.*]]: none, ...) -> (none, i4, !hw.struct<address: i4, data: i32>)
// CHECK: %[[VAL_6:.*]]:2 = fork [2] %[[VAL_3]] : i32
// CHECK: %[[VAL_7:.*]] = join %[[VAL_6]]#1 : i32
// CHECK: %[[VAL_8:.*]] = hw.struct_create (%[[VAL_9:.*]], %[[VAL_10:.*]]) : !hw.struct<data: i32, addr: index>
// CHECK: %[[VAL_11:.*]]:2 = fork [2] %[[VAL_5]] : none
// CHECK: %[[VAL_12:.*]], %[[VAL_13:.*]] = load {{\[}}%[[VAL_0]]] %[[VAL_6]]#0, %[[VAL_11]]#0 : index, i32
// CHECK: %[[VAL_9]], %[[VAL_10]] = store {{\[}}%[[VAL_1]]] %[[VAL_2]], %[[VAL_11]]#1 : index, i32
// CHECK: sink %[[VAL_12]] : i32
// CHECK: %[[VAL_14:.*]] = join %[[VAL_4]], %[[VAL_7]] : none, none
// CHECK: return %[[VAL_14]], %[[VAL_13]], %[[VAL_8]] : none, index, !hw.struct<data: i32, addr: index>
// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_9:.*]] : index to i4
// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_11:.*]] : index to i4
// CHECK: %[[VAL_12:.*]] = hw.struct_create (%[[VAL_10]], %[[VAL_13:.*]]) : !hw.struct<address: i4, data: i32>
// CHECK: %[[VAL_14:.*]]:2 = fork [2] %[[VAL_5]] : none
// CHECK: %[[VAL_15:.*]], %[[VAL_9]] = load {{\[}}%[[VAL_0]]] %[[VAL_6]]#0, %[[VAL_14]]#0 : index, i32
// CHECK: %[[VAL_13]], %[[VAL_11]] = store {{\[}}%[[VAL_1]]] %[[VAL_2]], %[[VAL_14]]#1 : index, i32
// CHECK: sink %[[VAL_15]] : i32
// CHECK: %[[VAL_16:.*]] = join %[[VAL_4]], %[[VAL_7]] : none, none
// CHECK: return %[[VAL_16]], %[[VAL_8]], %[[VAL_12]] : none, i4, !hw.struct<address: i4, data: i32>
// CHECK: }
handshake.func @main(%arg0: index, %arg1: index, %v: i32, %mem : memref<10xi32>, %argCtrl: none) -> none {