[Arc][Sim] Lower Sim DPI func to func.func and support dpi call in Arc (#7386)

This PR implements initial support for lowering Sim DPI operations to Arc. 

* sim::LowerDPIFuncPass implements lowering from `sim.dpi.func` to `func.func` that respects C-level ABI. 
* arc::LowerStatePass is modified to allocate states and call functions for `sim.dpi.call` op. 

Currently unclocked call is not supported yet.
This commit is contained in:
Hideto Ueno 2024-08-07 13:51:14 +09:00 committed by GitHub
parent 1a8f82e7a6
commit 9828707817
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 378 additions and 41 deletions

View File

@ -23,4 +23,9 @@ def ProceduralizeSim : Pass<"sim-proceduralize", "hw::HWModuleOp"> {
let dependentDialects = ["circt::hw::HWDialect, circt::seq::SeqDialect, mlir::scf::SCFDialect"]; let dependentDialects = ["circt::hw::HWDialect, circt::seq::SeqDialect, mlir::scf::SCFDialect"];
} }
def LowerDPIFunc : Pass<"sim-lower-dpi-func", "mlir::ModuleOp"> {
let summary = "Lower sim.dpi.func into func.func for the simulation flow";
let dependentDialects = ["mlir::func::FuncDialect", "mlir::LLVM::LLVMDialect"];
}
#endif // CIRCT_DIALECT_SIM_SEQPASSES #endif // CIRCT_DIALECT_SIM_SEQPASSES

View File

@ -0,0 +1,39 @@
// RUN: arcilator %s --run --jit-entry=main | FileCheck %s
// REQUIRES: arcilator-jit
// CHECK: c = 0
// CHECK-NEXT: c = 5
sim.func.dpi @dpi(in %a : i32, in %b : i32, out c : i32) attributes {verilogName = "adder_func"}
func.func @adder_func(%arg0: i32, %arg1: i32, %arg2: !llvm.ptr) {
%0 = arith.addi %arg0, %arg1 : i32
llvm.store %0, %arg2 : i32, !llvm.ptr
return
}
hw.module @adder(in %clock : i1, in %a : i32, in %b : i32, out c : i32) {
%seq_clk = seq.to_clock %clock
%0 = sim.func.dpi.call @dpi(%a, %b) clock %seq_clk : (i32, i32) -> i32
hw.output %0 : i32
}
func.func @main() {
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%one = arith.constant 1 : i1
%zero = arith.constant 0 : i1
arc.sim.instantiate @adder as %arg0 {
arc.sim.set_input %arg0, "a" = %c2_i32 : i32, !arc.sim.instance<@adder>
arc.sim.set_input %arg0, "b" = %c3_i32 : i32, !arc.sim.instance<@adder>
arc.sim.set_input %arg0, "clock" = %one : i1, !arc.sim.instance<@adder>
arc.sim.step %arg0 : !arc.sim.instance<@adder>
arc.sim.set_input %arg0, "clock" = %zero : i1, !arc.sim.instance<@adder>
%0 = arc.sim.get_port %arg0, "c" : i32, !arc.sim.instance<@adder>
arc.sim.emit "c", %0 : i32
arc.sim.step %arg0 : !arc.sim.instance<@adder>
arc.sim.set_input %arg0, "clock" = %one : i1, !arc.sim.instance<@adder>
%2 = arc.sim.get_port %arg0, "c" : i32, !arc.sim.instance<@adder>
arc.sim.emit "c", %2 : i32
}
return
}

View File

@ -11,5 +11,6 @@ add_circt_conversion_library(CIRCTConvertToArcs
CIRCTArc CIRCTArc
CIRCTHW CIRCTHW
CIRCTSeq CIRCTSeq
CIRCTSim
MLIRTransforms MLIRTransforms
) )

View File

@ -10,6 +10,7 @@
#include "circt/Dialect/Arc/ArcOps.h" #include "circt/Dialect/Arc/ArcOps.h"
#include "circt/Dialect/HW/HWOps.h" #include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/Seq/SeqOps.h" #include "circt/Dialect/Seq/SeqOps.h"
#include "circt/Dialect/Sim/SimOps.h"
#include "circt/Support/Namespace.h" #include "circt/Support/Namespace.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
@ -25,7 +26,7 @@ using llvm::MapVector;
static bool isArcBreakingOp(Operation *op) { static bool isArcBreakingOp(Operation *op) {
return op->hasTrait<OpTrait::ConstantLike>() || return op->hasTrait<OpTrait::ConstantLike>() ||
isa<hw::InstanceOp, seq::CompRegOp, MemoryOp, ClockedOpInterface, isa<hw::InstanceOp, seq::CompRegOp, MemoryOp, ClockedOpInterface,
seq::ClockGateOp>(op) || seq::ClockGateOp, sim::DPICallOp>(op) ||
op->getNumResults() > 1; op->getNumResults() > 1;
} }

View File

@ -35,6 +35,7 @@ add_circt_dialect_library(CIRCTArcTransforms
CIRCTOM CIRCTOM
CIRCTSV CIRCTSV
CIRCTSeq CIRCTSeq
CIRCTSim
CIRCTSupport CIRCTSupport
MLIRFuncDialect MLIRFuncDialect
MLIRLLVMDialect MLIRLLVMDialect

View File

@ -11,6 +11,7 @@
#include "circt/Dialect/Comb/CombOps.h" #include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h" #include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/Seq/SeqOps.h" #include "circt/Dialect/Seq/SeqOps.h"
#include "circt/Dialect/Sim/SimOps.h"
#include "circt/Support/BackedgeBuilder.h" #include "circt/Support/BackedgeBuilder.h"
#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
@ -117,7 +118,12 @@ struct ModuleLowering {
LogicalResult lowerPrimaryInputs(); LogicalResult lowerPrimaryInputs();
LogicalResult lowerPrimaryOutputs(); LogicalResult lowerPrimaryOutputs();
LogicalResult lowerStates(); LogicalResult lowerStates();
template <typename CallTy>
LogicalResult lowerStateLike(Operation *op, Value clock, Value enable,
Value reset, ArrayRef<Value> inputs,
FlatSymbolRefAttr callee);
LogicalResult lowerState(StateOp stateOp); LogicalResult lowerState(StateOp stateOp);
LogicalResult lowerState(sim::DPICallOp dpiCallOp);
LogicalResult lowerState(MemoryOp memOp); LogicalResult lowerState(MemoryOp memOp);
LogicalResult lowerState(MemoryWritePortOp memWriteOp); LogicalResult lowerState(MemoryWritePortOp memWriteOp);
LogicalResult lowerState(TapOp tapOp); LogicalResult lowerState(TapOp tapOp);
@ -139,7 +145,7 @@ static bool shouldMaterialize(Operation *op) {
return !isa<MemoryOp, AllocStateOp, AllocMemoryOp, AllocStorageOp, return !isa<MemoryOp, AllocStateOp, AllocMemoryOp, AllocStorageOp,
ClockTreeOp, PassThroughOp, RootInputOp, RootOutputOp, ClockTreeOp, PassThroughOp, RootInputOp, RootOutputOp,
StateWriteOp, MemoryWritePortOp, igraph::InstanceOpInterface, StateWriteOp, MemoryWritePortOp, igraph::InstanceOpInterface,
StateOp>(op); StateOp, sim::DPICallOp>(op);
} }
static bool shouldMaterialize(Value value) { static bool shouldMaterialize(Value value) {
@ -390,53 +396,48 @@ LogicalResult ModuleLowering::lowerPrimaryOutputs() {
LogicalResult ModuleLowering::lowerStates() { LogicalResult ModuleLowering::lowerStates() {
SmallVector<Operation *> opsToLower; SmallVector<Operation *> opsToLower;
for (auto &op : *moduleOp.getBodyBlock()) for (auto &op : *moduleOp.getBodyBlock())
if (isa<StateOp, MemoryOp, MemoryWritePortOp, TapOp>(&op)) if (isa<StateOp, MemoryOp, MemoryWritePortOp, TapOp, sim::DPICallOp>(&op))
opsToLower.push_back(&op); opsToLower.push_back(&op);
for (auto *op : opsToLower) { for (auto *op : opsToLower) {
LLVM_DEBUG(llvm::dbgs() << "- Lowering " << *op << "\n"); LLVM_DEBUG(llvm::dbgs() << "- Lowering " << *op << "\n");
auto result = TypeSwitch<Operation *, LogicalResult>(op) auto result =
.Case<StateOp, MemoryOp, MemoryWritePortOp, TapOp>( TypeSwitch<Operation *, LogicalResult>(op)
[&](auto op) { return lowerState(op); }) .Case<StateOp, MemoryOp, MemoryWritePortOp, TapOp, sim::DPICallOp>(
.Default(success()); [&](auto op) { return lowerState(op); })
.Default(success());
if (failed(result)) if (failed(result))
return failure(); return failure();
} }
return success(); return success();
} }
LogicalResult ModuleLowering::lowerState(StateOp stateOp) { template <typename CallOpTy>
// We don't support arcs beyond latency 1 yet. These should be easy to add in LogicalResult ModuleLowering::lowerStateLike(
// the future though. Operation *stateOp, Value stateClock, Value stateEnable, Value stateReset,
if (stateOp.getLatency() > 1) ArrayRef<Value> stateInputs, FlatSymbolRefAttr callee) {
return stateOp.emitError("state with latency > 1 not supported"); // Grab all operands from the state op at the callsite and make it drop all
// its references. This allows `materializeValue` to move an operation if this
// Grab all operands from the state op and make it drop all its references. // state was the last user.
// This allows `materializeValue` to move an operation if this state was the
// last user.
auto stateClock = stateOp.getClock();
auto stateEnable = stateOp.getEnable();
auto stateReset = stateOp.getReset();
auto stateInputs = SmallVector<Value>(stateOp.getInputs());
// Get the clock tree and enable condition for this state's clock. If this arc // Get the clock tree and enable condition for this state's clock. If this arc
// carries an explicit enable condition, fold that into the enable provided by // carries an explicit enable condition, fold that into the enable provided by
// the clock gates in the arc's clock tree. // the clock gates in the arc's clock tree.
auto info = getOrCreateClockLowering(stateClock); auto info = getOrCreateClockLowering(stateClock);
info.enable = info.clock.getOrCreateAnd( info.enable = info.clock.getOrCreateAnd(
info.enable, info.clock.materializeValue(stateEnable), stateOp.getLoc()); info.enable, info.clock.materializeValue(stateEnable), stateOp->getLoc());
// Allocate the necessary state within the model. // Allocate the necessary state within the model.
SmallVector<Value> allocatedStates; SmallVector<Value> allocatedStates;
for (unsigned stateIdx = 0; stateIdx < stateOp.getNumResults(); ++stateIdx) { for (unsigned stateIdx = 0; stateIdx < stateOp->getNumResults(); ++stateIdx) {
auto type = stateOp.getResult(stateIdx).getType(); auto type = stateOp->getResult(stateIdx).getType();
auto intType = dyn_cast<IntegerType>(type); auto intType = dyn_cast<IntegerType>(type);
if (!intType) if (!intType)
return stateOp.emitOpError("result ") return stateOp->emitOpError("result ")
<< stateIdx << " has non-integer type " << type << stateIdx << " has non-integer type " << type
<< "; only integer types are supported"; << "; only integer types are supported";
auto stateType = StateType::get(intType); auto stateType = StateType::get(intType);
auto state = stateBuilder.create<AllocStateOp>(stateOp.getLoc(), stateType, auto state = stateBuilder.create<AllocStateOp>(stateOp->getLoc(), stateType,
storageArg); storageArg);
if (auto names = stateOp->getAttrOfType<ArrayAttr>("names")) if (auto names = stateOp->getAttrOfType<ArrayAttr>("names"))
state->setAttr("name", names[stateIdx]); state->setAttr("name", names[stateIdx]);
@ -455,18 +456,18 @@ LogicalResult ModuleLowering::lowerState(StateOp stateOp) {
OpBuilder nonResetBuilder = info.clock.builder; OpBuilder nonResetBuilder = info.clock.builder;
if (stateReset) { if (stateReset) {
auto materializedReset = info.clock.materializeValue(stateReset); auto materializedReset = info.clock.materializeValue(stateReset);
auto ifOp = info.clock.builder.create<scf::IfOp>(stateOp.getLoc(), auto ifOp = info.clock.builder.create<scf::IfOp>(stateOp->getLoc(),
materializedReset, true); materializedReset, true);
for (auto [alloc, resTy] : for (auto [alloc, resTy] :
llvm::zip(allocatedStates, stateOp.getResultTypes())) { llvm::zip(allocatedStates, stateOp->getResultTypes())) {
if (!isa<IntegerType>(resTy)) if (!isa<IntegerType>(resTy))
stateOp->emitOpError("Non-integer result not supported yet!"); stateOp->emitOpError("Non-integer result not supported yet!");
auto thenBuilder = ifOp.getThenBodyBuilder(); auto thenBuilder = ifOp.getThenBodyBuilder();
Value constZero = Value constZero =
thenBuilder.create<hw::ConstantOp>(stateOp.getLoc(), resTy, 0); thenBuilder.create<hw::ConstantOp>(stateOp->getLoc(), resTy, 0);
thenBuilder.create<StateWriteOp>(stateOp.getLoc(), alloc, constZero, thenBuilder.create<StateWriteOp>(stateOp->getLoc(), alloc, constZero,
Value()); Value());
} }
@ -475,24 +476,50 @@ LogicalResult ModuleLowering::lowerState(StateOp stateOp) {
stateOp->dropAllReferences(); stateOp->dropAllReferences();
auto newStateOp = nonResetBuilder.create<CallOp>( auto newStateOp = nonResetBuilder.create<CallOpTy>(
stateOp.getLoc(), stateOp.getResultTypes(), stateOp.getArcAttr(), stateOp->getLoc(), stateOp->getResultTypes(), callee,
materializedOperands); materializedOperands);
// Create the write ops that write the result of the transfer function to the // Create the write ops that write the result of the transfer function to the
// allocated state storage. // allocated state storage.
for (auto [alloc, result] : for (auto [alloc, result] :
llvm::zip(allocatedStates, newStateOp.getResults())) llvm::zip(allocatedStates, newStateOp.getResults()))
nonResetBuilder.create<StateWriteOp>(stateOp.getLoc(), alloc, result, nonResetBuilder.create<StateWriteOp>(stateOp->getLoc(), alloc, result,
info.enable); info.enable);
// Replace all uses of the arc with reads from the allocated state. // Replace all uses of the arc with reads from the allocated state.
for (auto [alloc, result] : llvm::zip(allocatedStates, stateOp.getResults())) for (auto [alloc, result] : llvm::zip(allocatedStates, stateOp->getResults()))
replaceValueWithStateRead(result, alloc); replaceValueWithStateRead(result, alloc);
stateOp.erase(); stateOp->erase();
return success(); return success();
} }
LogicalResult ModuleLowering::lowerState(StateOp stateOp) {
// We don't support arcs beyond latency 1 yet. These should be easy to add in
// the future though.
if (stateOp.getLatency() > 1)
return stateOp.emitError("state with latency > 1 not supported");
auto stateInputs = SmallVector<Value>(stateOp.getInputs());
return lowerStateLike<arc::CallOp>(stateOp, stateOp.getClock(),
stateOp.getEnable(), stateOp.getReset(),
stateInputs, stateOp.getArcAttr());
}
LogicalResult ModuleLowering::lowerState(sim::DPICallOp callOp) {
// Clocked call op can be considered as arc state with single latency.
auto stateClock = callOp.getClock();
if (!stateClock)
return callOp.emitError("unclocked DPI call not implemented yet");
auto stateInputs = SmallVector<Value>(callOp.getInputs());
return lowerStateLike<func::CallOp>(callOp, stateClock, callOp.getEnable(),
Value(), stateInputs,
callOp.getCalleeAttr());
}
LogicalResult ModuleLowering::lowerState(MemoryOp memOp) { LogicalResult ModuleLowering::lowerState(MemoryOp memOp) {
auto allocMemOp = stateBuilder.create<AllocMemoryOp>( auto allocMemOp = stateBuilder.create<AllocMemoryOp>(
memOp.getLoc(), memOp.getType(), storageArg, memOp->getAttrs()); memOp.getLoc(), memOp.getType(), storageArg, memOp->getAttrs());

View File

@ -29,6 +29,7 @@ add_circt_dialect_library(CIRCTSim
CIRCTHW CIRCTHW
CIRCTSeq CIRCTSeq
CIRCTSV CIRCTSV
MLIRFuncDialect
MLIRIR MLIRIR
MLIRPass MLIRPass
MLIRTransforms MLIRTransforms

View File

@ -13,6 +13,7 @@
#include "circt/Dialect/Sim/SimOps.h" #include "circt/Dialect/Sim/SimOps.h"
#include "circt/Dialect/HW/ModuleImplementation.h" #include "circt/Dialect/HW/ModuleImplementation.h"
#include "circt/Dialect/SV/SVOps.h" #include "circt/Dialect/SV/SVOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Interfaces/FunctionImplementation.h"
@ -69,12 +70,15 @@ ParseResult DPIFuncOp::parse(OpAsmParser &parser, OperationState &result) {
LogicalResult LogicalResult
sim::DPICallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { sim::DPICallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto referencedOp = dyn_cast_or_null<sim::DPIFuncOp>( auto referencedOp =
symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr())); symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr());
if (!referencedOp) if (!referencedOp)
return emitError("cannot find function declaration '") return emitError("cannot find function declaration '")
<< getCallee() << "'"; << getCallee() << "'";
return success(); if (isa<func::FuncOp, sim::DPIFuncOp>(referencedOp))
return success();
return emitError("callee must be 'sim.dpi.func' or 'func.func' but got '")
<< referencedOp->getName() << "'";
} }
void DPIFuncOp::print(OpAsmPrinter &p) { void DPIFuncOp::print(OpAsmPrinter &p) {

View File

@ -1,4 +1,5 @@
add_circt_dialect_library(CIRCTSimTransforms add_circt_dialect_library(CIRCTSimTransforms
LowerDPIFunc.cpp
ProceduralizeSim.cpp ProceduralizeSim.cpp
@ -12,8 +13,10 @@ add_circt_dialect_library(CIRCTSimTransforms
CIRCTSV CIRCTSV
CIRCTComb CIRCTComb
CIRCTSupport CIRCTSupport
MLIRFuncDialect
MLIRIR MLIRIR
MLIRPass MLIRPass
MLIRLLVMDialect
MLIRSCFDialect MLIRSCFDialect
MLIRTransformUtils MLIRTransformUtils
) )

View File

@ -0,0 +1,177 @@
//===- LowerDPIFunc.cpp - Lower sim.dpi.func to func.func ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===
//
// This pass lowers Sim DPI func ops to MLIR func and call.
//
// sim.dpi.func @foo(input %a: i32, output %b: i64)
// hw.module @top (..) {
// %result = sim.dpi.call @foo(%a) clock %clock
// }
//
// ->
//
// func.func @foo(%a: i32, %b: !llvm.ptr) // Output is passed by a reference.
// func.func @foo_wrapper(%a: i32) -> (i64) {
// %0 = llvm.alloca: !llvm.ptr
// %v = func.call @foo (%a, %0)
// func.return %v
// }
// hw.module @mod(..) {
// %result = sim.dpi.call @foo_wrapper(%a) clock %clock
// }
//===----------------------------------------------------------------------===//
#include "circt/Dialect/Sim/SimOps.h"
#include "circt/Dialect/Sim/SimPasses.h"
#include "circt/Support/Namespace.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "sim-lower-dpi-func"
namespace circt {
namespace sim {
#define GEN_PASS_DEF_LOWERDPIFUNC
#include "circt/Dialect/Sim/SimPasses.h.inc"
} // namespace sim
} // namespace circt
using namespace mlir;
using namespace circt;
//===----------------------------------------------------------------------===//
// Pass Implementation
//===----------------------------------------------------------------------===//
namespace {
struct LoweringState {
DenseMap<StringAttr, func::FuncOp> dpiFuncDeclMapping;
circt::Namespace nameSpace;
};
struct LowerDPIFuncPass : public sim::impl::LowerDPIFuncBase<LowerDPIFuncPass> {
LogicalResult lowerDPI();
LogicalResult lowerDPIFuncOp(sim::DPIFuncOp simFunc,
LoweringState &loweringState,
SymbolTable &symbolTable);
void runOnOperation() override;
};
} // namespace
LogicalResult LowerDPIFuncPass::lowerDPIFuncOp(sim::DPIFuncOp simFunc,
LoweringState &loweringState,
SymbolTable &symbolTable) {
ImplicitLocOpBuilder builder(simFunc.getLoc(), simFunc);
auto moduleType = simFunc.getModuleType();
llvm::SmallVector<Type> dpiFunctionArgumentTypes;
for (auto arg : moduleType.getPorts()) {
// TODO: Support a non-integer type.
if (!arg.type.isInteger())
return simFunc->emitError()
<< "non-integer type argument is unsupported now";
if (arg.dir == hw::ModulePort::Input)
dpiFunctionArgumentTypes.push_back(arg.type);
else
// Output must be passed by a reference.
dpiFunctionArgumentTypes.push_back(
LLVM::LLVMPointerType::get(arg.type.getContext()));
}
auto funcType = builder.getFunctionType(dpiFunctionArgumentTypes, {});
func::FuncOp func;
// Look up func.func by verilog name since the function name is equal to the
// symbol name in MLIR
if (auto verilogName = simFunc.getVerilogName()) {
func = symbolTable.lookup<func::FuncOp>(*verilogName);
// TODO: Check if function type matches.
}
// If a referred function is not in the same module, create an external
// function declaration.
if (!func) {
func = builder.create<func::FuncOp>(simFunc.getVerilogName()
? *simFunc.getVerilogName()
: simFunc.getSymName(),
funcType);
// External function needs to be private.
func.setPrivate();
}
// Create a wrapper module that calls a DPI function.
auto funcOp = builder.create<func::FuncOp>(
loweringState.nameSpace.newName(simFunc.getSymName() + "_wrapper"),
moduleType.getFuncType());
// Map old symbol to a new func op.
loweringState.dpiFuncDeclMapping[simFunc.getSymNameAttr()] = funcOp;
builder.setInsertionPointToStart(funcOp.addEntryBlock());
SmallVector<Value> functionInputs;
SmallVector<LLVM::AllocaOp> functionOutputAllocas;
size_t inputIndex = 0;
for (auto arg : moduleType.getPorts()) {
if (arg.dir == hw::ModulePort::InOut)
return funcOp->emitError() << "inout is currently not supported";
if (arg.dir == hw::ModulePort::Input) {
functionInputs.push_back(funcOp.getArgument(inputIndex));
++inputIndex;
} else {
// Allocate an output placeholder.
auto one = builder.create<LLVM::ConstantOp>(builder.getI64IntegerAttr(1));
auto alloca = builder.create<LLVM::AllocaOp>(
builder.getType<LLVM::LLVMPointerType>(), arg.type, one);
functionInputs.push_back(alloca);
functionOutputAllocas.push_back(alloca);
}
}
builder.create<func::CallOp>(func, functionInputs);
SmallVector<Value> results;
for (auto functionOutputAlloca : functionOutputAllocas)
results.push_back(builder.create<LLVM::LoadOp>(
functionOutputAlloca.getElemType(), functionOutputAlloca));
builder.create<func::ReturnOp>(results);
simFunc.erase();
return success();
}
LogicalResult LowerDPIFuncPass::lowerDPI() {
LLVM_DEBUG(llvm::dbgs() << "Lowering sim DPI func to func.func\n");
auto op = getOperation();
LoweringState state;
state.nameSpace.add(op);
auto &symbolTable = getAnalysis<SymbolTable>();
for (auto simFunc : llvm::make_early_inc_range(op.getOps<sim::DPIFuncOp>()))
if (failed(lowerDPIFuncOp(simFunc, state, symbolTable)))
return failure();
op.walk([&](sim::DPICallOp op) {
auto func = state.dpiFuncDeclMapping.at(op.getCalleeAttr().getAttr());
op.setCallee(func.getSymNameAttr());
});
return success();
}
void LowerDPIFuncPass::runOnOperation() {
if (failed(lowerDPI()))
return signalPassFailure();
}

View File

@ -353,3 +353,16 @@ hw.module @BlackBox(in %clk: !seq.clock) {
} }
// CHECK-NOT: hw.module.extern private @BlackBoxExt // CHECK-NOT: hw.module.extern private @BlackBoxExt
hw.module.extern private @BlackBoxExt(in %a: i42, in %b: i42, out c: i42, out d: i42) hw.module.extern private @BlackBoxExt(in %a: i42, in %b: i42, out c: i42, out d: i42)
func.func private @func(%arg0: i32, %arg1: i32) -> i32
// CHECK-LABEL: arc.model @adder
hw.module @adder(in %clock : i1, in %a : i32, in %b : i32, out c : i32) {
%0 = seq.to_clock %clock
%1 = sim.func.dpi.call @func(%a, %b) clock %0 : (i32, i32) -> i32
// CHECK: arc.clock_tree
// CHECK-NEXT: %[[A:.+]] = arc.state_read %in_a : <i32>
// CHECK-NEXT: %[[B:.+]] = arc.state_read %in_b : <i32>
// CHECK-NEXT: %[[RESULT:.+]] = func.call @func(%6, %7) : (i32, i32) -> i32
hw.output %1 : i32
}

View File

@ -0,0 +1,50 @@
// RUN: circt-opt --sim-lower-dpi-func %s | FileCheck %s
sim.func.dpi @foo(out arg0: i32, in %arg1: i32, out arg2: i32)
// CHECK-LABEL: func.func private @foo(!llvm.ptr, i32, !llvm.ptr)
// CHECK-LABEL: func.func @foo_wrapper(%arg0: i32) -> (i32, i32) {
// CHECK-NEXT: %0 = llvm.mlir.constant(1 : i64) : i64
// CHECK-NEXT: %1 = llvm.alloca %0 x i32 : (i64) -> !llvm.ptr
// CHECK-NEXT: %2 = llvm.mlir.constant(1 : i64) : i64
// CHECK-NEXT: %3 = llvm.alloca %2 x i32 : (i64) -> !llvm.ptr
// CHECK-NEXT: call @foo(%1, %arg0, %3) : (!llvm.ptr, i32, !llvm.ptr) -> ()
// CHECK-NEXT: %4 = llvm.load %1 : !llvm.ptr -> i32
// CHECK-NEXT: %5 = llvm.load %3 : !llvm.ptr -> i32
// CHECK-NEXT: return %4, %5 : i32, i32
// CHECK-NEXT: }
// CHECK-LABEL: func.func @bar_wrapper(%arg0: i32) -> (i32, i32) {
// CHECK-NEXT: %0 = llvm.mlir.constant(1 : i64) : i64
// CHECK-NEXT: %1 = llvm.alloca %0 x i32 : (i64) -> !llvm.ptr
// CHECK-NEXT: %2 = llvm.mlir.constant(1 : i64) : i64
// CHECK-NEXT: %3 = llvm.alloca %2 x i32 : (i64) -> !llvm.ptr
// CHECK-NEXT: call @bar_c_name(%1, %arg0, %3) : (!llvm.ptr, i32, !llvm.ptr) -> ()
// CHECK-NEXT: %4 = llvm.load %1 : !llvm.ptr -> i32
// CHECK-NEXT: %5 = llvm.load %3 : !llvm.ptr -> i32
// CHECK-NEXT: return %4, %5 : i32, i32
// CHECK-NEXT: }
// CHECK-LABEL: func.func @bar_c_name
sim.func.dpi @bar(out arg0: i32, in %arg1: i32, out arg2: i32) attributes {verilogName="bar_c_name"}
func.func @bar_c_name(%arg0: !llvm.ptr, %arg1: i32, %arg2: !llvm.ptr) {
func.return
}
// CHECK-LABEL: func.func private @baz_c_name(!llvm.ptr, i32, !llvm.ptr)
// CHECK-LABEL: func.func @baz_wrapper(%arg0: i32) -> (i32, i32)
// CHECK: call @baz_c_name(%1, %arg0, %3) : (!llvm.ptr, i32, !llvm.ptr) -> ()
sim.func.dpi @baz(out arg0: i32, in %arg1: i32, out arg2: i32) attributes {verilogName="baz_c_name"}
// CHECK-LABEL: hw.module @dpi_call
hw.module @dpi_call(in %clock : !seq.clock, in %enable : i1, in %in: i32,
out o1: i32, out o2: i32, out o3: i32, out o4: i32, out o5: i32, out o6: i32) {
// CHECK-NEXT: %0:2 = sim.func.dpi.call @foo_wrapper(%in) clock %clock : (i32) -> (i32, i32)
// CHECK-NEXT: %1:2 = sim.func.dpi.call @bar_wrapper(%in) : (i32) -> (i32, i32)
// CHECK-NEXT: %2:2 = sim.func.dpi.call @baz_wrapper(%in) : (i32) -> (i32, i32)
// CHECK-NEXT: hw.output %0#0, %0#1, %1#0, %1#1, %2#0, %2#1 : i32, i32, i32, i32, i32, i32
%0, %1 = sim.func.dpi.call @foo(%in) clock %clock : (i32) -> (i32, i32)
%2, %3 = sim.func.dpi.call @bar(%in) : (i32) -> (i32, i32)
%4, %5 = sim.func.dpi.call @baz(%in) : (i32) -> (i32, i32)
hw.output %0, %1, %2, %3, %4, %5 : i32, i32, i32, i32, i32, i32
}

View File

@ -19,14 +19,15 @@ hw.module @stop_finish(in %clock : !seq.clock, in %cond : i1) {
// CHECK-LABEL: sim.func.dpi @dpi(out arg0 : i1, in %arg1 : i1, out arg2 : i1) // CHECK-LABEL: sim.func.dpi @dpi(out arg0 : i1, in %arg1 : i1, out arg2 : i1)
sim.func.dpi @dpi(out arg0: i1, in %arg1: i1, out arg2: i1) sim.func.dpi @dpi(out arg0: i1, in %arg1: i1, out arg2: i1)
func.func private @func(%arg1: i1) -> (i1, i1)
hw.module @dpi_call(in %clock : !seq.clock, in %enable : i1, in %in: i1) { hw.module @dpi_call(in %clock : !seq.clock, in %enable : i1, in %in: i1) {
// CHECK: sim.func.dpi.call @dpi(%in) clock %clock enable %enable : (i1) -> (i1, i1) // CHECK: sim.func.dpi.call @dpi(%in) clock %clock enable %enable : (i1) -> (i1, i1)
%0, %1 = sim.func.dpi.call @dpi(%in) clock %clock enable %enable: (i1) -> (i1, i1) %0, %1 = sim.func.dpi.call @dpi(%in) clock %clock enable %enable: (i1) -> (i1, i1)
// CHECK: sim.func.dpi.call @dpi(%in) clock %clock : (i1) -> (i1, i1) // CHECK: sim.func.dpi.call @dpi(%in) clock %clock : (i1) -> (i1, i1)
%2, %3 = sim.func.dpi.call @dpi(%in) clock %clock : (i1) -> (i1, i1) %2, %3 = sim.func.dpi.call @dpi(%in) clock %clock : (i1) -> (i1, i1)
// CHECK: sim.func.dpi.call @dpi(%in) enable %enable : (i1) -> (i1, i1) // CHECK: sim.func.dpi.call @func(%in) enable %enable : (i1) -> (i1, i1)
%4, %5 = sim.func.dpi.call @dpi(%in) enable %enable : (i1) -> (i1, i1) %4, %5 = sim.func.dpi.call @func(%in) enable %enable : (i1) -> (i1, i1)
// CHECK: sim.func.dpi.call @dpi(%in) : (i1) -> (i1, i1) // CHECK: sim.func.dpi.call @func(%in) : (i1) -> (i1, i1)
%6, %7 = sim.func.dpi.call @dpi(%in) : (i1) -> (i1, i1) %6, %7 = sim.func.dpi.call @func(%in) : (i1) -> (i1, i1)
} }

View File

@ -42,3 +42,12 @@ hw.module @proc_print_sv() {
sim.proc.print %lit sim.proc.print %lit
} }
} }
// -----
hw.module.extern @non_func(out arg0: i1, in %arg1: i1, out arg2: i1)
hw.module @dpi_call(in %clock : !seq.clock, in %in: i1) {
// expected-error @below {{callee must be 'sim.dpi.func' or 'func.func' but got 'hw.module.extern'}}
%0, %1 = sim.func.dpi.call @non_func(%in) : (i1) -> (i1, i1)
}

View File

@ -20,6 +20,7 @@ target_link_libraries(arcilator
CIRCTOM CIRCTOM
CIRCTSeqToSV CIRCTSeqToSV
CIRCTSeqTransforms CIRCTSeqTransforms
CIRCTSimTransforms
CIRCTSupport CIRCTSupport
CIRCTTransforms CIRCTTransforms
MLIRBuiltinToLLVMIRTranslation MLIRBuiltinToLLVMIRTranslation

View File

@ -22,6 +22,8 @@
#include "circt/Dialect/Emit/EmitDialect.h" #include "circt/Dialect/Emit/EmitDialect.h"
#include "circt/Dialect/HW/HWPasses.h" #include "circt/Dialect/HW/HWPasses.h"
#include "circt/Dialect/Seq/SeqPasses.h" #include "circt/Dialect/Seq/SeqPasses.h"
#include "circt/Dialect/Sim/SimDialect.h"
#include "circt/Dialect/Sim/SimPasses.h"
#include "circt/InitAllDialects.h" #include "circt/InitAllDialects.h"
#include "circt/InitAllPasses.h" #include "circt/InitAllPasses.h"
#include "circt/Support/Passes.h" #include "circt/Support/Passes.h"
@ -249,6 +251,7 @@ static void populateHwModuleToArcPipeline(PassManager &pm) {
opts.tapMemories = observeMemories; opts.tapMemories = observeMemories;
pm.addPass(arc::createInferMemoriesPass(opts)); pm.addPass(arc::createInferMemoriesPass(opts));
} }
pm.addPass(sim::createLowerDPIFunc());
pm.addPass(createCSEPass()); pm.addPass(createCSEPass());
pm.addPass(arc::createArcCanonicalizerPass()); pm.addPass(arc::createArcCanonicalizerPass());
@ -567,6 +570,7 @@ static LogicalResult executeArcilator(MLIRContext &context) {
mlir::scf::SCFDialect, mlir::scf::SCFDialect,
om::OMDialect, om::OMDialect,
seq::SeqDialect, seq::SeqDialect,
sim::SimDialect,
sv::SVDialect sv::SVDialect
>(); >();
// clang-format on // clang-format on