mirror of https://github.com/llvm/circt.git
[Seq] Add a pass to convert an array seq.firreg to seq.firmem (#8716)
This commit introduces a new transformation pass `RegOfVecToMem` that converts register arrays following memory access patterns into `seq.firmem` operations. When a valid pattern is detected, the pass replaces the register array with a `seq.firmem` operation and corresponding read/write ports. This is required for the `circt-verilog` tool, to identify memories, such that other `circt` transformations/analysis can be run in the `seq` dialect on the `mlir` parsed from verilog.
This commit is contained in:
parent
640daf0c92
commit
4ce45d581f
|
@ -31,6 +31,7 @@ std::unique_ptr<mlir::Pass> createLowerSeqFIFOPass();
|
|||
std::unique_ptr<mlir::Pass>
|
||||
createHWMemSimImplPass(const HWMemSimImplOptions &options = {});
|
||||
std::unique_ptr<mlir::Pass> createLowerSeqShiftRegPass();
|
||||
std::unique_ptr<mlir::Pass> createRegOfVecToMem();
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
|
|
|
@ -111,4 +111,16 @@ def LowerSeqShiftReg : Pass<"lower-seq-shiftreg", "hw::HWModuleOp"> {
|
|||
let dependentDialects = ["circt::hw::HWDialect"];
|
||||
}
|
||||
|
||||
def RegOfVecToMem : Pass<"seq-reg-of-vec-to-mem", "hw::HWModuleOp"> {
|
||||
let summary = "Convert register arrays to FIRRTL memories";
|
||||
let description = [{
|
||||
This pass identifies register arrays that follow memory access patterns
|
||||
and converts them to seq.firmem operations. It looks for patterns where:
|
||||
1. A register array is updated via array_inject operations
|
||||
2. The array is read via array_get operations
|
||||
3. Updates are controlled by enable signals through mux operations
|
||||
4. Read and write operations use the same clock
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // CIRCT_DIALECT_SEQ_SEQPASSES
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
add_circt_dialect_library(CIRCTSeqTransforms
|
||||
ExternalizeClockGate.cpp
|
||||
HWMemSimImpl.cpp
|
||||
RegOfVecToMem.cpp
|
||||
LowerSeqHLMem.cpp
|
||||
LowerSeqFIFO.cpp
|
||||
LowerSeqShiftReg.cpp
|
||||
|
|
|
@ -0,0 +1,261 @@
|
|||
//===- RegOfVecToMem.cpp - Convert Register Arrays to Memories -----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This transformation pass converts register arrays that follow memory access
|
||||
// patterns to seq.firmem operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "circt/Dialect/Comb/CombOps.h"
|
||||
#include "circt/Dialect/HW/HWOps.h"
|
||||
#include "circt/Dialect/Seq/SeqOps.h"
|
||||
#include "circt/Dialect/Seq/SeqPasses.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "reg-of-vec-to-mem"
|
||||
|
||||
using namespace circt;
|
||||
using namespace seq;
|
||||
using namespace hw;
|
||||
|
||||
namespace circt {
|
||||
namespace seq {
|
||||
#define GEN_PASS_DEF_REGOFVECTOMEM
|
||||
#include "circt/Dialect/Seq/SeqPasses.h.inc"
|
||||
} // namespace seq
|
||||
} // namespace circt
|
||||
|
||||
namespace {
|
||||
|
||||
struct MemoryPattern {
|
||||
FirRegOp memReg; // The register array representing memory
|
||||
FirRegOp outputReg; // Optional output register
|
||||
Value clock; // Clock signal
|
||||
Value readAddr; // Read address
|
||||
Value writeAddr; // Write address
|
||||
Value writeData; // Write data
|
||||
Value writeEnable; // Write enable
|
||||
Value readEnable; // Read enable (optional)
|
||||
comb::MuxOp writeMux; // Mux selecting between old/new memory state
|
||||
comb::MuxOp readMux; // Mux for read data
|
||||
hw::ArrayGetOp readAccess; // Array read operation
|
||||
hw::ArrayInjectOp writeAccess; // Array write operation
|
||||
};
|
||||
|
||||
class RegOfVecToMemPass : public impl::RegOfVecToMemBase<RegOfVecToMemPass> {
|
||||
public:
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
bool analyzeMemoryPattern(FirRegOp reg, MemoryPattern &pattern);
|
||||
bool createFirMemory(MemoryPattern &pattern);
|
||||
bool isArrayType(Type type);
|
||||
std::optional<std::pair<uint64_t, uint64_t>> getArrayDimensions(Type type);
|
||||
|
||||
SmallVector<Operation *> opsToErase;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
bool RegOfVecToMemPass::isArrayType(Type type) {
|
||||
return isa<hw::ArrayType, hw::UnpackedArrayType>(type);
|
||||
}
|
||||
|
||||
std::optional<std::pair<uint64_t, uint64_t>>
|
||||
RegOfVecToMemPass::getArrayDimensions(Type type) {
|
||||
if (auto arrayType = dyn_cast<hw::ArrayType>(type)) {
|
||||
auto elemType = arrayType.getElementType();
|
||||
if (auto intType = dyn_cast<IntegerType>(elemType)) {
|
||||
return std::make_pair(arrayType.getNumElements(), intType.getWidth());
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
bool RegOfVecToMemPass::analyzeMemoryPattern(FirRegOp reg,
|
||||
MemoryPattern &pattern) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Analyzing register: " << reg << "\n");
|
||||
|
||||
// Check if register has array type
|
||||
if (!isArrayType(reg.getType()))
|
||||
return false;
|
||||
|
||||
ArrayGetOp readAccess;
|
||||
ArrayInjectOp writeAccess;
|
||||
comb::MuxOp writeMux;
|
||||
for (auto *user : reg.getResult().getUsers()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " Register user: " << *user << "\n");
|
||||
if (auto arrayGet = dyn_cast<hw::ArrayGetOp>(user); !readAccess && arrayGet)
|
||||
readAccess = arrayGet;
|
||||
else if (auto arrayInject = dyn_cast<hw::ArrayInjectOp>(user);
|
||||
!writeAccess && arrayInject)
|
||||
writeAccess = arrayInject;
|
||||
else if (auto mux = dyn_cast<comb::MuxOp>(user); !writeMux && mux)
|
||||
writeMux = mux;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
if (!readAccess || !writeAccess || !writeMux)
|
||||
return false;
|
||||
|
||||
pattern.memReg = reg;
|
||||
pattern.clock = reg.getClk();
|
||||
|
||||
// Find the mux that drives this register
|
||||
auto nextValue = reg.getNext();
|
||||
auto mux = nextValue.getDefiningOp<comb::MuxOp>();
|
||||
if (!mux)
|
||||
return false;
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << " Found driving mux: " << mux << "\n");
|
||||
pattern.writeMux = mux;
|
||||
|
||||
// Check that the mux is only used by this register (safety check)
|
||||
if (!mux.getResult().hasOneUse()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " Mux has multiple uses, cannot transform\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Analyze mux inputs: sel ? write_result : current_memory
|
||||
Value writeResult = mux.getTrueValue();
|
||||
Value currentMemory = mux.getFalseValue();
|
||||
|
||||
// Check if false value is the current register (feedback)
|
||||
if (currentMemory != reg.getResult())
|
||||
return false;
|
||||
|
||||
// Look for array_inject operation in write path
|
||||
auto arrayInject = writeResult.getDefiningOp<hw::ArrayInjectOp>();
|
||||
if (!arrayInject)
|
||||
return false;
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << " Found array_inject: " << arrayInject << "\n");
|
||||
pattern.writeAccess = arrayInject;
|
||||
pattern.writeAddr = arrayInject.getIndex();
|
||||
pattern.writeData = arrayInject.getElement();
|
||||
pattern.writeEnable = mux.getCond();
|
||||
|
||||
// Look for read pattern - find array_get users
|
||||
auto arrayGet = readAccess;
|
||||
LLVM_DEBUG(llvm::dbgs() << " Found array_get: " << arrayGet << "\n");
|
||||
pattern.readAccess = arrayGet;
|
||||
pattern.readAddr = arrayGet.getIndex();
|
||||
|
||||
// Check if read goes through output register
|
||||
for (auto *readUser : arrayGet.getResult().getUsers()) {
|
||||
if (auto outputReg = dyn_cast<FirRegOp>(readUser)) {
|
||||
if (outputReg.getClk() == pattern.clock) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " Found output register: " << outputReg << "\n");
|
||||
pattern.outputReg = outputReg;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool success = pattern.readAccess != nullptr;
|
||||
LLVM_DEBUG(llvm::dbgs() << " Pattern analysis "
|
||||
<< (success ? "succeeded" : "failed") << "\n");
|
||||
return success;
|
||||
}
|
||||
|
||||
bool RegOfVecToMemPass::createFirMemory(MemoryPattern &pattern) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Creating FirMemory for pattern\n");
|
||||
|
||||
auto dims = getArrayDimensions(pattern.memReg.getType());
|
||||
if (!dims)
|
||||
return false;
|
||||
|
||||
uint64_t depth = dims->first;
|
||||
uint64_t width = dims->second;
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << " Memory dimensions: " << depth << " x " << width
|
||||
<< "\n");
|
||||
|
||||
ImplicitLocOpBuilder builder(pattern.memReg.getLoc(), pattern.memReg);
|
||||
|
||||
// Create FirMem
|
||||
auto memType =
|
||||
FirMemType::get(builder.getContext(), depth, width, /*maskWidth=*/1);
|
||||
auto firMem = builder.create<seq::FirMemOp>(
|
||||
memType, /*readLatency=*/0, /*writeLatency=*/1,
|
||||
/*readUnderWrite=*/seq::RUW::Undefined,
|
||||
/*writeUnderWrite=*/seq::WUW::Undefined,
|
||||
/*name=*/builder.getStringAttr("mem"), /*innerSym=*/hw::InnerSymAttr{},
|
||||
/*init=*/seq::FirMemInitAttr{}, /*prefix=*/StringAttr{},
|
||||
/*outputFile=*/Attribute{});
|
||||
|
||||
// Create read port
|
||||
Value readData = builder.create<FirMemReadOp>(
|
||||
firMem, pattern.readAddr, pattern.clock,
|
||||
/*enable=*/builder.create<hw::ConstantOp>(builder.getI1Type(), 1));
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << " Created read port\n"
|
||||
<< firMem << "\n " << readData);
|
||||
|
||||
Value mask;
|
||||
// Create write port
|
||||
builder.create<FirMemWriteOp>(firMem, pattern.writeAddr, pattern.clock,
|
||||
pattern.writeEnable, pattern.writeData, mask);
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << " Created write port\n");
|
||||
|
||||
// Replace read access
|
||||
if (pattern.outputReg)
|
||||
// If there's an output register, replace its input
|
||||
pattern.outputReg.getNext().replaceAllUsesWith(readData);
|
||||
else
|
||||
// Replace direct read access
|
||||
pattern.readAccess.getResult().replaceAllUsesWith(readData);
|
||||
|
||||
// Mark old operations for removal
|
||||
opsToErase.push_back(pattern.memReg);
|
||||
if (pattern.readAccess)
|
||||
opsToErase.push_back(pattern.readAccess);
|
||||
if (pattern.writeAccess)
|
||||
opsToErase.push_back(pattern.writeAccess);
|
||||
if (pattern.writeMux)
|
||||
opsToErase.push_back(pattern.writeMux);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void RegOfVecToMemPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
|
||||
SmallVector<FirRegOp> arrayRegs;
|
||||
|
||||
// Collect all FirRegOp with array types
|
||||
module.walk([&](FirRegOp reg) {
|
||||
if (isArrayType(reg.getType())) {
|
||||
arrayRegs.push_back(reg);
|
||||
}
|
||||
});
|
||||
|
||||
// Analyze each array register for memory patterns
|
||||
for (auto reg : arrayRegs) {
|
||||
MemoryPattern pattern;
|
||||
if (analyzeMemoryPattern(reg, pattern)) {
|
||||
createFirMemory(pattern);
|
||||
}
|
||||
}
|
||||
|
||||
// Erase all marked operations
|
||||
for (auto *op : opsToErase) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Erasing operation: " << *op << " number of uses:"
|
||||
<< "\n");
|
||||
op->dropAllUses();
|
||||
op->erase();
|
||||
}
|
||||
opsToErase.clear();
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
// RUN: circt-opt %s --seq-reg-of-vec-to-mem | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: hw.module private @complex_mem
|
||||
hw.module private @complex_mem(in %CLK : i1, in %D : i46, in %ADR : i13, in %WE : i1, in %ME : i1, out Q : i46) {
|
||||
%true = hw.constant true
|
||||
%c0_i46 = hw.constant 0 : i46
|
||||
%0 = comb.xor %WE, %true : i1
|
||||
%1 = comb.and %ME, %0 : i1
|
||||
%2 = hw.array_get %mem_core[%ADR] : !hw.array<8192xi46>, i13
|
||||
%3 = comb.xor %1, %true : i1
|
||||
%4 = comb.mux %3, %c0_i46, %2 : i46
|
||||
%5 = seq.to_clock %CLK
|
||||
%6 = comb.mux %1, %4, %Q_int : i46
|
||||
%Q_int = seq.firreg %6 clock %5 : i46
|
||||
// NOTE: The transformation cannot identify the memory read enable signal.
|
||||
%7 = comb.and %ME, %WE : i1
|
||||
%8 = hw.array_inject %mem_core[%ADR], %D : !hw.array<8192xi46>, i13
|
||||
%9 = comb.mux %7, %8, %mem_core : !hw.array<8192xi46>
|
||||
%mem_core = seq.firreg %9 clock %5 : !hw.array<8192xi46>
|
||||
hw.output %Q_int : i46
|
||||
}
|
||||
|
||||
// CHECK: %[[clock:.+]] = seq.to_clock %CLK
|
||||
// CHECK: %[[V6:.+]] = comb.and %ME, %WE : i1
|
||||
// CHECK: %mem = seq.firmem 0, 1, undefined, undefined : <8192 x 46, mask 1>
|
||||
// CHECK: %[[READ:.+]] = seq.firmem.read_port %mem[%ADR], clock %[[clock]] enable %true
|
||||
// CHECK: seq.firmem.write_port %mem[%ADR] = %D, clock %[[clock]] enable %[[V6]]
|
||||
// CHECK-NOT: seq.firreg %{{.*}} : !hw.array<8192xi46>
|
||||
// CHECK-NOT: hw.array_get
|
||||
// CHECK-NOT: hw.array_inject
|
||||
|
||||
// Simple test case
|
||||
// CHECK-LABEL: hw.module @simple_mem
|
||||
hw.module @simple_mem(in %clk : i1, in %addr : i2, in %data : i8, in %we : i1, out out : i8) {
|
||||
%clock = seq.to_clock %clk
|
||||
%true = hw.constant true
|
||||
%read = hw.array_get %mem[%addr] : !hw.array<4xi8>, i2
|
||||
%write = hw.array_inject %mem[%addr], %data : !hw.array<4xi8>, i2
|
||||
%next = comb.mux %we, %write, %mem : !hw.array<4xi8>
|
||||
%mem = seq.firreg %next clock %clock : !hw.array<4xi8>
|
||||
hw.output %read : i8
|
||||
}
|
||||
|
||||
// CHECK: %[[clock:.+]] = seq.to_clock %clk
|
||||
// CHECK: %mem = seq.firmem 0, 1, undefined, undefined : <4 x 8, mask 1>
|
||||
// CHECK: %[[READ:.+]] = seq.firmem.read_port %mem[%addr], clock %[[clock]] enable %true
|
||||
// CHECK: seq.firmem.write_port %mem[%addr] = %data, clock %[[clock]] enable %we
|
||||
// CHECK: hw.output %[[READ]] : i8
|
||||
// CHECK-NOT: seq.firreg %{{.*}} : !hw.array<8192xi46>
|
||||
// CHECK-NOT: hw.array_get
|
||||
// CHECK-NOT: hw.array_inject
|
||||
|
||||
// Test that transformation is skipped when mux has multiple uses
|
||||
// CHECK-LABEL: hw.module @shared_mux_test(
|
||||
hw.module @shared_mux_test(in %clk: i1, in %addr: i2, in %data: i8, in %we: i1, out other_out: !hw.array<4xi8>) {
|
||||
%clock = seq.to_clock %clk
|
||||
%write = hw.array_inject %mem[%addr], %data : !hw.array<4xi8>, i2
|
||||
%next = comb.mux %we, %write, %mem : !hw.array<4xi8>
|
||||
// CHECK: %mem = seq.firreg %{{.*}} clock %{{.*}} : !hw.array<4xi8>
|
||||
%mem = seq.firreg %next clock %clock : !hw.array<4xi8>
|
||||
|
||||
// Mux result used elsewhere - should prevent transformation
|
||||
hw.output %next : !hw.array<4xi8>
|
||||
}
|
||||
|
||||
// Test that transformation is skipped when register has multiple uses
|
||||
// CHECK-LABEL: hw.module @shared_reg_test(
|
||||
hw.module @shared_reg_test(in %clk: i1, in %addr: i2, in %data: i8, in %we: i1, out reg_out: !hw.array<4xi8>, out read_out: i8) {
|
||||
%clock = seq.to_clock %clk
|
||||
%read = hw.array_get %mem[%addr] : !hw.array<4xi8>, i2
|
||||
%write = hw.array_inject %mem[%addr], %data : !hw.array<4xi8>, i2
|
||||
%next = comb.mux %we, %write, %mem : !hw.array<4xi8>
|
||||
// CHECK: %mem = seq.firreg %{{.*}} clock %{{.*}} : !hw.array<4xi8>
|
||||
%mem = seq.firreg %next clock %clock : !hw.array<4xi8>
|
||||
|
||||
// Register used elsewhere - should prevent transformation
|
||||
hw.output %mem, %read : !hw.array<4xi8>, i8
|
||||
}
|
Loading…
Reference in New Issue