[mlir] [VectorOps] Add expand/compress operations to Vector dialect

Introduces the expand and compress operations to the Vector dialect
(important memory operations for sparse computations), together
with a first reference implementation that lowers to the LLVM IR
dialect to enable running on CPU (and other targets that support
the corresponding LLVM IR intrinsics).

Reviewed By: reidtatge

Differential Revision: https://reviews.llvm.org/D84888
This commit is contained in:
aartbik 2020-07-31 12:47:25 -07:00
parent 3c0f347002
commit e8dcf5f87d
11 changed files with 505 additions and 60 deletions

View File

@ -1042,6 +1042,16 @@ def LLVM_masked_scatter
"type($value) `,` type($mask) `into` type($ptrs)";
}
/// Create a call to Masked Expand Load intrinsic.
def LLVM_masked_expandload
: LLVM_IntrOp<"masked.expandload", [0], [], [], 1>,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
/// Create a call to Masked Compress Store intrinsic.
def LLVM_masked_compressstore
: LLVM_IntrOp<"masked.compressstore", [], [0], [], 0>,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
//
// Atomic operations.
//

View File

@ -1158,7 +1158,7 @@ def Vector_GatherOp :
Variadic<VectorOfRank<[1]>>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
let summary = "gathers elements from memory into a vector as defined by an index vector";
let summary = "gathers elements from memory into a vector as defined by an index vector and mask";
let description = [{
The gather operation gathers elements from memory into a 1-D vector as
@ -1186,7 +1186,6 @@ def Vector_GatherOp :
%g = vector.gather %base, %indices, %mask, %pass_thru
: (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
```
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
@ -1217,7 +1216,7 @@ def Vector_ScatterOp :
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$value)> {
let summary = "scatters elements from a vector into memory as defined by an index vector";
let summary = "scatters elements from a vector into memory as defined by an index vector and mask";
let description = [{
The scatter operation scatters elements from a 1-D vector into memory as
@ -1265,6 +1264,108 @@ def Vector_ScatterOp :
"type($indices) `,` type($mask) `,` type($value) `into` type($base)";
}
def Vector_ExpandLoadOp :
Vector_Op<"expandload">,
Arguments<(ins AnyMemRef:$base,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
let description = [{
The expand load reads elements from memory into a 1-D vector as defined
by a base and a 1-D mask vector. When the mask is set, the next element
is read from memory. Otherwise, the corresponding element is taken from
a 1-D pass-through vector. Informally the semantics are:
```
index = base
result[0] := mask[0] ? MEM[index++] : pass_thru[0]
result[1] := mask[1] ? MEM[index++] : pass_thru[1]
etc.
```
Note that the index increment is done conditionally.
The expand load can be used directly where applicable, or can be used
during progressively lowering to bring other memory operations closer to
hardware ISA support for an expand. The semantics of the operation closely
correspond to those of the `llvm.masked.expandload`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
Example:
```mlir
%0 = vector.expandload %base, %mask, %pass_thru
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
```
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return base().getType().cast<MemRefType>();
}
VectorType getMaskVectorType() {
return mask().getType().cast<VectorType>();
}
VectorType getPassThruVectorType() {
return pass_thru().getType().cast<VectorType>();
}
VectorType getResultVectorType() {
return result().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
}
def Vector_CompressStoreOp :
Vector_Op<"compressstore">,
Arguments<(ins AnyMemRef:$base,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$value)> {
let summary = "writes elements selectively from a vector as defined by a mask";
let description = [{
The compress store operation writes elements from a 1-D vector into memory
as defined by a base and a 1-D mask vector. When the mask is set, the
corresponding element from the vector is written next to memory. Otherwise,
no action is taken for the element. Informally the semantics are:
```
index = base
if (mask[0]) MEM[index++] = value[0]
if (mask[1]) MEM[index++] = value[1]
etc.
```
Note that the index increment is done conditionally.
The compress store can be used directly where applicable, or can be used
during progressively lowering to bring other memory operations closer to
hardware ISA support for a compress. The semantics of the operation closely
correspond to those of the `llvm.masked.compressstore`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
Example:
```mlir
vector.compressstore %base, %mask, %value
: memref<?xf32>, vector<8xi1>, vector<8xf32>
```
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return base().getType().cast<MemRefType>();
}
VectorType getMaskVectorType() {
return mask().getType().cast<VectorType>();
}
VectorType getValueVectorType() {
return value().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
"type($base) `,` type($mask) `,` type($value)";
}
def Vector_ShapeCastOp :
Vector_Op<"shape_cast", [NoSideEffect]>,
Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>,

View File

@ -0,0 +1,90 @@
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
func @compress16(%base: memref<?xf32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
vector.compressstore %base, %mask, %value
: memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
func @printmem16(%A: memref<?xf32>) {
%c0 = constant 0: index
%c1 = constant 1: index
%c16 = constant 16: index
%z = constant 0.0: f32
%m = vector.broadcast %z : f32 to vector<16xf32>
%mem = scf.for %i = %c0 to %c16 step %c1
iter_args(%m_iter = %m) -> (vector<16xf32>) {
%c = load %A[%i] : memref<?xf32>
%i32 = index_cast %i : index to i32
%m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32>
scf.yield %m_new : vector<16xf32>
}
vector.print %mem : vector<16xf32>
return
}
func @entry() {
// Set up memory.
%c0 = constant 0: index
%c1 = constant 1: index
%c16 = constant 16: index
%A = alloc(%c16) : memref<?xf32>
%z = constant 0.0: f32
%v = vector.broadcast %z : f32 to vector<16xf32>
%value = scf.for %i = %c0 to %c16 step %c1
iter_args(%v_iter = %v) -> (vector<16xf32>) {
store %z, %A[%i] : memref<?xf32>
%i32 = index_cast %i : index to i32
%fi = sitofp %i32 : i32 to f32
%v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32>
scf.yield %v_new : vector<16xf32>
}
// Set up masks.
%f = constant 0: i1
%t = constant 1: i1
%none = vector.constant_mask [0] : vector<16xi1>
%all = vector.constant_mask [16] : vector<16xi1>
%some1 = vector.constant_mask [4] : vector<16xi1>
%0 = vector.insert %f, %some1[0] : i1 into vector<16xi1>
%1 = vector.insert %t, %0[7] : i1 into vector<16xi1>
%2 = vector.insert %t, %1[11] : i1 into vector<16xi1>
%3 = vector.insert %t, %2[13] : i1 into vector<16xi1>
%some2 = vector.insert %t, %3[15] : i1 into vector<16xi1>
%some3 = vector.insert %f, %some2[2] : i1 into vector<16xi1>
//
// Expanding load tests.
//
call @compress16(%A, %none, %value)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
call @compress16(%A, %all, %value)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK-NEXT: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
call @compress16(%A, %some3, %value)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK-NEXT: ( 1, 3, 7, 11, 13, 15, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
call @compress16(%A, %some2, %value)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK-NEXT: ( 1, 2, 3, 7, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
call @compress16(%A, %some1, %value)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
return
}

View File

@ -0,0 +1,82 @@
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
func @expand16(%base: memref<?xf32>,
%mask: vector<16xi1>,
%pass_thru: vector<16xf32>) -> vector<16xf32> {
%e = vector.expandload %base, %mask, %pass_thru
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %e : vector<16xf32>
}
func @entry() {
// Set up memory.
%c0 = constant 0: index
%c1 = constant 1: index
%c16 = constant 16: index
%A = alloc(%c16) : memref<?xf32>
scf.for %i = %c0 to %c16 step %c1 {
%i32 = index_cast %i : index to i32
%fi = sitofp %i32 : i32 to f32
store %fi, %A[%i] : memref<?xf32>
}
// Set up pass thru vector.
%u = constant -7.0: f32
%v = constant 7.7: f32
%pass = vector.broadcast %u : f32 to vector<16xf32>
// Set up masks.
%f = constant 0: i1
%t = constant 1: i1
%none = vector.constant_mask [0] : vector<16xi1>
%all = vector.constant_mask [16] : vector<16xi1>
%some1 = vector.constant_mask [4] : vector<16xi1>
%0 = vector.insert %f, %some1[0] : i1 into vector<16xi1>
%1 = vector.insert %t, %0[7] : i1 into vector<16xi1>
%2 = vector.insert %t, %1[11] : i1 into vector<16xi1>
%3 = vector.insert %t, %2[13] : i1 into vector<16xi1>
%some2 = vector.insert %t, %3[15] : i1 into vector<16xi1>
%some3 = vector.insert %f, %some2[2] : i1 into vector<16xi1>
//
// Expanding load tests.
//
%e1 = call @expand16(%A, %none, %pass)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
vector.print %e1 : vector<16xf32>
// CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
%e2 = call @expand16(%A, %all, %pass)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
vector.print %e2 : vector<16xf32>
// CHECK-NEXT: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
%e3 = call @expand16(%A, %some1, %pass)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
vector.print %e3 : vector<16xf32>
// CHECK-NEXT: ( 0, 1, 2, 3, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
%e4 = call @expand16(%A, %some2, %pass)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
vector.print %e4 : vector<16xf32>
// CHECK-NEXT: ( -7, 0, 1, 2, -7, -7, -7, 3, -7, -7, -7, 4, -7, 5, -7, 6 )
%e5 = call @expand16(%A, %some3, %pass)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
vector.print %e5 : vector<16xf32>
// CHECK-NEXT: ( -7, 0, -7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, -7, 5 )
%4 = vector.insert %v, %pass[1] : f32 into vector<16xf32>
%5 = vector.insert %v, %4[2] : f32 into vector<16xf32>
%alt_pass = vector.insert %v, %5[14] : f32 into vector<16xf32>
%e6 = call @expand16(%A, %some3, %alt_pass)
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
vector.print %e6 : vector<16xf32>
// CHECK-NEXT: ( -7, 0, 7.7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, 7.7, 5 )
return
}

View File

@ -11,34 +11,20 @@ func @scatter8(%base: memref<?xf32>,
return
}
func @printmem(%A: memref<?xf32>) {
%f = constant 0.0: f32
%0 = vector.broadcast %f : f32 to vector<8xf32>
%1 = constant 0: index
%2 = load %A[%1] : memref<?xf32>
%3 = vector.insert %2, %0[0] : f32 into vector<8xf32>
%4 = constant 1: index
%5 = load %A[%4] : memref<?xf32>
%6 = vector.insert %5, %3[1] : f32 into vector<8xf32>
%7 = constant 2: index
%8 = load %A[%7] : memref<?xf32>
%9 = vector.insert %8, %6[2] : f32 into vector<8xf32>
%10 = constant 3: index
%11 = load %A[%10] : memref<?xf32>
%12 = vector.insert %11, %9[3] : f32 into vector<8xf32>
%13 = constant 4: index
%14 = load %A[%13] : memref<?xf32>
%15 = vector.insert %14, %12[4] : f32 into vector<8xf32>
%16 = constant 5: index
%17 = load %A[%16] : memref<?xf32>
%18 = vector.insert %17, %15[5] : f32 into vector<8xf32>
%19 = constant 6: index
%20 = load %A[%19] : memref<?xf32>
%21 = vector.insert %20, %18[6] : f32 into vector<8xf32>
%22 = constant 7: index
%23 = load %A[%22] : memref<?xf32>
%24 = vector.insert %23, %21[7] : f32 into vector<8xf32>
vector.print %24 : vector<8xf32>
func @printmem8(%A: memref<?xf32>) {
%c0 = constant 0: index
%c1 = constant 1: index
%c8 = constant 8: index
%z = constant 0.0: f32
%m = vector.broadcast %z : f32 to vector<8xf32>
%mem = scf.for %i = %c0 to %c8 step %c1
iter_args(%m_iter = %m) -> (vector<8xf32>) {
%c = load %A[%i] : memref<?xf32>
%i32 = index_cast %i : index to i32
%m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<8xf32>
scf.yield %m_new : vector<8xf32>
}
vector.print %mem : vector<8xf32>
return
}
@ -104,31 +90,27 @@ func @entry() {
vector.print %idx : vector<8xi32>
// CHECK: ( 7, 0, 1, 6, 2, 4, 5, 3 )
call @printmem(%A) : (memref<?xf32>) -> ()
call @printmem8(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
call @scatter8(%A, %idx, %none, %val)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
call @printmem(%A) : (memref<?xf32>) -> ()
call @printmem8(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
call @scatter8(%A, %idx, %some, %val)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
call @printmem(%A) : (memref<?xf32>) -> ()
call @printmem8(%A) : (memref<?xf32>) -> ()
// CHECK: ( 1, 2, 2, 3, 4, 5, 3, 0 )
call @scatter8(%A, %idx, %more, %val)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
call @printmem(%A) : (memref<?xf32>) -> ()
call @printmem8(%A) : (memref<?xf32>) -> ()
// CHECK: ( 1, 2, 2, 7, 4, 5, 3, 0 )
call @scatter8(%A, %idx, %all, %val)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
call @printmem(%A) : (memref<?xf32>) -> ()
call @printmem8(%A) : (memref<?xf32>) -> ()
// CHECK: ( 1, 2, 4, 7, 5, 6, 3, 0 )
return

View File

@ -134,11 +134,9 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
return success();
}
// Helper that returns vector of pointers given a base and an index vector.
LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
Value memref, Value indices, MemRefType memRefType,
VectorType vType, Type iType, Value &ptrs) {
// Helper that returns the base address of a memref.
LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
Value memref, MemRefType memRefType, Value &base) {
// Inspect stride and offset structure.
//
// TODO: flat memory only for now, generalize
@ -149,13 +147,31 @@ LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
offset != 0 || memRefType.getMemorySpace() != 0)
return failure();
base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
return success();
}
// Create a vector of pointers from base and indices.
MemRefDescriptor memRefDescriptor(memref);
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
int64_t size = vType.getDimSize(0);
auto pType = memRefDescriptor.getElementType();
auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size);
// Helper that returns a pointer given a memref base.
LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
Value memref, MemRefType memRefType, Value &ptr) {
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
auto pType = MemRefDescriptor(memref).getElementType();
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
return success();
}
// Helper that returns vector of pointers given a memref base and an index
// vector.
LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
Value memref, Value indices, MemRefType memRefType,
VectorType vType, Type iType, Value &ptrs) {
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
auto pType = MemRefDescriptor(memref).getElementType();
auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
return success();
}
@ -305,9 +321,8 @@ public:
VectorType vType = gather.getResultVectorType();
Type iType = gather.getIndicesVectorType().getElementType();
Value ptrs;
if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
adaptor.indices(), gather.getMemRefType(), vType,
iType, ptrs)))
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
gather.getMemRefType(), vType, iType, ptrs)))
return failure();
// Replace with the gather intrinsic.
@ -344,9 +359,8 @@ public:
VectorType vType = scatter.getValueVectorType();
Type iType = scatter.getIndicesVectorType().getElementType();
Value ptrs;
if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
adaptor.indices(), scatter.getMemRefType(), vType,
iType, ptrs)))
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
scatter.getMemRefType(), vType, iType, ptrs)))
return failure();
// Replace with the scatter intrinsic.
@ -357,6 +371,60 @@ public:
}
};
/// Conversion pattern for a vector.expandload.
class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorExpandLoadOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
typeConverter) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto expand = cast<vector::ExpandLoadOp>(op);
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
Value ptr;
if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
ptr)))
return failure();
auto vType = expand.getResultVectorType();
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
op, typeConverter.convertType(vType), ptr, adaptor.mask(),
adaptor.pass_thru());
return success();
}
};
/// Conversion pattern for a vector.compressstore.
class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorCompressStoreOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
context, typeConverter) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto compress = cast<vector::CompressStoreOp>(op);
auto adaptor = vector::CompressStoreOpAdaptor(operands);
Value ptr;
if (failed(getBasePtr(rewriter, loc, adaptor.base(),
compress.getMemRefType(), ptr)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
op, adaptor.value(), ptr, adaptor.mask());
return success();
}
};
/// Conversion pattern for all vector reductions.
class VectorReductionOpConversion : public ConvertToLLVMPattern {
public:
@ -1274,7 +1342,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorTransferConversion<TransferWriteOp>,
VectorTypeCastOpConversion,
VectorGatherOpConversion,
VectorScatterOpConversion>(ctx, converter);
VectorScatterOpConversion,
VectorExpandLoadOpConversion,
VectorCompressStoreOpConversion>(ctx, converter);
// clang-format on
}

View File

@ -1898,6 +1898,41 @@ static LogicalResult verify(ScatterOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// ExpandLoadOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(ExpandLoadOp op) {
VectorType maskVType = op.getMaskVectorType();
VectorType passVType = op.getPassThruVectorType();
VectorType resVType = op.getResultVectorType();
if (resVType.getElementType() != op.getMemRefType().getElementType())
return op.emitOpError("base and result element type should match");
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected result dim to match mask dim");
if (resVType != passVType)
return op.emitOpError("expected pass_thru of same type as result type");
return success();
}
//===----------------------------------------------------------------------===//
// CompressStoreOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(CompressStoreOp op) {
VectorType maskVType = op.getMaskVectorType();
VectorType valueVType = op.getValueVectorType();
if (valueVType.getElementType() != op.getMemRefType().getElementType())
return op.emitOpError("base and value element type should match");
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected value dim to match mask dim");
return success();
}
//===----------------------------------------------------------------------===//
// ShapeCastOp
//===----------------------------------------------------------------------===//

View File

@ -989,3 +989,23 @@ func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>
// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr<float>, !llvm.vec<3 x i32>) -> !llvm.vec<3 x ptr<float>>
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<3 x float>, !llvm.vec<3 x i1> into !llvm.vec<3 x ptr<float>>
// CHECK: llvm.return
func @expand_load_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> {
%0 = vector.expandload %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
return %0 : vector<11xf32>
}
// CHECK-LABEL: func @expand_load_op
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<float>) -> !llvm.ptr<float>
// CHECK: %[[E:.*]] = "llvm.intr.masked.expandload"(%[[P]], %{{.*}}, %{{.*}}) : (!llvm.ptr<float>, !llvm.vec<11 x i1>, !llvm.vec<11 x float>) -> !llvm.vec<11 x float>
// CHECK: llvm.return %[[E]] : !llvm.vec<11 x float>
func @compress_store_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) {
vector.compressstore %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
return
}
// CHECK-LABEL: func @compress_store_op
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<float>) -> !llvm.ptr<float>
// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (!llvm.vec<11 x float>, !llvm.ptr<float>, !llvm.vec<11 x i1>) -> ()
// CHECK: llvm.return

View File

@ -1240,3 +1240,38 @@ func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
// expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}}
vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref<?xf32>
}
// -----
func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
// expected-error@+1 {{'vector.expandload' op base and result element type should match}}
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
// -----
func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
// expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}}
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
}
// -----
func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<17xf32>) {
// expected-error@+1 {{'vector.expandload' op expected pass_thru of same type as result type}}
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
}
// -----
func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
// expected-error@+1 {{'vector.compressstore' op base and value element type should match}}
vector.compressstore %base, %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
}
// -----
func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
// expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}}
vector.compressstore %base, %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
}

View File

@ -379,3 +379,12 @@ func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask:
vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
return
}
// CHECK-LABEL: @expand_and_compress
func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
// CHECK: %[[X:.*]] = vector.expandload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%0 = vector.expandload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK: vector.compressstore %{{.*}}, %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
vector.compressstore %base, %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}

View File

@ -237,8 +237,8 @@ llvm.func @matrix_intrinsics(%A: !llvm.vec<64 x float>, %B: !llvm.vec<48 x float
llvm.return
}
// CHECK-LABEL: @masked_intrinsics
llvm.func @masked_intrinsics(%A: !llvm.ptr<vec<7 x float>>, %mask: !llvm.vec<7 x i1>) {
// CHECK-LABEL: @masked_load_store_intrinsics
llvm.func @masked_load_store_intrinsics(%A: !llvm.ptr<vec<7 x float>>, %mask: !llvm.vec<7 x i1>) {
// CHECK: call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> undef)
%a = llvm.intr.masked.load %A, %mask { alignment = 1: i32} :
(!llvm.ptr<vec<7 x float>>, !llvm.vec<7 x i1>) -> !llvm.vec<7 x float>
@ -265,6 +265,17 @@ llvm.func @masked_gather_scatter_intrinsics(%M: !llvm.vec<7 x ptr<float>>, %mask
llvm.return
}
// CHECK-LABEL: @masked_expand_compress_intrinsics
llvm.func @masked_expand_compress_intrinsics(%ptr: !llvm.ptr<float>, %mask: !llvm.vec<7 x i1>, %passthru: !llvm.vec<7 x float>) {
// CHECK: call <7 x float> @llvm.masked.expandload.v7f32(float* %{{.*}}, <7 x i1> %{{.*}}, <7 x float> %{{.*}})
%0 = "llvm.intr.masked.expandload"(%ptr, %mask, %passthru)
: (!llvm.ptr<float>, !llvm.vec<7 x i1>, !llvm.vec<7 x float>) -> (!llvm.vec<7 x float>)
// CHECK: call void @llvm.masked.compressstore.v7f32(<7 x float> %{{.*}}, float* %{{.*}}, <7 x i1> %{{.*}})
"llvm.intr.masked.compressstore"(%0, %ptr, %mask)
: (!llvm.vec<7 x float>, !llvm.ptr<float>, !llvm.vec<7 x i1>) -> ()
llvm.return
}
// CHECK-LABEL: @memcpy_test
llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm.ptr<i8>, %arg3: !llvm.ptr<i8>) {
// CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}})