forked from OSchip/llvm-project
[mlir] [VectorOps] Add expand/compress operations to Vector dialect
Introduces the expand and compress operations to the Vector dialect (important memory operations for sparse computations), together with a first reference implementation that lowers to the LLVM IR dialect to enable running on CPU (and other targets that support the corresponding LLVM IR intrinsics). Reviewed By: reidtatge Differential Revision: https://reviews.llvm.org/D84888
This commit is contained in:
parent
3c0f347002
commit
e8dcf5f87d
|
|
@ -1042,6 +1042,16 @@ def LLVM_masked_scatter
|
||||||
"type($value) `,` type($mask) `into` type($ptrs)";
|
"type($value) `,` type($mask) `into` type($ptrs)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a call to Masked Expand Load intrinsic.
|
||||||
|
def LLVM_masked_expandload
|
||||||
|
: LLVM_IntrOp<"masked.expandload", [0], [], [], 1>,
|
||||||
|
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||||
|
|
||||||
|
/// Create a call to Masked Compress Store intrinsic.
|
||||||
|
def LLVM_masked_compressstore
|
||||||
|
: LLVM_IntrOp<"masked.compressstore", [], [0], [], 0>,
|
||||||
|
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Atomic operations.
|
// Atomic operations.
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -1158,7 +1158,7 @@ def Vector_GatherOp :
|
||||||
Variadic<VectorOfRank<[1]>>:$pass_thru)>,
|
Variadic<VectorOfRank<[1]>>:$pass_thru)>,
|
||||||
Results<(outs VectorOfRank<[1]>:$result)> {
|
Results<(outs VectorOfRank<[1]>:$result)> {
|
||||||
|
|
||||||
let summary = "gathers elements from memory into a vector as defined by an index vector";
|
let summary = "gathers elements from memory into a vector as defined by an index vector and mask";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
The gather operation gathers elements from memory into a 1-D vector as
|
The gather operation gathers elements from memory into a 1-D vector as
|
||||||
|
|
@ -1186,7 +1186,6 @@ def Vector_GatherOp :
|
||||||
%g = vector.gather %base, %indices, %mask, %pass_thru
|
%g = vector.gather %base, %indices, %mask, %pass_thru
|
||||||
: (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
|
: (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
|
||||||
```
|
```
|
||||||
|
|
||||||
}];
|
}];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
MemRefType getMemRefType() {
|
MemRefType getMemRefType() {
|
||||||
|
|
@ -1217,7 +1216,7 @@ def Vector_ScatterOp :
|
||||||
VectorOfRankAndType<[1], [I1]>:$mask,
|
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||||
VectorOfRank<[1]>:$value)> {
|
VectorOfRank<[1]>:$value)> {
|
||||||
|
|
||||||
let summary = "scatters elements from a vector into memory as defined by an index vector";
|
let summary = "scatters elements from a vector into memory as defined by an index vector and mask";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
The scatter operation scatters elements from a 1-D vector into memory as
|
The scatter operation scatters elements from a 1-D vector into memory as
|
||||||
|
|
@ -1265,6 +1264,108 @@ def Vector_ScatterOp :
|
||||||
"type($indices) `,` type($mask) `,` type($value) `into` type($base)";
|
"type($indices) `,` type($mask) `,` type($value) `into` type($base)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Vector_ExpandLoadOp :
|
||||||
|
Vector_Op<"expandload">,
|
||||||
|
Arguments<(ins AnyMemRef:$base,
|
||||||
|
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||||
|
VectorOfRank<[1]>:$pass_thru)>,
|
||||||
|
Results<(outs VectorOfRank<[1]>:$result)> {
|
||||||
|
|
||||||
|
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
The expand load reads elements from memory into a 1-D vector as defined
|
||||||
|
by a base and a 1-D mask vector. When the mask is set, the next element
|
||||||
|
is read from memory. Otherwise, the corresponding element is taken from
|
||||||
|
a 1-D pass-through vector. Informally the semantics are:
|
||||||
|
```
|
||||||
|
index = base
|
||||||
|
result[0] := mask[0] ? MEM[index++] : pass_thru[0]
|
||||||
|
result[1] := mask[1] ? MEM[index++] : pass_thru[1]
|
||||||
|
etc.
|
||||||
|
```
|
||||||
|
Note that the index increment is done conditionally.
|
||||||
|
|
||||||
|
The expand load can be used directly where applicable, or can be used
|
||||||
|
during progressively lowering to bring other memory operations closer to
|
||||||
|
hardware ISA support for an expand. The semantics of the operation closely
|
||||||
|
correspond to those of the `llvm.masked.expandload`
|
||||||
|
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
%0 = vector.expandload %base, %mask, %pass_thru
|
||||||
|
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
MemRefType getMemRefType() {
|
||||||
|
return base().getType().cast<MemRefType>();
|
||||||
|
}
|
||||||
|
VectorType getMaskVectorType() {
|
||||||
|
return mask().getType().cast<VectorType>();
|
||||||
|
}
|
||||||
|
VectorType getPassThruVectorType() {
|
||||||
|
return pass_thru().getType().cast<VectorType>();
|
||||||
|
}
|
||||||
|
VectorType getResultVectorType() {
|
||||||
|
return result().getType().cast<VectorType>();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
|
||||||
|
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
|
def Vector_CompressStoreOp :
|
||||||
|
Vector_Op<"compressstore">,
|
||||||
|
Arguments<(ins AnyMemRef:$base,
|
||||||
|
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||||
|
VectorOfRank<[1]>:$value)> {
|
||||||
|
|
||||||
|
let summary = "writes elements selectively from a vector as defined by a mask";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
The compress store operation writes elements from a 1-D vector into memory
|
||||||
|
as defined by a base and a 1-D mask vector. When the mask is set, the
|
||||||
|
corresponding element from the vector is written next to memory. Otherwise,
|
||||||
|
no action is taken for the element. Informally the semantics are:
|
||||||
|
```
|
||||||
|
index = base
|
||||||
|
if (mask[0]) MEM[index++] = value[0]
|
||||||
|
if (mask[1]) MEM[index++] = value[1]
|
||||||
|
etc.
|
||||||
|
```
|
||||||
|
Note that the index increment is done conditionally.
|
||||||
|
|
||||||
|
The compress store can be used directly where applicable, or can be used
|
||||||
|
during progressively lowering to bring other memory operations closer to
|
||||||
|
hardware ISA support for a compress. The semantics of the operation closely
|
||||||
|
correspond to those of the `llvm.masked.compressstore`
|
||||||
|
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
vector.compressstore %base, %mask, %value
|
||||||
|
: memref<?xf32>, vector<8xi1>, vector<8xf32>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
MemRefType getMemRefType() {
|
||||||
|
return base().getType().cast<MemRefType>();
|
||||||
|
}
|
||||||
|
VectorType getMaskVectorType() {
|
||||||
|
return mask().getType().cast<VectorType>();
|
||||||
|
}
|
||||||
|
VectorType getValueVectorType() {
|
||||||
|
return value().getType().cast<VectorType>();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
|
||||||
|
"type($base) `,` type($mask) `,` type($value)";
|
||||||
|
}
|
||||||
|
|
||||||
def Vector_ShapeCastOp :
|
def Vector_ShapeCastOp :
|
||||||
Vector_Op<"shape_cast", [NoSideEffect]>,
|
Vector_Op<"shape_cast", [NoSideEffect]>,
|
||||||
Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>,
|
Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,90 @@
|
||||||
|
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
|
||||||
|
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
|
||||||
|
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||||
|
// RUN: FileCheck %s
|
||||||
|
|
||||||
|
func @compress16(%base: memref<?xf32>,
|
||||||
|
%mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||||
|
vector.compressstore %base, %mask, %value
|
||||||
|
: memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func @printmem16(%A: memref<?xf32>) {
|
||||||
|
%c0 = constant 0: index
|
||||||
|
%c1 = constant 1: index
|
||||||
|
%c16 = constant 16: index
|
||||||
|
%z = constant 0.0: f32
|
||||||
|
%m = vector.broadcast %z : f32 to vector<16xf32>
|
||||||
|
%mem = scf.for %i = %c0 to %c16 step %c1
|
||||||
|
iter_args(%m_iter = %m) -> (vector<16xf32>) {
|
||||||
|
%c = load %A[%i] : memref<?xf32>
|
||||||
|
%i32 = index_cast %i : index to i32
|
||||||
|
%m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32>
|
||||||
|
scf.yield %m_new : vector<16xf32>
|
||||||
|
}
|
||||||
|
vector.print %mem : vector<16xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func @entry() {
|
||||||
|
// Set up memory.
|
||||||
|
%c0 = constant 0: index
|
||||||
|
%c1 = constant 1: index
|
||||||
|
%c16 = constant 16: index
|
||||||
|
%A = alloc(%c16) : memref<?xf32>
|
||||||
|
%z = constant 0.0: f32
|
||||||
|
%v = vector.broadcast %z : f32 to vector<16xf32>
|
||||||
|
%value = scf.for %i = %c0 to %c16 step %c1
|
||||||
|
iter_args(%v_iter = %v) -> (vector<16xf32>) {
|
||||||
|
store %z, %A[%i] : memref<?xf32>
|
||||||
|
%i32 = index_cast %i : index to i32
|
||||||
|
%fi = sitofp %i32 : i32 to f32
|
||||||
|
%v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32>
|
||||||
|
scf.yield %v_new : vector<16xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up masks.
|
||||||
|
%f = constant 0: i1
|
||||||
|
%t = constant 1: i1
|
||||||
|
%none = vector.constant_mask [0] : vector<16xi1>
|
||||||
|
%all = vector.constant_mask [16] : vector<16xi1>
|
||||||
|
%some1 = vector.constant_mask [4] : vector<16xi1>
|
||||||
|
%0 = vector.insert %f, %some1[0] : i1 into vector<16xi1>
|
||||||
|
%1 = vector.insert %t, %0[7] : i1 into vector<16xi1>
|
||||||
|
%2 = vector.insert %t, %1[11] : i1 into vector<16xi1>
|
||||||
|
%3 = vector.insert %t, %2[13] : i1 into vector<16xi1>
|
||||||
|
%some2 = vector.insert %t, %3[15] : i1 into vector<16xi1>
|
||||||
|
%some3 = vector.insert %f, %some2[2] : i1 into vector<16xi1>
|
||||||
|
|
||||||
|
//
|
||||||
|
// Expanding load tests.
|
||||||
|
//
|
||||||
|
|
||||||
|
call @compress16(%A, %none, %value)
|
||||||
|
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||||
|
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||||
|
// CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||||
|
|
||||||
|
call @compress16(%A, %all, %value)
|
||||||
|
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||||
|
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||||
|
// CHECK-NEXT: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||||
|
|
||||||
|
call @compress16(%A, %some3, %value)
|
||||||
|
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||||
|
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||||
|
// CHECK-NEXT: ( 1, 3, 7, 11, 13, 15, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||||
|
|
||||||
|
call @compress16(%A, %some2, %value)
|
||||||
|
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||||
|
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||||
|
// CHECK-NEXT: ( 1, 2, 3, 7, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||||
|
|
||||||
|
call @compress16(%A, %some1, %value)
|
||||||
|
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||||
|
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||||
|
// CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
|
||||||
|
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
|
||||||
|
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||||
|
// RUN: FileCheck %s
|
||||||
|
|
||||||
|
func @expand16(%base: memref<?xf32>,
|
||||||
|
%mask: vector<16xi1>,
|
||||||
|
%pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||||
|
%e = vector.expandload %base, %mask, %pass_thru
|
||||||
|
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||||
|
return %e : vector<16xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @entry() {
|
||||||
|
// Set up memory.
|
||||||
|
%c0 = constant 0: index
|
||||||
|
%c1 = constant 1: index
|
||||||
|
%c16 = constant 16: index
|
||||||
|
%A = alloc(%c16) : memref<?xf32>
|
||||||
|
scf.for %i = %c0 to %c16 step %c1 {
|
||||||
|
%i32 = index_cast %i : index to i32
|
||||||
|
%fi = sitofp %i32 : i32 to f32
|
||||||
|
store %fi, %A[%i] : memref<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up pass thru vector.
|
||||||
|
%u = constant -7.0: f32
|
||||||
|
%v = constant 7.7: f32
|
||||||
|
%pass = vector.broadcast %u : f32 to vector<16xf32>
|
||||||
|
|
||||||
|
// Set up masks.
|
||||||
|
%f = constant 0: i1
|
||||||
|
%t = constant 1: i1
|
||||||
|
%none = vector.constant_mask [0] : vector<16xi1>
|
||||||
|
%all = vector.constant_mask [16] : vector<16xi1>
|
||||||
|
%some1 = vector.constant_mask [4] : vector<16xi1>
|
||||||
|
%0 = vector.insert %f, %some1[0] : i1 into vector<16xi1>
|
||||||
|
%1 = vector.insert %t, %0[7] : i1 into vector<16xi1>
|
||||||
|
%2 = vector.insert %t, %1[11] : i1 into vector<16xi1>
|
||||||
|
%3 = vector.insert %t, %2[13] : i1 into vector<16xi1>
|
||||||
|
%some2 = vector.insert %t, %3[15] : i1 into vector<16xi1>
|
||||||
|
%some3 = vector.insert %f, %some2[2] : i1 into vector<16xi1>
|
||||||
|
|
||||||
|
//
|
||||||
|
// Expanding load tests.
|
||||||
|
//
|
||||||
|
|
||||||
|
%e1 = call @expand16(%A, %none, %pass)
|
||||||
|
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||||
|
vector.print %e1 : vector<16xf32>
|
||||||
|
// CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
|
||||||
|
|
||||||
|
%e2 = call @expand16(%A, %all, %pass)
|
||||||
|
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||||
|
vector.print %e2 : vector<16xf32>
|
||||||
|
// CHECK-NEXT: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||||
|
|
||||||
|
%e3 = call @expand16(%A, %some1, %pass)
|
||||||
|
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||||
|
vector.print %e3 : vector<16xf32>
|
||||||
|
// CHECK-NEXT: ( 0, 1, 2, 3, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
|
||||||
|
|
||||||
|
%e4 = call @expand16(%A, %some2, %pass)
|
||||||
|
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||||
|
vector.print %e4 : vector<16xf32>
|
||||||
|
// CHECK-NEXT: ( -7, 0, 1, 2, -7, -7, -7, 3, -7, -7, -7, 4, -7, 5, -7, 6 )
|
||||||
|
|
||||||
|
%e5 = call @expand16(%A, %some3, %pass)
|
||||||
|
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||||
|
vector.print %e5 : vector<16xf32>
|
||||||
|
// CHECK-NEXT: ( -7, 0, -7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, -7, 5 )
|
||||||
|
|
||||||
|
%4 = vector.insert %v, %pass[1] : f32 into vector<16xf32>
|
||||||
|
%5 = vector.insert %v, %4[2] : f32 into vector<16xf32>
|
||||||
|
%alt_pass = vector.insert %v, %5[14] : f32 into vector<16xf32>
|
||||||
|
%e6 = call @expand16(%A, %some3, %alt_pass)
|
||||||
|
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||||
|
vector.print %e6 : vector<16xf32>
|
||||||
|
// CHECK-NEXT: ( -7, 0, 7.7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, 7.7, 5 )
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -11,34 +11,20 @@ func @scatter8(%base: memref<?xf32>,
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func @printmem(%A: memref<?xf32>) {
|
func @printmem8(%A: memref<?xf32>) {
|
||||||
%f = constant 0.0: f32
|
%c0 = constant 0: index
|
||||||
%0 = vector.broadcast %f : f32 to vector<8xf32>
|
%c1 = constant 1: index
|
||||||
%1 = constant 0: index
|
%c8 = constant 8: index
|
||||||
%2 = load %A[%1] : memref<?xf32>
|
%z = constant 0.0: f32
|
||||||
%3 = vector.insert %2, %0[0] : f32 into vector<8xf32>
|
%m = vector.broadcast %z : f32 to vector<8xf32>
|
||||||
%4 = constant 1: index
|
%mem = scf.for %i = %c0 to %c8 step %c1
|
||||||
%5 = load %A[%4] : memref<?xf32>
|
iter_args(%m_iter = %m) -> (vector<8xf32>) {
|
||||||
%6 = vector.insert %5, %3[1] : f32 into vector<8xf32>
|
%c = load %A[%i] : memref<?xf32>
|
||||||
%7 = constant 2: index
|
%i32 = index_cast %i : index to i32
|
||||||
%8 = load %A[%7] : memref<?xf32>
|
%m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<8xf32>
|
||||||
%9 = vector.insert %8, %6[2] : f32 into vector<8xf32>
|
scf.yield %m_new : vector<8xf32>
|
||||||
%10 = constant 3: index
|
}
|
||||||
%11 = load %A[%10] : memref<?xf32>
|
vector.print %mem : vector<8xf32>
|
||||||
%12 = vector.insert %11, %9[3] : f32 into vector<8xf32>
|
|
||||||
%13 = constant 4: index
|
|
||||||
%14 = load %A[%13] : memref<?xf32>
|
|
||||||
%15 = vector.insert %14, %12[4] : f32 into vector<8xf32>
|
|
||||||
%16 = constant 5: index
|
|
||||||
%17 = load %A[%16] : memref<?xf32>
|
|
||||||
%18 = vector.insert %17, %15[5] : f32 into vector<8xf32>
|
|
||||||
%19 = constant 6: index
|
|
||||||
%20 = load %A[%19] : memref<?xf32>
|
|
||||||
%21 = vector.insert %20, %18[6] : f32 into vector<8xf32>
|
|
||||||
%22 = constant 7: index
|
|
||||||
%23 = load %A[%22] : memref<?xf32>
|
|
||||||
%24 = vector.insert %23, %21[7] : f32 into vector<8xf32>
|
|
||||||
vector.print %24 : vector<8xf32>
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -104,31 +90,27 @@ func @entry() {
|
||||||
vector.print %idx : vector<8xi32>
|
vector.print %idx : vector<8xi32>
|
||||||
// CHECK: ( 7, 0, 1, 6, 2, 4, 5, 3 )
|
// CHECK: ( 7, 0, 1, 6, 2, 4, 5, 3 )
|
||||||
|
|
||||||
call @printmem(%A) : (memref<?xf32>) -> ()
|
call @printmem8(%A) : (memref<?xf32>) -> ()
|
||||||
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
|
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
|
||||||
|
|
||||||
call @scatter8(%A, %idx, %none, %val)
|
call @scatter8(%A, %idx, %none, %val)
|
||||||
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
|
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
|
||||||
|
call @printmem8(%A) : (memref<?xf32>) -> ()
|
||||||
call @printmem(%A) : (memref<?xf32>) -> ()
|
|
||||||
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
|
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
|
||||||
|
|
||||||
call @scatter8(%A, %idx, %some, %val)
|
call @scatter8(%A, %idx, %some, %val)
|
||||||
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
|
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
|
||||||
|
call @printmem8(%A) : (memref<?xf32>) -> ()
|
||||||
call @printmem(%A) : (memref<?xf32>) -> ()
|
|
||||||
// CHECK: ( 1, 2, 2, 3, 4, 5, 3, 0 )
|
// CHECK: ( 1, 2, 2, 3, 4, 5, 3, 0 )
|
||||||
|
|
||||||
call @scatter8(%A, %idx, %more, %val)
|
call @scatter8(%A, %idx, %more, %val)
|
||||||
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
|
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
|
||||||
|
call @printmem8(%A) : (memref<?xf32>) -> ()
|
||||||
call @printmem(%A) : (memref<?xf32>) -> ()
|
|
||||||
// CHECK: ( 1, 2, 2, 7, 4, 5, 3, 0 )
|
// CHECK: ( 1, 2, 2, 7, 4, 5, 3, 0 )
|
||||||
|
|
||||||
call @scatter8(%A, %idx, %all, %val)
|
call @scatter8(%A, %idx, %all, %val)
|
||||||
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
|
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
|
||||||
|
call @printmem8(%A) : (memref<?xf32>) -> ()
|
||||||
call @printmem(%A) : (memref<?xf32>) -> ()
|
|
||||||
// CHECK: ( 1, 2, 4, 7, 5, 6, 3, 0 )
|
// CHECK: ( 1, 2, 4, 7, 5, 6, 3, 0 )
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -134,11 +134,9 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper that returns vector of pointers given a base and an index vector.
|
// Helper that returns the base address of a memref.
|
||||||
LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
|
LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
LLVMTypeConverter &typeConverter, Location loc,
|
Value memref, MemRefType memRefType, Value &base) {
|
||||||
Value memref, Value indices, MemRefType memRefType,
|
|
||||||
VectorType vType, Type iType, Value &ptrs) {
|
|
||||||
// Inspect stride and offset structure.
|
// Inspect stride and offset structure.
|
||||||
//
|
//
|
||||||
// TODO: flat memory only for now, generalize
|
// TODO: flat memory only for now, generalize
|
||||||
|
|
@ -149,13 +147,31 @@ LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
|
||||||
if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
|
if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
|
||||||
offset != 0 || memRefType.getMemorySpace() != 0)
|
offset != 0 || memRefType.getMemorySpace() != 0)
|
||||||
return failure();
|
return failure();
|
||||||
|
base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// Create a vector of pointers from base and indices.
|
// Helper that returns a pointer given a memref base.
|
||||||
MemRefDescriptor memRefDescriptor(memref);
|
LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
|
Value memref, MemRefType memRefType, Value &ptr) {
|
||||||
int64_t size = vType.getDimSize(0);
|
Value base;
|
||||||
auto pType = memRefDescriptor.getElementType();
|
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
|
||||||
auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size);
|
return failure();
|
||||||
|
auto pType = MemRefDescriptor(memref).getElementType();
|
||||||
|
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper that returns vector of pointers given a memref base and an index
|
||||||
|
// vector.
|
||||||
|
LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
Value memref, Value indices, MemRefType memRefType,
|
||||||
|
VectorType vType, Type iType, Value &ptrs) {
|
||||||
|
Value base;
|
||||||
|
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
|
||||||
|
return failure();
|
||||||
|
auto pType = MemRefDescriptor(memref).getElementType();
|
||||||
|
auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
|
||||||
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
|
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
@ -305,9 +321,8 @@ public:
|
||||||
VectorType vType = gather.getResultVectorType();
|
VectorType vType = gather.getResultVectorType();
|
||||||
Type iType = gather.getIndicesVectorType().getElementType();
|
Type iType = gather.getIndicesVectorType().getElementType();
|
||||||
Value ptrs;
|
Value ptrs;
|
||||||
if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
|
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
|
||||||
adaptor.indices(), gather.getMemRefType(), vType,
|
gather.getMemRefType(), vType, iType, ptrs)))
|
||||||
iType, ptrs)))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Replace with the gather intrinsic.
|
// Replace with the gather intrinsic.
|
||||||
|
|
@ -344,9 +359,8 @@ public:
|
||||||
VectorType vType = scatter.getValueVectorType();
|
VectorType vType = scatter.getValueVectorType();
|
||||||
Type iType = scatter.getIndicesVectorType().getElementType();
|
Type iType = scatter.getIndicesVectorType().getElementType();
|
||||||
Value ptrs;
|
Value ptrs;
|
||||||
if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
|
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
|
||||||
adaptor.indices(), scatter.getMemRefType(), vType,
|
scatter.getMemRefType(), vType, iType, ptrs)))
|
||||||
iType, ptrs)))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Replace with the scatter intrinsic.
|
// Replace with the scatter intrinsic.
|
||||||
|
|
@ -357,6 +371,60 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Conversion pattern for a vector.expandload.
|
||||||
|
class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
|
||||||
|
public:
|
||||||
|
explicit VectorExpandLoadOpConversion(MLIRContext *context,
|
||||||
|
LLVMTypeConverter &typeConverter)
|
||||||
|
: ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
|
||||||
|
typeConverter) {}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto expand = cast<vector::ExpandLoadOp>(op);
|
||||||
|
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
|
||||||
|
|
||||||
|
Value ptr;
|
||||||
|
if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
|
||||||
|
ptr)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto vType = expand.getResultVectorType();
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
|
||||||
|
op, typeConverter.convertType(vType), ptr, adaptor.mask(),
|
||||||
|
adaptor.pass_thru());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Conversion pattern for a vector.compressstore.
|
||||||
|
class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
|
||||||
|
public:
|
||||||
|
explicit VectorCompressStoreOpConversion(MLIRContext *context,
|
||||||
|
LLVMTypeConverter &typeConverter)
|
||||||
|
: ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
|
||||||
|
context, typeConverter) {}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto compress = cast<vector::CompressStoreOp>(op);
|
||||||
|
auto adaptor = vector::CompressStoreOpAdaptor(operands);
|
||||||
|
|
||||||
|
Value ptr;
|
||||||
|
if (failed(getBasePtr(rewriter, loc, adaptor.base(),
|
||||||
|
compress.getMemRefType(), ptr)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
|
||||||
|
op, adaptor.value(), ptr, adaptor.mask());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Conversion pattern for all vector reductions.
|
/// Conversion pattern for all vector reductions.
|
||||||
class VectorReductionOpConversion : public ConvertToLLVMPattern {
|
class VectorReductionOpConversion : public ConvertToLLVMPattern {
|
||||||
public:
|
public:
|
||||||
|
|
@ -1274,7 +1342,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||||
VectorTransferConversion<TransferWriteOp>,
|
VectorTransferConversion<TransferWriteOp>,
|
||||||
VectorTypeCastOpConversion,
|
VectorTypeCastOpConversion,
|
||||||
VectorGatherOpConversion,
|
VectorGatherOpConversion,
|
||||||
VectorScatterOpConversion>(ctx, converter);
|
VectorScatterOpConversion,
|
||||||
|
VectorExpandLoadOpConversion,
|
||||||
|
VectorCompressStoreOpConversion>(ctx, converter);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1898,6 +1898,41 @@ static LogicalResult verify(ScatterOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ExpandLoadOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult verify(ExpandLoadOp op) {
|
||||||
|
VectorType maskVType = op.getMaskVectorType();
|
||||||
|
VectorType passVType = op.getPassThruVectorType();
|
||||||
|
VectorType resVType = op.getResultVectorType();
|
||||||
|
|
||||||
|
if (resVType.getElementType() != op.getMemRefType().getElementType())
|
||||||
|
return op.emitOpError("base and result element type should match");
|
||||||
|
|
||||||
|
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
|
||||||
|
return op.emitOpError("expected result dim to match mask dim");
|
||||||
|
if (resVType != passVType)
|
||||||
|
return op.emitOpError("expected pass_thru of same type as result type");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// CompressStoreOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult verify(CompressStoreOp op) {
|
||||||
|
VectorType maskVType = op.getMaskVectorType();
|
||||||
|
VectorType valueVType = op.getValueVectorType();
|
||||||
|
|
||||||
|
if (valueVType.getElementType() != op.getMemRefType().getElementType())
|
||||||
|
return op.emitOpError("base and value element type should match");
|
||||||
|
|
||||||
|
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
|
||||||
|
return op.emitOpError("expected value dim to match mask dim");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ShapeCastOp
|
// ShapeCastOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
||||||
|
|
@ -989,3 +989,23 @@ func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>
|
||||||
// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr<float>, !llvm.vec<3 x i32>) -> !llvm.vec<3 x ptr<float>>
|
// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr<float>, !llvm.vec<3 x i32>) -> !llvm.vec<3 x ptr<float>>
|
||||||
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<3 x float>, !llvm.vec<3 x i1> into !llvm.vec<3 x ptr<float>>
|
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<3 x float>, !llvm.vec<3 x i1> into !llvm.vec<3 x ptr<float>>
|
||||||
// CHECK: llvm.return
|
// CHECK: llvm.return
|
||||||
|
|
||||||
|
func @expand_load_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> {
|
||||||
|
%0 = vector.expandload %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
|
||||||
|
return %0 : vector<11xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @expand_load_op
|
||||||
|
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<float>) -> !llvm.ptr<float>
|
||||||
|
// CHECK: %[[E:.*]] = "llvm.intr.masked.expandload"(%[[P]], %{{.*}}, %{{.*}}) : (!llvm.ptr<float>, !llvm.vec<11 x i1>, !llvm.vec<11 x float>) -> !llvm.vec<11 x float>
|
||||||
|
// CHECK: llvm.return %[[E]] : !llvm.vec<11 x float>
|
||||||
|
|
||||||
|
func @compress_store_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) {
|
||||||
|
vector.compressstore %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @compress_store_op
|
||||||
|
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<float>) -> !llvm.ptr<float>
|
||||||
|
// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (!llvm.vec<11 x float>, !llvm.ptr<float>, !llvm.vec<11 x i1>) -> ()
|
||||||
|
// CHECK: llvm.return
|
||||||
|
|
|
||||||
|
|
@ -1240,3 +1240,38 @@ func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
|
||||||
// expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}}
|
// expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}}
|
||||||
vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref<?xf32>
|
vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
|
||||||
|
// expected-error@+1 {{'vector.expandload' op base and result element type should match}}
|
||||||
|
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
|
||||||
|
// expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}}
|
||||||
|
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<17xf32>) {
|
||||||
|
// expected-error@+1 {{'vector.expandload' op expected pass_thru of same type as result type}}
|
||||||
|
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||||
|
// expected-error@+1 {{'vector.compressstore' op base and value element type should match}}
|
||||||
|
vector.compressstore %base, %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
|
||||||
|
// expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}}
|
||||||
|
vector.compressstore %base, %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -379,3 +379,12 @@ func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask:
|
||||||
vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
|
vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @expand_and_compress
|
||||||
|
func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
|
||||||
|
// CHECK: %[[X:.*]] = vector.expandload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||||
|
%0 = vector.expandload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||||
|
// CHECK: vector.compressstore %{{.*}}, %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||||
|
vector.compressstore %base, %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -237,8 +237,8 @@ llvm.func @matrix_intrinsics(%A: !llvm.vec<64 x float>, %B: !llvm.vec<48 x float
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @masked_intrinsics
|
// CHECK-LABEL: @masked_load_store_intrinsics
|
||||||
llvm.func @masked_intrinsics(%A: !llvm.ptr<vec<7 x float>>, %mask: !llvm.vec<7 x i1>) {
|
llvm.func @masked_load_store_intrinsics(%A: !llvm.ptr<vec<7 x float>>, %mask: !llvm.vec<7 x i1>) {
|
||||||
// CHECK: call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> undef)
|
// CHECK: call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> undef)
|
||||||
%a = llvm.intr.masked.load %A, %mask { alignment = 1: i32} :
|
%a = llvm.intr.masked.load %A, %mask { alignment = 1: i32} :
|
||||||
(!llvm.ptr<vec<7 x float>>, !llvm.vec<7 x i1>) -> !llvm.vec<7 x float>
|
(!llvm.ptr<vec<7 x float>>, !llvm.vec<7 x i1>) -> !llvm.vec<7 x float>
|
||||||
|
|
@ -265,6 +265,17 @@ llvm.func @masked_gather_scatter_intrinsics(%M: !llvm.vec<7 x ptr<float>>, %mask
|
||||||
llvm.return
|
llvm.return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @masked_expand_compress_intrinsics
|
||||||
|
llvm.func @masked_expand_compress_intrinsics(%ptr: !llvm.ptr<float>, %mask: !llvm.vec<7 x i1>, %passthru: !llvm.vec<7 x float>) {
|
||||||
|
// CHECK: call <7 x float> @llvm.masked.expandload.v7f32(float* %{{.*}}, <7 x i1> %{{.*}}, <7 x float> %{{.*}})
|
||||||
|
%0 = "llvm.intr.masked.expandload"(%ptr, %mask, %passthru)
|
||||||
|
: (!llvm.ptr<float>, !llvm.vec<7 x i1>, !llvm.vec<7 x float>) -> (!llvm.vec<7 x float>)
|
||||||
|
// CHECK: call void @llvm.masked.compressstore.v7f32(<7 x float> %{{.*}}, float* %{{.*}}, <7 x i1> %{{.*}})
|
||||||
|
"llvm.intr.masked.compressstore"(%0, %ptr, %mask)
|
||||||
|
: (!llvm.vec<7 x float>, !llvm.ptr<float>, !llvm.vec<7 x i1>) -> ()
|
||||||
|
llvm.return
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @memcpy_test
|
// CHECK-LABEL: @memcpy_test
|
||||||
llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm.ptr<i8>, %arg3: !llvm.ptr<i8>) {
|
llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm.ptr<i8>, %arg3: !llvm.ptr<i8>) {
|
||||||
// CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}})
|
// CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}})
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue