[mlir][Vector] Revisit VectorToSCF.

Vector to SCF conversion still had issues due to the interaction with the natural alignment derived by the LLVM data layout. One traditional workaround is to allocate aligned. However, this does not always work for vector sizes that are non-powers of 2.

This revision implements a more portable mechanism where the intermediate allocation is always a memref of elemental vector type. AllocOp is extended to use the natural LLVM DataLayout alignment for non-scalar types, when the alignment is not specified in the first place.

An integration test is added that exercises the transfer to scf.for + scalar lowering with a 5x5 transposition.

Differential Revision: https://reviews.llvm.org/D87150
This commit is contained in:
Nicolas Vasilache 2020-09-04 11:43:00 -04:00
parent 7634c64b61
commit 8d64df9f13
9 changed files with 186 additions and 53 deletions

View File

@ -16,10 +16,16 @@ namespace intrinsics {
using vector_broadcast = ValueBuilder<vector::BroadcastOp>;
using vector_contract = ValueBuilder<vector::ContractionOp>;
using vector_insert = ValueBuilder<vector::InsertOp>;
using vector_fma = ValueBuilder<vector::FMAOp>;
using vector_extract = ValueBuilder<vector::ExtractOp>;
using vector_extractelement = ValueBuilder<vector::ExtractElementOp>;
using vector_extract_element = ValueBuilder<vector::ExtractElementOp>;
using vector_extract_slices = ValueBuilder<vector::ExtractSlicesOp>;
using vector_extract_strided_slice =
ValueBuilder<vector::ExtractStridedSliceOp>;
using vector_fma = ValueBuilder<vector::FMAOp>;
using vector_insert = ValueBuilder<vector::InsertOp>;
using vector_insert_element = ValueBuilder<vector::InsertElementOp>;
using vector_insert_slices = ValueBuilder<vector::InsertSlicesOp>;
using vector_insert_strided_slice = ValueBuilder<vector::InsertStridedSliceOp>;
using vector_matmul = ValueBuilder<vector::MatmulOp>;
using vector_outerproduct = ValueBuilder<vector::OuterProductOp>;
using vector_print = OperationBuilder<vector::PrintOp>;
@ -27,11 +33,6 @@ using vector_transfer_read = ValueBuilder<vector::TransferReadOp>;
using vector_transfer_write = OperationBuilder<vector::TransferWriteOp>;
using vector_transpose = ValueBuilder<vector::TransposeOp>;
using vector_type_cast = ValueBuilder<vector::TypeCastOp>;
using vector_extract_slices = ValueBuilder<vector::ExtractSlicesOp>;
using vector_insert_slices = ValueBuilder<vector::InsertSlicesOp>;
using vector_extract_strided_slice =
ValueBuilder<vector::ExtractStridedSliceOp>;
using vector_insert_strided_slice = ValueBuilder<vector::InsertStridedSliceOp>;
} // namespace intrinsics
} // namespace edsc

View File

@ -348,15 +348,21 @@ def Vector_ExtractElementOp :
%1 = vector.extractelement %0[%c : i32]: vector<16xf32>
```
}];
let assemblyFormat = [{
$vector `[` $position `:` type($position) `]` attr-dict `:` type($vector)
}];
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value source, "
"int64_t position">,
OpBuilder<
"OpBuilder &builder, OperationState &result, Value source, "
"Value position">];
let extraClassDeclaration = [{
VectorType getVectorType() {
return vector().getType().cast<VectorType>();
}
}];
let assemblyFormat = [{
$vector `[` $position `:` type($position) `]` attr-dict `:` type($vector)
}];
}
def Vector_ExtractOp :
@ -508,6 +514,17 @@ def Vector_InsertElementOp :
%1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
```
}];
let assemblyFormat = [{
$source `,` $dest `[` $position `:` type($position) `]` attr-dict `:`
type($result)
}];
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value source, "
"Value dest, int64_t position">,
OpBuilder<
"OpBuilder &builder, OperationState &result, Value source, "
"Value dest, Value position">];
let extraClassDeclaration = [{
Type getSourceType() { return source().getType(); }
VectorType getDestVectorType() {
@ -515,10 +532,6 @@ def Vector_InsertElementOp :
}
}];
let assemblyFormat = [{
$source `,` $dest `[` $position `:` type($position) `]` attr-dict `:`
type($result)
}];
}
def Vector_InsertOp :

View File

@ -0,0 +1,81 @@
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm | \
// RUN: mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext,%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
#map0 = affine_map<(d0, d1) -> (d1, d0)>
func @print_memref_f32(memref<*xf32>)
func @alloc_2d_filled_f32(%arg0: index, %arg1: index) -> memref<?x?xf32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c10 = constant 10 : index
%c100 = constant 100 : index
%0 = alloc(%arg0, %arg1) : memref<?x?xf32>
scf.for %arg5 = %c0 to %arg0 step %c1 {
scf.for %arg6 = %c0 to %arg1 step %c1 {
%arg66 = muli %arg6, %c100 : index
%tmp1 = addi %arg5, %arg66 : index
%tmp2 = index_cast %tmp1 : index to i32
%tmp3 = sitofp %tmp2 : i32 to f32
store %tmp3, %0[%arg5, %arg6] : memref<?x?xf32>
}
}
return %0 : memref<?x?xf32>
}
func @main() {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c6 = constant 6 : index
%cst = constant -4.2e+01 : f32
%0 = call @alloc_2d_filled_f32(%c6, %c6) : (index, index) -> memref<?x?xf32>
%converted = memref_cast %0 : memref<?x?xf32> to memref<*xf32>
call @print_memref_f32(%converted): (memref<*xf32>) -> ()
// CHECK: Unranked{{.*}}data =
// CHECK: [
// CHECK-SAME: [0, 100, 200, 300, 400, 500],
// CHECK-NEXT: [1, 101, 201, 301, 401, 501],
// CHECK-NEXT: [2, 102, 202, 302, 402, 502],
// CHECK-NEXT: [3, 103, 203, 303, 403, 503],
// CHECK-NEXT: [4, 104, 204, 304, 404, 504],
// CHECK-NEXT: [5, 105, 205, 305, 405, 505]]
%init = vector.transfer_read %0[%c1, %c1], %cst : memref<?x?xf32>, vector<5x5xf32>
vector.print %init : vector<5x5xf32>
// 5x5 block rooted at {1, 1}
// CHECK-NEXT: ( ( 101, 201, 301, 401, 501 ),
// CHECK-SAME: ( 102, 202, 302, 402, 502 ),
// CHECK-SAME: ( 103, 203, 303, 403, 503 ),
// CHECK-SAME: ( 104, 204, 304, 404, 504 ),
// CHECK-SAME: ( 105, 205, 305, 405, 505 ) )
%1 = vector.transfer_read %0[%c1, %c1], %cst {permutation_map = #map0} : memref<?x?xf32>, vector<5x5xf32>
vector.print %1 : vector<5x5xf32>
// Transposed 5x5 block rooted @{1, 1} in memory.
// CHECK-NEXT: ( ( 101, 102, 103, 104, 105 ),
// CHECK-SAME: ( 201, 202, 203, 204, 205 ),
// CHECK-SAME: ( 301, 302, 303, 304, 305 ),
// CHECK-SAME: ( 401, 402, 403, 404, 405 ),
// CHECK-SAME: ( 501, 502, 503, 504, 505 ) )
// Transpose-write the transposed 5x5 block @{0, 0} in memory.
vector.transfer_write %1, %0[%c0, %c0] {permutation_map = #map0} : vector<5x5xf32>, memref<?x?xf32>
%2 = vector.transfer_read %0[%c1, %c1], %cst : memref<?x?xf32>, vector<5x5xf32>
vector.print %2 : vector<5x5xf32>
// New 5x5 block rooted @{1, 1} in memory.
// Here we expect the boundaries from the original data
// (i.e. last row: 105 .. 505, last col: 501 .. 505)
// and the 4x4 subblock 202 .. 505 rooted @{0, 0} in the vector
// CHECK-NEXT: ( ( 202, 302, 402, 502, 501 ),
// CHECK-SAME: ( 203, 303, 403, 503, 502 ),
// CHECK-SAME: ( 204, 304, 404, 504, 503 ),
// CHECK-SAME: ( 205, 305, 405, 505, 504 ),
// CHECK-SAME: ( 105, 205, 305, 405, 505 ) )
dealloc %0 : memref<?x?xf32>
return
}

View File

@ -1893,11 +1893,17 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
// Adjust the allocation size to consider alignment.
if (Optional<uint64_t> alignment = allocOp.alignment()) {
accessAlignment = createIndexConstant(rewriter, loc, *alignment);
cumulativeSize = rewriter.create<LLVM::SubOp>(
loc,
rewriter.create<LLVM::AddOp>(loc, cumulativeSize, accessAlignment),
one);
} else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
// In the case where no alignment is specified, we may want to override
// `malloc's` behavior. `malloc` typically aligns at the size of the
// biggest scalar on a target HW. For non-scalars, use the natural
// alignment of the LLVM type given by the LLVM DataLayout.
accessAlignment =
this->getSizeInBytes(loc, memRefType.getElementType(), rewriter);
}
if (accessAlignment)
cumulativeSize =
rewriter.create<LLVM::AddOp>(loc, cumulativeSize, accessAlignment);
callArgs.push_back(cumulativeSize);
}
auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFunc);

View File

@ -35,8 +35,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#define ALIGNMENT_SIZE 128
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
@ -234,8 +232,7 @@ static Value setAllocAtFunctionEntry(MemRefType memRefMinorVectorType,
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
assert(scope && "Expected op to be inside automatic allocation scope");
b.setInsertionPointToStart(&scope->getRegion(0).front());
Value res = std_alloca(memRefMinorVectorType, ValueRange{},
b.getI64IntegerAttr(ALIGNMENT_SIZE));
Value res = std_alloca(memRefMinorVectorType);
return res;
}
@ -494,8 +491,10 @@ template <typename TransferOpTy>
MemRefType VectorTransferRewriter<TransferOpTy>::tmpMemRefType(
TransferOpTy transfer) const {
auto vectorType = transfer.getVectorType();
return MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
0);
return MemRefType::get(vectorType.getShape().drop_back(),
VectorType::get(vectorType.getShape().take_back(),
vectorType.getElementType()),
{}, 0);
}
/// Lowers TransferReadOp into a combination of:
@ -585,8 +584,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
steps.push_back(std_constant_index(step));
// 2. Emit alloc-copy-load-dealloc.
Value tmp = std_alloc(tmpMemRefType(transfer), ValueRange{},
rewriter.getI64IntegerAttr(ALIGNMENT_SIZE));
Value tmp = setAllocAtFunctionEntry(tmpMemRefType(transfer), transfer);
StdIndexedValue local(tmp);
Value vec = vector_type_cast(tmp);
loopNestBuilder(lbs, ubs, steps, [&](ValueRange loopIvs) {
@ -595,10 +593,15 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
if (coalescedIdx >= 0)
std::swap(ivs.back(), ivs[coalescedIdx]);
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
local(ivs) = remote(clip(transfer, memRefBoundsCapture, ivs));
SmallVector<Value, 8> indices = clip(transfer, memRefBoundsCapture, ivs);
ArrayRef<Value> indicesRef(indices), ivsRef(ivs);
Value pos =
std_index_cast(IntegerType::get(32, op->getContext()), ivsRef.back());
Value vector = vector_insert_element(remote(indicesRef),
local(ivsRef.drop_back()), pos);
local(ivsRef.drop_back()) = vector;
});
Value vectorValue = std_load(vec);
(std_dealloc(tmp)); // vexing parse
// 3. Propagate.
rewriter.replaceOp(op, vectorValue);
@ -667,8 +670,7 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
steps.push_back(std_constant_index(step));
// 2. Emit alloc-store-copy-dealloc.
Value tmp = std_alloc(tmpMemRefType(transfer), ValueRange{},
rewriter.getI64IntegerAttr(ALIGNMENT_SIZE));
Value tmp = setAllocAtFunctionEntry(tmpMemRefType(transfer), transfer);
StdIndexedValue local(tmp);
Value vec = vector_type_cast(tmp);
std_store(vectorValue, vec);
@ -678,10 +680,15 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
if (coalescedIdx >= 0)
std::swap(ivs.back(), ivs[coalescedIdx]);
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
remote(clip(transfer, memRefBoundsCapture, ivs)) = local(ivs);
SmallVector<Value, 8> indices = clip(transfer, memRefBoundsCapture, ivs);
ArrayRef<Value> indicesRef(indices), ivsRef(ivs);
Value pos =
std_index_cast(IntegerType::get(32, op->getContext()), ivsRef.back());
Value scalar = vector_extract_element(local(ivsRef.drop_back()), pos);
remote(indices) = scalar;
});
(std_dealloc(tmp)); // vexing parse...
// 3. Erase.
rewriter.eraseOp(op);
return success();
}

View File

@ -537,6 +537,18 @@ Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
// ExtractElementOp
//===----------------------------------------------------------------------===//
void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
Value source, Value position) {
result.addOperands({source, position});
result.addTypes(source.getType().cast<VectorType>().getElementType());
}
void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
Value source, int64_t position) {
Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
build(builder, result, source, pos);
}
static LogicalResult verify(vector::ExtractElementOp op) {
VectorType vectorType = op.getVectorType();
if (vectorType.getRank() != 1)
@ -1007,6 +1019,18 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
// InsertElementOp
//===----------------------------------------------------------------------===//
void InsertElementOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest, Value position) {
result.addOperands({source, dest, position});
result.addTypes(dest.getType());
}
void InsertElementOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest, int64_t position) {
Value pos = builder.create<ConstantIntOp>(result.location, position, 32);
build(builder, result, source, dest, pos);
}
static LogicalResult verify(InsertElementOp op) {
auto dstVectorType = op.getDestVectorType();
if (dstVectorType.getRank() != 1)

View File

@ -130,8 +130,7 @@ func @aligned_1d_alloc() -> memref<42xf32> {
// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
// CHECK-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64
// CHECK-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one_1]] : !llvm.i64
// CHECK-NEXT: %[[allocsize:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64
// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm.ptr<i8>
// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<i8> to !llvm.ptr<float>
// CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
@ -154,8 +153,7 @@ func @aligned_1d_alloc() -> memref<42xf32> {
// BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// BAREPTR-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
// BAREPTR-NEXT: %[[alignmentMinus1:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64
// BAREPTR-NEXT: %[[allocsize:.*]] = llvm.sub %[[alignmentMinus1]], %[[one_1]] : !llvm.i64
// BAREPTR-NEXT: %[[allocsize:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64
// BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm.ptr<i8>
// BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<i8> to !llvm.ptr<float>
// BAREPTR-NEXT: llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>

View File

@ -19,7 +19,7 @@ func @materialize_read_1d() {
// CHECK: %[[FILTERED1:.*]] = select
// CHECK: {{.*}} = select
// CHECK: %[[FILTERED2:.*]] = select
// CHECK-NEXT: %{{.*}} = load {{.*}}[%[[FILTERED1]], %[[FILTERED2]]] : memref<7x42xf32>
// CHECK: %{{.*}} = load {{.*}}[%[[FILTERED1]], %[[FILTERED2]]] : memref<7x42xf32>
}
}
return
@ -58,6 +58,7 @@ func @materialize_read_1d_partially_specialized(%dyn1 : index, %dyn2 : index, %d
// CHECK-LABEL: func @materialize_read(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
%f0 = constant 0.0: f32
// CHECK-DAG: %[[ALLOC:.*]] = alloca() : memref<5x4xvector<3xf32>>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
@ -68,7 +69,6 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
// CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %{{.*}} {
// CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %{{.*}} {
// CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %{{.*}} step 5 {
// CHECK: %[[ALLOC:.*]] = alloc() {alignment = 128 : i64} : memref<5x4x3xf32>
// CHECK-NEXT: scf.for %[[I4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK-NEXT: scf.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
// CHECK-NEXT: scf.for %[[I6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] {
@ -97,13 +97,15 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
// CHECK-NEXT: {{.*}} = select
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
// CHECK-NEXT: %[[L3:.*]] = select
// CHECK-NEXT: %[[VIDX:.*]] = index_cast %[[I4]]
//
// CHECK-NEXT: {{.*}} = load %{{.*}}[%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : memref<?x?x?x?xf32>
// CHECK-NEXT: store {{.*}}, %[[ALLOC]][%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32>
// CHECK-NEXT: %[[VEC:.*]] = load %[[ALLOC]][%[[I6]], %[[I5]]] : memref<5x4xvector<3xf32>>
// CHECK-NEXT: %[[RVEC:.*]] = vector.insertelement %25, %[[VEC]][%[[VIDX]] : i32] : vector<3xf32>
// CHECK-NEXT: store %[[RVEC]], %[[ALLOC]][%[[I6]], %[[I5]]] : memref<5x4xvector<3xf32>>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: dealloc %[[ALLOC]] : memref<5x4x3xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
@ -134,6 +136,7 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
// CHECK-LABEL:func @materialize_write(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
// CHECK-DAG: %[[ALLOC:.*]] = alloca() : memref<5x4xvector<3xf32>>
// CHECK-DAG: %{{.*}} = constant dense<1.000000e+00> : vector<5x4x3xf32>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
@ -145,8 +148,7 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
// CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %{{.*}} step 4 {
// CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %{{.*}} {
// CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %{{.*}} step 5 {
// CHECK: %[[ALLOC:.*]] = alloc() {alignment = 128 : i64} : memref<5x4x3xf32>
// CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector.type_cast {{.*}} : memref<5x4x3xf32>
// CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector.type_cast {{.*}} : memref<5x4xvector<3xf32>>
// CHECK: store %{{.*}}, {{.*}} : memref<vector<5x4x3xf32>>
// CHECK-NEXT: scf.for %[[I4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK-NEXT: scf.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
@ -177,13 +179,14 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
// CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
// CHECK-NEXT: %[[S3:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
// CHECK-NEXT: %[[VIDX:.*]] = index_cast %[[I4]]
//
// CHECK-NEXT: {{.*}} = load {{.*}}[%[[I6]], %[[I5]], %[[I4]]] : memref<5x4x3xf32>
// CHECK: store {{.*}}, {{.*}}[%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : memref<?x?x?x?xf32>
// CHECK-NEXT: %[[VEC:.*]] = load {{.*}}[%[[I6]], %[[I5]]] : memref<5x4xvector<3xf32>>
// CHECK-NEXT: %[[SCAL:.*]] = vector.extractelement %[[VEC]][%[[VIDX]] : i32] : vector<3xf32>
// CHECK: store %[[SCAL]], {{.*}}[%[[S0]], %[[S1]], %[[S2]], %[[S3]]] : memref<?x?x?x?xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: dealloc {{.*}} : memref<5x4x3xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
@ -232,7 +235,7 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<3x
%f7 = constant 7.0: f32
// CHECK-DAG: %[[splat:.*]] = constant dense<7.000000e+00> : vector<15xf32>
// CHECK-DAG: %[[alloc:.*]] = alloca() {alignment = 128 : i64} : memref<3xvector<15xf32>>
// CHECK-DAG: %[[alloc:.*]] = alloca() : memref<3xvector<15xf32>>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[dim:.*]] = dim %[[A]], %[[C0]] : memref<?x?xf32>
// CHECK: affine.for %[[I:.*]] = 0 to 3 {
@ -307,7 +310,7 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<3x
// FULL-UNROLL-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32>
func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vector<3x15xf32>) {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[alloc:.*]] = alloca() {alignment = 128 : i64} : memref<3xvector<15xf32>>
// CHECK: %[[alloc:.*]] = alloca() : memref<3xvector<15xf32>>
// CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref<vector<3x15xf32>>
// CHECK: store %[[vec]], %[[vmemref]][] : memref<vector<3x15xf32>>
// CHECK: %[[dim:.*]] = dim %[[A]], %[[C0]] : memref<?x?xf32>
@ -363,7 +366,7 @@ func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vecto
// FULL-UNROLL-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32>
func @transfer_write_progressive_unmasked(%A : memref<?x?xf32>, %base: index, %vec: vector<3x15xf32>) {
// CHECK-NOT: scf.if
// CHECK-NEXT: %[[alloc:.*]] = alloca() {alignment = 128 : i64} : memref<3xvector<15xf32>>
// CHECK-NEXT: %[[alloc:.*]] = alloca() : memref<3xvector<15xf32>>
// CHECK-NEXT: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref<vector<3x15xf32>>
// CHECK-NEXT: store %[[vec]], %[[vmemref]][] : memref<vector<3x15xf32>>
// CHECK-NEXT: affine.for %[[I:.*]] = 0 to 3 {
@ -416,7 +419,7 @@ func @transfer_read_minor_identity(%A : memref<?x?x?x?xf32>) -> vector<3x3xf32>
// CHECK: %[[cst:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[c2:.*]] = constant 2 : index
// CHECK: %[[cst0:.*]] = constant dense<0.000000e+00> : vector<3xf32>
// CHECK: %[[m:.*]] = alloca() {alignment = 128 : i64} : memref<3xvector<3xf32>>
// CHECK: %[[m:.*]] = alloca() : memref<3xvector<3xf32>>
// CHECK: %[[d:.*]] = dim %[[A]], %[[c2]] : memref<?x?x?x?xf32>
// CHECK: affine.for %[[arg1:.*]] = 0 to 3 {
// CHECK: %[[cmp:.*]] = cmpi "slt", %[[arg1]], %[[d]] : index
@ -445,7 +448,7 @@ func @transfer_write_minor_identity(%A : vector<3x3xf32>, %B : memref<?x?x?x?xf3
// CHECK-SAME: %[[B:.*]]: memref<?x?x?x?xf32>)
// CHECK: %[[c0:.*]] = constant 0 : index
// CHECK: %[[c2:.*]] = constant 2 : index
// CHECK: %[[m:.*]] = alloca() {alignment = 128 : i64} : memref<3xvector<3xf32>>
// CHECK: %[[m:.*]] = alloca() : memref<3xvector<3xf32>>
// CHECK: %[[cast:.*]] = vector.type_cast %[[m]] : memref<3xvector<3xf32>> to memref<vector<3x3xf32>>
// CHECK: store %[[A]], %[[cast]][] : memref<vector<3x3xf32>>
// CHECK: %[[d:.*]] = dim %[[B]], %[[c2]] : memref<?x?x?x?xf32>

View File

@ -1089,7 +1089,7 @@ TEST_FUNC(vector_extractelement_op_i32) {
ScopedContext scope(builder, f.getLoc());
auto i32Type = builder.getI32Type();
auto vectorType = VectorType::get(/*shape=*/{8}, i32Type);
vector_extractelement(
vector_extract_element(
i32Type, std_constant(vectorType, builder.getI32VectorAttr({10})),
std_constant_int(0, i32Type));