[HWToSMT] Return an unbound value for OOB hw.array_inject (#8794)

This is a follow up to 186dcc8b0e to fix OOB behavior.

h/t @fzi-hielscher for pointing out the issue
This commit is contained in:
Hideto Ueno 2025-07-29 17:00:06 -07:00 committed by GitHub
parent a11a834d96
commit 9e27c4ddae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 5 deletions

View File

@ -193,6 +193,8 @@ def ArrayInjectOp : HWOp<"array_inject", [
`element` value, and returns the updated array value as a result. The index
must be exactly `ceil(log2(length(input)))` bits wide. The element type
must match the input array's element type.
If the `index` is out of bounds, the result is undefined.
}];
let arguments = (ins
ArrayType:$input,

View File

@ -169,6 +169,7 @@ struct ArrayInjectOpConversion : OpConversionPattern<ArrayInjectOp> {
if (!arrType)
return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");
Value oobVal = mlir::smt::DeclareFunOp::create(rewriter, loc, arrType);
// Check if the index is within bounds
Value numElementsVal = mlir::smt::BVConstantOp::create(
rewriter, loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
@ -181,10 +182,8 @@ struct ArrayInjectOpConversion : OpConversionPattern<ArrayInjectOp> {
rewriter, loc, adaptor.getInput(), adaptor.getIndex(),
adaptor.getElement());
// Return the original array if out of bounds, otherwise return the new
// array
rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, stored,
adaptor.getInput());
// Return unbounded array if out of bounds
rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, stored, oobVal);
return success();
}
};

View File

@ -29,10 +29,11 @@ hw.module @modB(in %in: i32, out out: i32) {
// CHECK-LABEL: func.func @inject
// CHECK-SAME: (%[[ARR:.+]]: !smt.array<[!smt.bv<2> -> !smt.bv<8>]>, %[[IDX:.+]]: !smt.bv<2>, %[[VAL:.+]]: !smt.bv<8>)
hw.module @inject(in %arr: !hw.array<3xi8>, in %index: i2, in %v: i8, out out: !hw.array<3xi8>) {
// CHECK-NEXT: %[[OOB:.+]] = smt.declare_fun : !smt.array<[!smt.bv<2> -> !smt.bv<8>]>
// CHECK-NEXT: %[[C2:.+]] = smt.bv.constant #smt.bv<-2> : !smt.bv<2>
// CHECK-NEXT: %[[CMP:.+]] = smt.bv.cmp ule %[[IDX]], %[[C2]] : !smt.bv<2>
// CHECK-NEXT: %[[STORED:.+]] = smt.array.store %[[ARR]][%[[IDX]]], %[[VAL]] : !smt.array<[!smt.bv<2> -> !smt.bv<8>]>
// CHECK-NEXT: %[[RESULT:.+]] = smt.ite %[[CMP]], %[[STORED]], %[[ARR]] : !smt.array<[!smt.bv<2> -> !smt.bv<8>]>
// CHECK-NEXT: %[[RESULT:.+]] = smt.ite %[[CMP]], %[[STORED]], %[[OOB]] : !smt.array<[!smt.bv<2> -> !smt.bv<8>]>
// CHECK-NEXT: return %[[RESULT]] : !smt.array<[!smt.bv<2> -> !smt.bv<8>]>
%arr_injected = hw.array_inject %arr[%index], %v : !hw.array<3xi8>, i2
hw.output %arr_injected : !hw.array<3xi8>