[HWToLLVM] Add lowering support for 'hw.array_inject' op (#8774)

This commit is contained in:
Michael 2025-07-30 11:31:23 +03:00 committed by GitHub
parent 9e27c4ddae
commit dcb2d92b98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 128 additions and 5 deletions

View File

@ -135,6 +135,67 @@ struct StructExtractOpConversion
};
} // namespace
namespace {
/// Convert an ArrayInjectOp to the LLVM dialect.
/// Pattern: array_inject(input, element, index) =>
/// store(gep(store(input, alloca), zext(index)), element)
/// load(alloca)
struct ArrayInjectOpConversion
: public ConvertOpToLLVMPattern<hw::ArrayInjectOp> {
using ConvertOpToLLVMPattern<hw::ArrayInjectOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(hw::ArrayInjectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputType = cast<hw::ArrayType>(op.getInput().getType());
auto oldArrTy = adaptor.getInput().getType();
auto newArrTy = oldArrTy;
const size_t arrElems = inputType.getNumElements();
if (arrElems == 0) {
rewriter.replaceOp(op, adaptor.getInput());
return success();
}
auto oneC = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
auto zextIndex = zextByOne(op->getLoc(), rewriter, op.getIndex());
Value arrPtr;
if (arrElems == 1 || !llvm::isPowerOf2_64(arrElems)) {
// Clamp index to prevent OOB access. We add an extra element to the
// array so that OOB access modifies this element, leaving the original
// array intact.
auto maxIndex = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), zextIndex.getType(),
rewriter.getI32IntegerAttr(arrElems));
zextIndex =
rewriter.create<LLVM::UMinOp>(op->getLoc(), zextIndex, maxIndex);
newArrTy = typeConverter->convertType(
hw::ArrayType::get(inputType.getElementType(), arrElems + 1));
arrPtr = rewriter.create<LLVM::AllocaOp>(
op->getLoc(), LLVM::LLVMPointerType::get(rewriter.getContext()),
newArrTy, oneC, /*alignment=*/4);
} else {
arrPtr = rewriter.create<LLVM::AllocaOp>(
op->getLoc(), LLVM::LLVMPointerType::get(rewriter.getContext()),
newArrTy, oneC, /*alignment=*/4);
}
rewriter.create<LLVM::StoreOp>(op->getLoc(), adaptor.getInput(), arrPtr);
auto gep = rewriter.create<LLVM::GEPOp>(
op->getLoc(), LLVM::LLVMPointerType::get(rewriter.getContext()),
newArrTy, arrPtr, ArrayRef<LLVM::GEPArg>{0, zextIndex});
rewriter.create<LLVM::StoreOp>(op->getLoc(), adaptor.getElement(), gep);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, oldArrTy, arrPtr);
return success();
}
};
} // namespace
namespace {
/// Convert an ArrayGetOp to the LLVM dialect.
/// Pattern: array_get(input, index) =>
@ -660,9 +721,10 @@ void circt::populateHWToLLVMConversionPatterns(
patterns.add<BitcastOpConversion>(converter);
// Extraction operation conversion patterns.
patterns.add<ArrayGetOpConversion, ArraySliceOpConversion,
ArrayConcatOpConversion, StructExplodeOpConversion,
StructExtractOpConversion, StructInjectOpConversion>(converter);
patterns.add<ArrayInjectOpConversion, ArrayGetOpConversion,
ArraySliceOpConversion, ArrayConcatOpConversion,
StructExplodeOpConversion, StructExtractOpConversion,
StructInjectOpConversion>(converter);
}
void circt::populateHWToLLVMTypeConversions(LLVMTypeConverter &converter) {

View File

@ -57,7 +57,6 @@ func.func @convertArray(%arg0 : i1, %arg1: !hw.array<2xi32>, %arg2: i32, %arg3:
// CHECK-NEXT: llvm.insertvalue %[[E4]], %[[I3]][3] : !llvm.array<4 x i32>
%2 = hw.array_concat %arg1, %arg1 : !hw.array<2xi32>, !hw.array<2xi32>
// CHECK-NEXT: [[V6:%.*]] = llvm.mlir.undef : !llvm.array<4 x i32>
// CHECK-NEXT: [[V7:%.*]] = llvm.insertvalue %arg5, [[V6]][0] : !llvm.array<4 x i32>
// CHECK-NEXT: [[V8:%.*]] = llvm.insertvalue %arg4, [[V7]][1] : !llvm.array<4 x i32>
@ -68,6 +67,59 @@ func.func @convertArray(%arg0 : i1, %arg1: !hw.array<2xi32>, %arg2: i32, %arg3:
return
}
// CHECK-LABEL: @convertArrayInject
func.func @convertArrayInject(
%arg0: !hw.array<0xi32>, %arg1: !hw.array<1xi32>, %arg2: !hw.array<2xi32>, %arg3: !hw.array<5xi32>,
%arg4: i32, %arg5: i64, %arg6: i1, %arg7: i3, %arg8: i64) {
// CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %arg0 : !hw.array<0xi32> to !llvm.array<0 x i32>
// CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %arg1 : !hw.array<1xi32> to !llvm.array<1 x i32>
// CHECK-DAG: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %arg2 : !hw.array<2xi32> to !llvm.array<2 x i32>
// CHECK-DAG: %[[CAST3:.*]] = builtin.unrealized_conversion_cast %arg3 : !hw.array<5xi32> to !llvm.array<5 x i32>
%0 = hw.array_inject %arg0[%arg5], %arg4 : !hw.array<0xi32>, i64
// CHECK-NEXT: %[[ONE1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[ZEXT1:.*]] = llvm.zext %arg6 : i1 to i2
// CHECK-NEXT: %[[MAX1:.*]] = llvm.mlir.constant(1 : i32) : i2
// CHECK-NEXT: %[[UMIN1:.*]] = llvm.intr.umin(%[[ZEXT1]], %[[MAX1]]) : (i2, i2) -> i2
// CHECK-NEXT: %[[ALLOCA1:.*]] = llvm.alloca %[[ONE1]] x !llvm.array<2 x i32> {alignment = 4 : i64} : (i32) -> !llvm.ptr
// CHECK-NEXT: llvm.store %[[CAST1]], %[[ALLOCA1]] : !llvm.array<1 x i32>, !llvm.ptr
// CHECK-NEXT: %[[GEP1:.*]] = llvm.getelementptr %[[ALLOCA1]][0, %[[UMIN1]]] : (!llvm.ptr, i2) -> !llvm.ptr, !llvm.array<2 x i32>
// CHECK-NEXT: llvm.store %arg4, %[[GEP1]] : i32, !llvm.ptr
// CHECK-NEXT: llvm.load %[[ALLOCA1]] : !llvm.ptr -> !llvm.array<1 x i32>
%1 = hw.array_inject %arg1[%arg6], %arg4 : !hw.array<1xi32>, i1
// CHECK-NEXT: %[[ONE2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[ZEXT2:.*]] = llvm.zext %arg6 : i1 to i2
// CHECK-NEXT: %[[ALLOCA2:.*]] = llvm.alloca %[[ONE2]] x !llvm.array<2 x i32> {alignment = 4 : i64} : (i32) -> !llvm.ptr
// CHECK-NEXT: llvm.store %[[CAST2]], %[[ALLOCA2]] : !llvm.array<2 x i32>, !llvm.ptr
// CHECK-NEXT: %[[GEP2:.*]] = llvm.getelementptr %[[ALLOCA2]][0, %[[ZEXT2]]] : (!llvm.ptr, i2) -> !llvm.ptr, !llvm.array<2 x i32>
// CHECK-NEXT: llvm.store %arg4, %[[GEP2]] : i32, !llvm.ptr
// CHECK-NEXT: llvm.load %[[ALLOCA2]] : !llvm.ptr -> !llvm.array<2 x i32>
%2 = hw.array_inject %arg2[%arg6], %arg4 : !hw.array<2xi32>, i1
// CHECK-NEXT: %[[ONE3:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[ZEXT3:.*]] = llvm.zext %arg7 : i3 to i4
// CHECK-NEXT: %[[MAX3:.*]] = llvm.mlir.constant(5 : i32) : i4
// CHECK-NEXT: %[[UMIN3:.*]] = llvm.intr.umin(%[[ZEXT3]], %[[MAX3]]) : (i4, i4) -> i4
// CHECK-NEXT: %[[ALLOCA3:.*]] = llvm.alloca %[[ONE3]] x !llvm.array<6 x i32> {alignment = 4 : i64} : (i32) -> !llvm.ptr
// CHECK-NEXT: llvm.store %[[CAST3]], %[[ALLOCA3]] : !llvm.array<5 x i32>, !llvm.ptr
// CHECK-NEXT: %[[GEP3:.*]] = llvm.getelementptr %[[ALLOCA3]][0, %[[UMIN3]]] : (!llvm.ptr, i4) -> !llvm.ptr, !llvm.array<6 x i32>
// CHECK-NEXT: llvm.store %arg4, %[[GEP3]] : i32, !llvm.ptr
// CHECK-NEXT: llvm.load %[[ALLOCA3]] : !llvm.ptr -> !llvm.array<5 x i32>
%3 = hw.array_inject %arg3[%arg7], %arg4 : !hw.array<5xi32>, i3
// CHECK-NEXT: %[[ONE4:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[ALLOCA4:.*]] = llvm.alloca %[[ONE4]] x !llvm.array<0 x i32> {alignment = 4 : i64} : (i32) -> !llvm.ptr
// CHECK-NEXT: llvm.store %[[CAST0]], %[[ALLOCA4]] : !llvm.array<0 x i32>, !llvm.ptr
// CHECK-NEXT: %[[ZEXT4:.*]] = llvm.zext %arg8 : i64 to i65
// CHECK-NEXT: %[[GEP4:.*]] = llvm.getelementptr %[[ALLOCA4]][0, %[[ZEXT4]]] : (!llvm.ptr, i65) -> !llvm.ptr, !llvm.array<0 x i32>
// CHECK-NEXT: llvm.load %[[GEP4]] : !llvm.ptr -> i32
%4 = hw.array_get %0[%arg8] : !hw.array<0xi32>, i64
return
}
// CHECK: llvm.mlir.global internal constant @[[GLOB1:.+]](dense<[1, 0]> : tensor<2xi32>) {addr_space = 0 : i32} : !llvm.array<2 x i32>
// CHECK: llvm.mlir.global internal constant @[[GLOB2:.+]](dense<{{[[][[]}}3, 2], [1, 0{{[]][]]}}> : tensor<2x2xi32>) {addr_space = 0 : i32} : !llvm.array<2 x array<2 x i32>>
@ -90,7 +142,7 @@ func.func @convertArray(%arg0 : i1, %arg1: !hw.array<2xi32>, %arg2: i32, %arg3:
// CHECK-NEXT: }
// CHECK: @convertConstArray
func.func @convertConstArray(%arg0 : i1) {
func.func @convertConstArray(%arg0 : i1, %arg1 : i32) {
// COM: Test: simple constant array converted to constant global
// CHECK: %[[VAL_2:.*]] = llvm.mlir.addressof @[[GLOB1]] : !llvm.ptr
// CHECK-NEXT: %[[VAL_3:.*]] = llvm.load %[[VAL_2]] : !llvm.ptr -> !llvm.array<2 x i32>
@ -117,6 +169,15 @@ func.func @convertConstArray(%arg0 : i1) {
// CHECK-NEXT: {{%.+}} = llvm.load %[[VAL_9]] : !llvm.ptr -> !llvm.array<2 x struct<(i1, i32)>>
%4 = hw.aggregate_constant [[0 : i32, 1 : i1], [2 : i32, 0 : i1]] : !hw.array<2x!hw.struct<a: i32, b: i1>>
// CHECK-NEXT: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[ZEXT0:.*]] = llvm.zext %arg0 : i1 to i2
// CHECK-NEXT: %[[ALLOCA0:.*]] = llvm.alloca %[[ONE]] x !llvm.array<2 x i32> {alignment = 4 : i64} : (i32) -> !llvm.ptr
// CHECK-NEXT: llvm.store %[[VAL_3]], %[[ALLOCA0]] : !llvm.array<2 x i32>, !llvm.ptr
// CHECK-NEXT: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA0]][0, %[[ZEXT0]]] : (!llvm.ptr, i2) -> !llvm.ptr, !llvm.array<2 x i32>
// CHECK-NEXT: llvm.store %arg1, %[[GEP0]] : i32, !llvm.ptr
// CHECK-NEXT: %{{.+}} = llvm.load %[[ALLOCA0]] : !llvm.ptr -> !llvm.array<2 x i32>
%5 = hw.array_inject %0[%arg0], %arg1 : !hw.array<2xi32>, i1
return
}