[HW][HWAggregateToComb] Add support for hw.array_inject operation in HWAggregateToComb pass (#8788)

This patch implements lowering of hw.array_inject operations to combinational
logic in the HWAggregateToComb transformation pass.

The implementation creates a 2D array where each row represents the result
of injecting the new value at a specific index. A multiplexer then selects
the appropriate row based on the injection index.
This commit is contained in:
Hideto Ueno 2025-07-28 11:34:15 -07:00 committed by GitHub
parent 2f0ca3c18e
commit 108965f1e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 93 additions and 7 deletions

View File

@ -5,7 +5,7 @@
// RUN: circt-lec %t.mlir %s -c1=array -c2=array --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ARRAY
// COMB_ARRAY: c1 == c2
hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, in %sel1: i2, in %sel2: i2, out out1: i2, out out2: i2) {
hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, in %sel1: i2, in %sel2: i2, in %sel3: i2, in %val: i2, out out1: i2, out out2: i2) {
%0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i2
%1 = hw.array_get %0[%sel1] : !hw.array<4xi2>, i2
%2 = hw.array_create %arg0, %arg1, %arg2 : i2
@ -13,8 +13,11 @@ hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, in %sel
// NOTE: If the index is out of bounds, the result value is undefined.
// In LEC such value is lowered into unbounded SMT variable and cause
// the LEC to fail. So just asssume that the index is in bounds.
%inbound = comb.icmp ult %sel2, %c3_i2 : i2
verif.assume %inbound : i1
%3 = hw.array_get %2[%sel2] : !hw.array<3xi2>, i2
%inbound_sel_2 = comb.icmp ult %sel2, %c3_i2 : i2
verif.assume %inbound_sel_2 : i1
%inbound_sel_3 = comb.icmp ult %sel3, %c3_i2 : i2
verif.assume %inbound_sel_3 : i1
%inject = hw.array_inject %2[%sel3], %val: !hw.array<3xi2>, i2
%3 = hw.array_get %inject[%sel2] : !hw.array<3xi2>, i2
hw.output %1, %3 : i2, i2
}

View File

@ -115,6 +115,66 @@ struct HWArrayGetOpConversion : OpConversionPattern<hw::ArrayGetOp> {
}
};
struct HWArrayInjectOpConversion : OpConversionPattern<hw::ArrayInjectOp> {
using OpConversionPattern<hw::ArrayInjectOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(hw::ArrayInjectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto arrayType = cast<hw::ArrayType>(op.getInput().getType());
auto elemType = arrayType.getElementType();
auto numElements = arrayType.getNumElements();
auto elemWidth = hw::getBitWidth(elemType);
if (elemWidth < 0)
return rewriter.notifyMatchFailure(op.getLoc(), "unknown element width");
Location loc = op.getLoc();
// Extract all elements from the input array
SmallVector<Value> originalElements;
auto inputArray = adaptor.getInput();
for (size_t i = 0; i < numElements; ++i) {
originalElements.push_back(rewriter.createOrFold<comb::ExtractOp>(
loc, inputArray, i * elemWidth, elemWidth));
}
// Create 2D array: each row represents what the array would look like
// if injection happened at that specific index
SmallVector<Value> arrayRows;
arrayRows.reserve(numElements);
for (int injectIdx = numElements - 1; injectIdx >= 0; --injectIdx) {
SmallVector<Value> rowElements;
rowElements.reserve(numElements);
// Build the row: array[n-1], array[n-2], ..., but replace element at
// injectIdx with newVal
for (int originalIdx = numElements - 1; originalIdx >= 0; --originalIdx) {
if (originalIdx == injectIdx) {
rowElements.push_back(adaptor.getElement());
} else {
rowElements.push_back(originalElements[originalIdx]);
}
}
// Concatenate elements to form this row
Value row = rewriter.create<hw::ArrayCreateOp>(loc, rowElements);
arrayRows.push_back(row);
}
// Create the 2D array by concatenating all rows
// arrayRows[0] corresponds to injection at index 0
// arrayRows[1] corresponds to injection at index 1, etc.
Value array2D = rewriter.create<hw::ArrayCreateOp>(loc, arrayRows);
// Create array_get operation to select the row
auto arrayGetOp =
rewriter.create<hw::ArrayGetOp>(loc, array2D, adaptor.getIndex());
rewriter.replaceOp(op, arrayGetOp);
return success();
}
};
/// A type converter is needed to perform the in-flight materialization of
/// aggregate types to integer types.
class AggregateTypeConverter : public TypeConverter {
@ -152,8 +212,8 @@ static void populateHWAggregateToCombOpConversionPatterns(
patterns.add<HWArrayGetOpConversion,
HWArrayCreateLikeOpConversion<hw::ArrayCreateOp>,
HWArrayCreateLikeOpConversion<hw::ArrayConcatOp>,
HWAggregateConstantOpConversion>(typeConverter,
patterns.getContext());
HWAggregateConstantOpConversion, HWArrayInjectOpConversion>(
typeConverter, patterns.getContext());
}
namespace {
@ -169,7 +229,7 @@ void HWAggregateToCombPass::runOnOperation() {
// TODO: Add ArraySliceOp and struct operatons as well.
target.addIllegalOp<hw::ArrayGetOp, hw::ArrayCreateOp, hw::ArrayConcatOp,
hw::AggregateConstantOp>();
hw::AggregateConstantOp, hw::ArrayInjectOp>();
target.addLegalDialect<hw::HWDialect, comb::CombDialect>();

View File

@ -57,3 +57,26 @@ hw.module @array(in %arg0: i2, in %arg1: i2, in %arg2: i2, in %arg3: i2, out out
// CHECK-NEXT: hw.output %[[BITCAST]], %[[MUX_2]]
hw.output %0, %1 : !hw.array<4xi2>, i2
}
// CHECK-LABEL: @array_inject(
hw.module @array_inject(in %in: !hw.array<3xi2>, in %sel: i2, in %val: i2, out out_inject: !hw.array<3xi2>) {
// CHECK-NEXT: %[[in_bitcast:.+]] = hw.bitcast %in
// CHECK-NEXT: %[[element_0:.+]] = comb.extract %[[in_bitcast]] from 0 : (i6) -> i2
// CHECK-NEXT: %[[element_1:.+]] = comb.extract %[[in_bitcast]] from 2 : (i6) -> i2
// CHECK-NEXT: %[[element_2:.+]] = comb.extract %[[in_bitcast]] from 4 : (i6) -> i2
// CHECK-NEXT: %[[inject_2:.+]] = comb.concat %val, %[[element_1]], %[[element_0]]
// CHECK-NEXT: %[[inject_1:.+]] = comb.concat %[[element_2]], %val, %[[element_0]]
// CHECK-NEXT: %[[inject_0:.+]] = comb.concat %[[element_2]], %[[element_1]], %val
// CHECK-NEXT: %[[array_2d:.+]] = comb.concat %[[inject_2]], %[[inject_1]], %[[inject_0]]
// CHECK-NEXT: %[[array_0:.+]] = comb.extract %[[array_2d]] from 0 : (i18) -> i6
// CHECK-NEXT: %[[array_1:.+]] = comb.extract %[[array_2d]] from 6 : (i18) -> i6
// CHECK-NEXT: %[[array_2:.+]] = comb.extract %[[array_2d]] from 12 : (i18) -> i6
// CHECK-NEXT: %[[sel_0:.+]] = comb.extract %sel from 0 : (i2) -> i1
// CHECK-NEXT: %[[sel_1:.+]] = comb.extract %sel from 1 : (i2) -> i1
// CHECK-NEXT: %[[mux_0:.+]] = comb.mux %[[sel_0]], %[[array_1]], %[[array_0]]
// CHECK-NEXT: %[[mux_1:.+]] = comb.mux %[[sel_1]], %[[array_2]], %[[mux_0]]
// CHECK-NEXT: %[[result:.+]] = hw.bitcast %[[mux_1]]
// CHECK-NEXT: hw.output %[[result]]
%0 = hw.array_inject %in[%sel], %val : !hw.array<3xi2>, i2
hw.output %0 : !hw.array<3xi2>
}