circt/lib/Dialect/Sim/Transforms/LowerDPIFunc.cpp

178 lines
5.9 KiB
C++

//===- LowerDPIFunc.cpp - Lower sim.dpi.func to func.func ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===
//
// This pass lowers Sim DPI func ops to MLIR func and call.
//
// sim.dpi.func @foo(input %a: i32, output %b: i64)
// hw.module @top (..) {
// %result = sim.dpi.call @foo(%a) clock %clock
// }
//
// ->
//
// func.func @foo(%a: i32, %b: !llvm.ptr) // Output is passed by a reference.
// func.func @foo_wrapper(%a: i32) -> (i64) {
// %0 = llvm.alloca: !llvm.ptr
// %v = func.call @foo (%a, %0)
// func.return %v
// }
// hw.module @mod(..) {
// %result = sim.dpi.call @foo_wrapper(%a) clock %clock
// }
//===----------------------------------------------------------------------===//
#include "circt/Dialect/Sim/SimOps.h"
#include "circt/Dialect/Sim/SimPasses.h"
#include "circt/Support/Namespace.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "sim-lower-dpi-func"
namespace circt {
namespace sim {
#define GEN_PASS_DEF_LOWERDPIFUNC
#include "circt/Dialect/Sim/SimPasses.h.inc"
} // namespace sim
} // namespace circt
using namespace mlir;
using namespace circt;
//===----------------------------------------------------------------------===//
// Pass Implementation
//===----------------------------------------------------------------------===//
namespace {
struct LoweringState {
DenseMap<StringAttr, func::FuncOp> dpiFuncDeclMapping;
circt::Namespace nameSpace;
};
struct LowerDPIFuncPass : public sim::impl::LowerDPIFuncBase<LowerDPIFuncPass> {
LogicalResult lowerDPI();
LogicalResult lowerDPIFuncOp(sim::DPIFuncOp simFunc,
LoweringState &loweringState,
SymbolTable &symbolTable);
void runOnOperation() override;
};
} // namespace
LogicalResult LowerDPIFuncPass::lowerDPIFuncOp(sim::DPIFuncOp simFunc,
LoweringState &loweringState,
SymbolTable &symbolTable) {
ImplicitLocOpBuilder builder(simFunc.getLoc(), simFunc);
auto moduleType = simFunc.getModuleType();
llvm::SmallVector<Type> dpiFunctionArgumentTypes;
for (auto arg : moduleType.getPorts()) {
// TODO: Support a non-integer type.
if (!arg.type.isInteger())
return simFunc->emitError()
<< "non-integer type argument is unsupported now";
if (arg.dir == hw::ModulePort::Input)
dpiFunctionArgumentTypes.push_back(arg.type);
else
// Output must be passed by a reference.
dpiFunctionArgumentTypes.push_back(
LLVM::LLVMPointerType::get(arg.type.getContext()));
}
auto funcType = builder.getFunctionType(dpiFunctionArgumentTypes, {});
func::FuncOp func;
// Look up func.func by verilog name since the function name is equal to the
// symbol name in MLIR
if (auto verilogName = simFunc.getVerilogName()) {
func = symbolTable.lookup<func::FuncOp>(*verilogName);
// TODO: Check if function type matches.
}
// If a referred function is not in the same module, create an external
// function declaration.
if (!func) {
func = builder.create<func::FuncOp>(simFunc.getVerilogName()
? *simFunc.getVerilogName()
: simFunc.getSymName(),
funcType);
// External function needs to be private.
func.setPrivate();
}
// Create a wrapper module that calls a DPI function.
auto funcOp = builder.create<func::FuncOp>(
loweringState.nameSpace.newName(simFunc.getSymName() + "_wrapper"),
moduleType.getFuncType());
// Map old symbol to a new func op.
loweringState.dpiFuncDeclMapping[simFunc.getSymNameAttr()] = funcOp;
builder.setInsertionPointToStart(funcOp.addEntryBlock());
SmallVector<Value> functionInputs;
SmallVector<LLVM::AllocaOp> functionOutputAllocas;
size_t inputIndex = 0;
for (auto arg : moduleType.getPorts()) {
if (arg.dir == hw::ModulePort::InOut)
return funcOp->emitError() << "inout is currently not supported";
if (arg.dir == hw::ModulePort::Input) {
functionInputs.push_back(funcOp.getArgument(inputIndex));
++inputIndex;
} else {
// Allocate an output placeholder.
auto one = builder.create<LLVM::ConstantOp>(builder.getI64IntegerAttr(1));
auto alloca = builder.create<LLVM::AllocaOp>(
builder.getType<LLVM::LLVMPointerType>(), arg.type, one);
functionInputs.push_back(alloca);
functionOutputAllocas.push_back(alloca);
}
}
builder.create<func::CallOp>(func, functionInputs);
SmallVector<Value> results;
for (auto functionOutputAlloca : functionOutputAllocas)
results.push_back(builder.create<LLVM::LoadOp>(
functionOutputAlloca.getElemType(), functionOutputAlloca));
builder.create<func::ReturnOp>(results);
simFunc.erase();
return success();
}
LogicalResult LowerDPIFuncPass::lowerDPI() {
LLVM_DEBUG(llvm::dbgs() << "Lowering sim DPI func to func.func\n");
auto op = getOperation();
LoweringState state;
state.nameSpace.add(op);
auto &symbolTable = getAnalysis<SymbolTable>();
for (auto simFunc : llvm::make_early_inc_range(op.getOps<sim::DPIFuncOp>()))
if (failed(lowerDPIFuncOp(simFunc, state, symbolTable)))
return failure();
op.walk([&](sim::DPICallOp op) {
auto func = state.dpiFuncDeclMapping.at(op.getCalleeAttr().getAttr());
op.setCallee(func.getSymNameAttr());
});
return success();
}
void LowerDPIFuncPass::runOnOperation() {
if (failed(lowerDPI()))
return signalPassFailure();
}