forked from OSchip/llvm-project
[mlir] [VectorOps] Add expand/compress operations to Vector dialect
Introduces the expand and compress operations to the Vector dialect (important memory operations for sparse computations), together with a first reference implementation that lowers to the LLVM IR dialect to enable running on CPU (and other targets that support the corresponding LLVM IR intrinsics). Reviewed By: reidtatge Differential Revision: https://reviews.llvm.org/D84888
This commit is contained in:
parent
3c0f347002
commit
e8dcf5f87d
|
|
@ -1042,6 +1042,16 @@ def LLVM_masked_scatter
|
|||
"type($value) `,` type($mask) `into` type($ptrs)";
|
||||
}
|
||||
|
||||
/// Create a call to Masked Expand Load intrinsic.
|
||||
def LLVM_masked_expandload
|
||||
: LLVM_IntrOp<"masked.expandload", [0], [], [], 1>,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
/// Create a call to Masked Compress Store intrinsic.
|
||||
def LLVM_masked_compressstore
|
||||
: LLVM_IntrOp<"masked.compressstore", [], [0], [], 0>,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
//
|
||||
// Atomic operations.
|
||||
//
|
||||
|
|
|
|||
|
|
@ -1158,7 +1158,7 @@ def Vector_GatherOp :
|
|||
Variadic<VectorOfRank<[1]>>:$pass_thru)>,
|
||||
Results<(outs VectorOfRank<[1]>:$result)> {
|
||||
|
||||
let summary = "gathers elements from memory into a vector as defined by an index vector";
|
||||
let summary = "gathers elements from memory into a vector as defined by an index vector and mask";
|
||||
|
||||
let description = [{
|
||||
The gather operation gathers elements from memory into a 1-D vector as
|
||||
|
|
@ -1186,7 +1186,6 @@ def Vector_GatherOp :
|
|||
%g = vector.gather %base, %indices, %mask, %pass_thru
|
||||
: (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
|
||||
```
|
||||
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
|
|
@ -1217,7 +1216,7 @@ def Vector_ScatterOp :
|
|||
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||
VectorOfRank<[1]>:$value)> {
|
||||
|
||||
let summary = "scatters elements from a vector into memory as defined by an index vector";
|
||||
let summary = "scatters elements from a vector into memory as defined by an index vector and mask";
|
||||
|
||||
let description = [{
|
||||
The scatter operation scatters elements from a 1-D vector into memory as
|
||||
|
|
@ -1265,6 +1264,108 @@ def Vector_ScatterOp :
|
|||
"type($indices) `,` type($mask) `,` type($value) `into` type($base)";
|
||||
}
|
||||
|
||||
def Vector_ExpandLoadOp :
|
||||
Vector_Op<"expandload">,
|
||||
Arguments<(ins AnyMemRef:$base,
|
||||
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||
VectorOfRank<[1]>:$pass_thru)>,
|
||||
Results<(outs VectorOfRank<[1]>:$result)> {
|
||||
|
||||
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
|
||||
|
||||
let description = [{
|
||||
The expand load reads elements from memory into a 1-D vector as defined
|
||||
by a base and a 1-D mask vector. When the mask is set, the next element
|
||||
is read from memory. Otherwise, the corresponding element is taken from
|
||||
a 1-D pass-through vector. Informally the semantics are:
|
||||
```
|
||||
index = base
|
||||
result[0] := mask[0] ? MEM[index++] : pass_thru[0]
|
||||
result[1] := mask[1] ? MEM[index++] : pass_thru[1]
|
||||
etc.
|
||||
```
|
||||
Note that the index increment is done conditionally.
|
||||
|
||||
The expand load can be used directly where applicable, or can be used
|
||||
during progressively lowering to bring other memory operations closer to
|
||||
hardware ISA support for an expand. The semantics of the operation closely
|
||||
correspond to those of the `llvm.masked.expandload`
|
||||
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = vector.expandload %base, %mask, %pass_thru
|
||||
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
|
||||
```
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return base().getType().cast<MemRefType>();
|
||||
}
|
||||
VectorType getMaskVectorType() {
|
||||
return mask().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getPassThruVectorType() {
|
||||
return pass_thru().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getResultVectorType() {
|
||||
return result().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
|
||||
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
|
||||
}
|
||||
|
||||
def Vector_CompressStoreOp :
|
||||
Vector_Op<"compressstore">,
|
||||
Arguments<(ins AnyMemRef:$base,
|
||||
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||
VectorOfRank<[1]>:$value)> {
|
||||
|
||||
let summary = "writes elements selectively from a vector as defined by a mask";
|
||||
|
||||
let description = [{
|
||||
The compress store operation writes elements from a 1-D vector into memory
|
||||
as defined by a base and a 1-D mask vector. When the mask is set, the
|
||||
corresponding element from the vector is written next to memory. Otherwise,
|
||||
no action is taken for the element. Informally the semantics are:
|
||||
```
|
||||
index = base
|
||||
if (mask[0]) MEM[index++] = value[0]
|
||||
if (mask[1]) MEM[index++] = value[1]
|
||||
etc.
|
||||
```
|
||||
Note that the index increment is done conditionally.
|
||||
|
||||
The compress store can be used directly where applicable, or can be used
|
||||
during progressively lowering to bring other memory operations closer to
|
||||
hardware ISA support for a compress. The semantics of the operation closely
|
||||
correspond to those of the `llvm.masked.compressstore`
|
||||
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
vector.compressstore %base, %mask, %value
|
||||
: memref<?xf32>, vector<8xi1>, vector<8xf32>
|
||||
```
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return base().getType().cast<MemRefType>();
|
||||
}
|
||||
VectorType getMaskVectorType() {
|
||||
return mask().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getValueVectorType() {
|
||||
return value().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
|
||||
"type($base) `,` type($mask) `,` type($value)";
|
||||
}
|
||||
|
||||
def Vector_ShapeCastOp :
|
||||
Vector_Op<"shape_cast", [NoSideEffect]>,
|
||||
Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,90 @@
|
|||
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
func @compress16(%base: memref<?xf32>,
|
||||
%mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||
vector.compressstore %base, %mask, %value
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @printmem16(%A: memref<?xf32>) {
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
%c16 = constant 16: index
|
||||
%z = constant 0.0: f32
|
||||
%m = vector.broadcast %z : f32 to vector<16xf32>
|
||||
%mem = scf.for %i = %c0 to %c16 step %c1
|
||||
iter_args(%m_iter = %m) -> (vector<16xf32>) {
|
||||
%c = load %A[%i] : memref<?xf32>
|
||||
%i32 = index_cast %i : index to i32
|
||||
%m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32>
|
||||
scf.yield %m_new : vector<16xf32>
|
||||
}
|
||||
vector.print %mem : vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @entry() {
|
||||
// Set up memory.
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
%c16 = constant 16: index
|
||||
%A = alloc(%c16) : memref<?xf32>
|
||||
%z = constant 0.0: f32
|
||||
%v = vector.broadcast %z : f32 to vector<16xf32>
|
||||
%value = scf.for %i = %c0 to %c16 step %c1
|
||||
iter_args(%v_iter = %v) -> (vector<16xf32>) {
|
||||
store %z, %A[%i] : memref<?xf32>
|
||||
%i32 = index_cast %i : index to i32
|
||||
%fi = sitofp %i32 : i32 to f32
|
||||
%v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32>
|
||||
scf.yield %v_new : vector<16xf32>
|
||||
}
|
||||
|
||||
// Set up masks.
|
||||
%f = constant 0: i1
|
||||
%t = constant 1: i1
|
||||
%none = vector.constant_mask [0] : vector<16xi1>
|
||||
%all = vector.constant_mask [16] : vector<16xi1>
|
||||
%some1 = vector.constant_mask [4] : vector<16xi1>
|
||||
%0 = vector.insert %f, %some1[0] : i1 into vector<16xi1>
|
||||
%1 = vector.insert %t, %0[7] : i1 into vector<16xi1>
|
||||
%2 = vector.insert %t, %1[11] : i1 into vector<16xi1>
|
||||
%3 = vector.insert %t, %2[13] : i1 into vector<16xi1>
|
||||
%some2 = vector.insert %t, %3[15] : i1 into vector<16xi1>
|
||||
%some3 = vector.insert %f, %some2[2] : i1 into vector<16xi1>
|
||||
|
||||
//
|
||||
// Expanding load tests.
|
||||
//
|
||||
|
||||
call @compress16(%A, %none, %value)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||
|
||||
call @compress16(%A, %all, %value)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK-NEXT: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||
|
||||
call @compress16(%A, %some3, %value)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK-NEXT: ( 1, 3, 7, 11, 13, 15, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||
|
||||
call @compress16(%A, %some2, %value)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK-NEXT: ( 1, 2, 3, 7, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||
|
||||
call @compress16(%A, %some1, %value)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
|
||||
call @printmem16(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
|
||||
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
func @expand16(%base: memref<?xf32>,
|
||||
%mask: vector<16xi1>,
|
||||
%pass_thru: vector<16xf32>) -> vector<16xf32> {
|
||||
%e = vector.expandload %base, %mask, %pass_thru
|
||||
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
return %e : vector<16xf32>
|
||||
}
|
||||
|
||||
func @entry() {
|
||||
// Set up memory.
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
%c16 = constant 16: index
|
||||
%A = alloc(%c16) : memref<?xf32>
|
||||
scf.for %i = %c0 to %c16 step %c1 {
|
||||
%i32 = index_cast %i : index to i32
|
||||
%fi = sitofp %i32 : i32 to f32
|
||||
store %fi, %A[%i] : memref<?xf32>
|
||||
}
|
||||
|
||||
// Set up pass thru vector.
|
||||
%u = constant -7.0: f32
|
||||
%v = constant 7.7: f32
|
||||
%pass = vector.broadcast %u : f32 to vector<16xf32>
|
||||
|
||||
// Set up masks.
|
||||
%f = constant 0: i1
|
||||
%t = constant 1: i1
|
||||
%none = vector.constant_mask [0] : vector<16xi1>
|
||||
%all = vector.constant_mask [16] : vector<16xi1>
|
||||
%some1 = vector.constant_mask [4] : vector<16xi1>
|
||||
%0 = vector.insert %f, %some1[0] : i1 into vector<16xi1>
|
||||
%1 = vector.insert %t, %0[7] : i1 into vector<16xi1>
|
||||
%2 = vector.insert %t, %1[11] : i1 into vector<16xi1>
|
||||
%3 = vector.insert %t, %2[13] : i1 into vector<16xi1>
|
||||
%some2 = vector.insert %t, %3[15] : i1 into vector<16xi1>
|
||||
%some3 = vector.insert %f, %some2[2] : i1 into vector<16xi1>
|
||||
|
||||
//
|
||||
// Expanding load tests.
|
||||
//
|
||||
|
||||
%e1 = call @expand16(%A, %none, %pass)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||
vector.print %e1 : vector<16xf32>
|
||||
// CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
|
||||
|
||||
%e2 = call @expand16(%A, %all, %pass)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||
vector.print %e2 : vector<16xf32>
|
||||
// CHECK-NEXT: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
|
||||
|
||||
%e3 = call @expand16(%A, %some1, %pass)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||
vector.print %e3 : vector<16xf32>
|
||||
// CHECK-NEXT: ( 0, 1, 2, 3, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
|
||||
|
||||
%e4 = call @expand16(%A, %some2, %pass)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||
vector.print %e4 : vector<16xf32>
|
||||
// CHECK-NEXT: ( -7, 0, 1, 2, -7, -7, -7, 3, -7, -7, -7, 4, -7, 5, -7, 6 )
|
||||
|
||||
%e5 = call @expand16(%A, %some3, %pass)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||
vector.print %e5 : vector<16xf32>
|
||||
// CHECK-NEXT: ( -7, 0, -7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, -7, 5 )
|
||||
|
||||
%4 = vector.insert %v, %pass[1] : f32 into vector<16xf32>
|
||||
%5 = vector.insert %v, %4[2] : f32 into vector<16xf32>
|
||||
%alt_pass = vector.insert %v, %5[14] : f32 into vector<16xf32>
|
||||
%e6 = call @expand16(%A, %some3, %alt_pass)
|
||||
: (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
|
||||
vector.print %e6 : vector<16xf32>
|
||||
// CHECK-NEXT: ( -7, 0, 7.7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, 7.7, 5 )
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -11,34 +11,20 @@ func @scatter8(%base: memref<?xf32>,
|
|||
return
|
||||
}
|
||||
|
||||
func @printmem(%A: memref<?xf32>) {
|
||||
%f = constant 0.0: f32
|
||||
%0 = vector.broadcast %f : f32 to vector<8xf32>
|
||||
%1 = constant 0: index
|
||||
%2 = load %A[%1] : memref<?xf32>
|
||||
%3 = vector.insert %2, %0[0] : f32 into vector<8xf32>
|
||||
%4 = constant 1: index
|
||||
%5 = load %A[%4] : memref<?xf32>
|
||||
%6 = vector.insert %5, %3[1] : f32 into vector<8xf32>
|
||||
%7 = constant 2: index
|
||||
%8 = load %A[%7] : memref<?xf32>
|
||||
%9 = vector.insert %8, %6[2] : f32 into vector<8xf32>
|
||||
%10 = constant 3: index
|
||||
%11 = load %A[%10] : memref<?xf32>
|
||||
%12 = vector.insert %11, %9[3] : f32 into vector<8xf32>
|
||||
%13 = constant 4: index
|
||||
%14 = load %A[%13] : memref<?xf32>
|
||||
%15 = vector.insert %14, %12[4] : f32 into vector<8xf32>
|
||||
%16 = constant 5: index
|
||||
%17 = load %A[%16] : memref<?xf32>
|
||||
%18 = vector.insert %17, %15[5] : f32 into vector<8xf32>
|
||||
%19 = constant 6: index
|
||||
%20 = load %A[%19] : memref<?xf32>
|
||||
%21 = vector.insert %20, %18[6] : f32 into vector<8xf32>
|
||||
%22 = constant 7: index
|
||||
%23 = load %A[%22] : memref<?xf32>
|
||||
%24 = vector.insert %23, %21[7] : f32 into vector<8xf32>
|
||||
vector.print %24 : vector<8xf32>
|
||||
func @printmem8(%A: memref<?xf32>) {
|
||||
%c0 = constant 0: index
|
||||
%c1 = constant 1: index
|
||||
%c8 = constant 8: index
|
||||
%z = constant 0.0: f32
|
||||
%m = vector.broadcast %z : f32 to vector<8xf32>
|
||||
%mem = scf.for %i = %c0 to %c8 step %c1
|
||||
iter_args(%m_iter = %m) -> (vector<8xf32>) {
|
||||
%c = load %A[%i] : memref<?xf32>
|
||||
%i32 = index_cast %i : index to i32
|
||||
%m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<8xf32>
|
||||
scf.yield %m_new : vector<8xf32>
|
||||
}
|
||||
vector.print %mem : vector<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -104,31 +90,27 @@ func @entry() {
|
|||
vector.print %idx : vector<8xi32>
|
||||
// CHECK: ( 7, 0, 1, 6, 2, 4, 5, 3 )
|
||||
|
||||
call @printmem(%A) : (memref<?xf32>) -> ()
|
||||
call @printmem8(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
|
||||
|
||||
call @scatter8(%A, %idx, %none, %val)
|
||||
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
|
||||
|
||||
call @printmem(%A) : (memref<?xf32>) -> ()
|
||||
call @printmem8(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
|
||||
|
||||
call @scatter8(%A, %idx, %some, %val)
|
||||
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
|
||||
|
||||
call @printmem(%A) : (memref<?xf32>) -> ()
|
||||
call @printmem8(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 1, 2, 2, 3, 4, 5, 3, 0 )
|
||||
|
||||
call @scatter8(%A, %idx, %more, %val)
|
||||
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
|
||||
|
||||
call @printmem(%A) : (memref<?xf32>) -> ()
|
||||
call @printmem8(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 1, 2, 2, 7, 4, 5, 3, 0 )
|
||||
|
||||
call @scatter8(%A, %idx, %all, %val)
|
||||
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
|
||||
|
||||
call @printmem(%A) : (memref<?xf32>) -> ()
|
||||
call @printmem8(%A) : (memref<?xf32>) -> ()
|
||||
// CHECK: ( 1, 2, 4, 7, 5, 6, 3, 0 )
|
||||
|
||||
return
|
||||
|
|
|
|||
|
|
@ -134,11 +134,9 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
|
|||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns vector of pointers given a base and an index vector.
|
||||
LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
Value memref, Value indices, MemRefType memRefType,
|
||||
VectorType vType, Type iType, Value &ptrs) {
|
||||
// Helper that returns the base address of a memref.
|
||||
LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value memref, MemRefType memRefType, Value &base) {
|
||||
// Inspect stride and offset structure.
|
||||
//
|
||||
// TODO: flat memory only for now, generalize
|
||||
|
|
@ -149,13 +147,31 @@ LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
|
|||
if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
|
||||
offset != 0 || memRefType.getMemorySpace() != 0)
|
||||
return failure();
|
||||
base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Create a vector of pointers from base and indices.
|
||||
MemRefDescriptor memRefDescriptor(memref);
|
||||
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
|
||||
int64_t size = vType.getDimSize(0);
|
||||
auto pType = memRefDescriptor.getElementType();
|
||||
auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size);
|
||||
// Helper that returns a pointer given a memref base.
|
||||
LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value memref, MemRefType memRefType, Value &ptr) {
|
||||
Value base;
|
||||
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
|
||||
return failure();
|
||||
auto pType = MemRefDescriptor(memref).getElementType();
|
||||
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns vector of pointers given a memref base and an index
|
||||
// vector.
|
||||
LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value memref, Value indices, MemRefType memRefType,
|
||||
VectorType vType, Type iType, Value &ptrs) {
|
||||
Value base;
|
||||
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
|
||||
return failure();
|
||||
auto pType = MemRefDescriptor(memref).getElementType();
|
||||
auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
|
||||
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
|
||||
return success();
|
||||
}
|
||||
|
|
@ -305,9 +321,8 @@ public:
|
|||
VectorType vType = gather.getResultVectorType();
|
||||
Type iType = gather.getIndicesVectorType().getElementType();
|
||||
Value ptrs;
|
||||
if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
|
||||
adaptor.indices(), gather.getMemRefType(), vType,
|
||||
iType, ptrs)))
|
||||
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
|
||||
gather.getMemRefType(), vType, iType, ptrs)))
|
||||
return failure();
|
||||
|
||||
// Replace with the gather intrinsic.
|
||||
|
|
@ -344,9 +359,8 @@ public:
|
|||
VectorType vType = scatter.getValueVectorType();
|
||||
Type iType = scatter.getIndicesVectorType().getElementType();
|
||||
Value ptrs;
|
||||
if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
|
||||
adaptor.indices(), scatter.getMemRefType(), vType,
|
||||
iType, ptrs)))
|
||||
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
|
||||
scatter.getMemRefType(), vType, iType, ptrs)))
|
||||
return failure();
|
||||
|
||||
// Replace with the scatter intrinsic.
|
||||
|
|
@ -357,6 +371,60 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.expandload.
|
||||
class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorExpandLoadOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto expand = cast<vector::ExpandLoadOp>(op);
|
||||
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
|
||||
|
||||
Value ptr;
|
||||
if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
|
||||
ptr)))
|
||||
return failure();
|
||||
|
||||
auto vType = expand.getResultVectorType();
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
|
||||
op, typeConverter.convertType(vType), ptr, adaptor.mask(),
|
||||
adaptor.pass_thru());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.compressstore.
|
||||
class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorCompressStoreOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto compress = cast<vector::CompressStoreOp>(op);
|
||||
auto adaptor = vector::CompressStoreOpAdaptor(operands);
|
||||
|
||||
Value ptr;
|
||||
if (failed(getBasePtr(rewriter, loc, adaptor.base(),
|
||||
compress.getMemRefType(), ptr)))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
|
||||
op, adaptor.value(), ptr, adaptor.mask());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for all vector reductions.
|
||||
class VectorReductionOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
|
|
@ -1274,7 +1342,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
|||
VectorTransferConversion<TransferWriteOp>,
|
||||
VectorTypeCastOpConversion,
|
||||
VectorGatherOpConversion,
|
||||
VectorScatterOpConversion>(ctx, converter);
|
||||
VectorScatterOpConversion,
|
||||
VectorExpandLoadOpConversion,
|
||||
VectorCompressStoreOpConversion>(ctx, converter);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1898,6 +1898,41 @@ static LogicalResult verify(ScatterOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExpandLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(ExpandLoadOp op) {
|
||||
VectorType maskVType = op.getMaskVectorType();
|
||||
VectorType passVType = op.getPassThruVectorType();
|
||||
VectorType resVType = op.getResultVectorType();
|
||||
|
||||
if (resVType.getElementType() != op.getMemRefType().getElementType())
|
||||
return op.emitOpError("base and result element type should match");
|
||||
|
||||
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
|
||||
return op.emitOpError("expected result dim to match mask dim");
|
||||
if (resVType != passVType)
|
||||
return op.emitOpError("expected pass_thru of same type as result type");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CompressStoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(CompressStoreOp op) {
|
||||
VectorType maskVType = op.getMaskVectorType();
|
||||
VectorType valueVType = op.getValueVectorType();
|
||||
|
||||
if (valueVType.getElementType() != op.getMemRefType().getElementType())
|
||||
return op.emitOpError("base and value element type should match");
|
||||
|
||||
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
|
||||
return op.emitOpError("expected value dim to match mask dim");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
|
|
@ -989,3 +989,23 @@ func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>
|
|||
// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr<float>, !llvm.vec<3 x i32>) -> !llvm.vec<3 x ptr<float>>
|
||||
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<3 x float>, !llvm.vec<3 x i1> into !llvm.vec<3 x ptr<float>>
|
||||
// CHECK: llvm.return
|
||||
|
||||
func @expand_load_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> {
|
||||
%0 = vector.expandload %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
|
||||
return %0 : vector<11xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @expand_load_op
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<float>) -> !llvm.ptr<float>
|
||||
// CHECK: %[[E:.*]] = "llvm.intr.masked.expandload"(%[[P]], %{{.*}}, %{{.*}}) : (!llvm.ptr<float>, !llvm.vec<11 x i1>, !llvm.vec<11 x float>) -> !llvm.vec<11 x float>
|
||||
// CHECK: llvm.return %[[E]] : !llvm.vec<11 x float>
|
||||
|
||||
func @compress_store_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) {
|
||||
vector.compressstore %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @compress_store_op
|
||||
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<float>) -> !llvm.ptr<float>
|
||||
// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (!llvm.vec<11 x float>, !llvm.ptr<float>, !llvm.vec<11 x i1>) -> ()
|
||||
// CHECK: llvm.return
|
||||
|
|
|
|||
|
|
@ -1240,3 +1240,38 @@ func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
|
|||
// expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}}
|
||||
vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
|
||||
// expected-error@+1 {{'vector.expandload' op base and result element type should match}}
|
||||
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
|
||||
// expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}}
|
||||
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<17xf32>) {
|
||||
// expected-error@+1 {{'vector.expandload' op expected pass_thru of same type as result type}}
|
||||
%0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||
// expected-error@+1 {{'vector.compressstore' op base and value element type should match}}
|
||||
vector.compressstore %base, %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
|
||||
// expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}}
|
||||
vector.compressstore %base, %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
|
||||
}
|
||||
|
|
|
|||
|
|
@ -379,3 +379,12 @@ func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask:
|
|||
vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @expand_and_compress
|
||||
func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
|
||||
// CHECK: %[[X:.*]] = vector.expandload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
%0 = vector.expandload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
// CHECK: vector.compressstore %{{.*}}, %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
vector.compressstore %base, %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -237,8 +237,8 @@ llvm.func @matrix_intrinsics(%A: !llvm.vec<64 x float>, %B: !llvm.vec<48 x float
|
|||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @masked_intrinsics
|
||||
llvm.func @masked_intrinsics(%A: !llvm.ptr<vec<7 x float>>, %mask: !llvm.vec<7 x i1>) {
|
||||
// CHECK-LABEL: @masked_load_store_intrinsics
|
||||
llvm.func @masked_load_store_intrinsics(%A: !llvm.ptr<vec<7 x float>>, %mask: !llvm.vec<7 x i1>) {
|
||||
// CHECK: call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> undef)
|
||||
%a = llvm.intr.masked.load %A, %mask { alignment = 1: i32} :
|
||||
(!llvm.ptr<vec<7 x float>>, !llvm.vec<7 x i1>) -> !llvm.vec<7 x float>
|
||||
|
|
@ -265,6 +265,17 @@ llvm.func @masked_gather_scatter_intrinsics(%M: !llvm.vec<7 x ptr<float>>, %mask
|
|||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @masked_expand_compress_intrinsics
|
||||
llvm.func @masked_expand_compress_intrinsics(%ptr: !llvm.ptr<float>, %mask: !llvm.vec<7 x i1>, %passthru: !llvm.vec<7 x float>) {
|
||||
// CHECK: call <7 x float> @llvm.masked.expandload.v7f32(float* %{{.*}}, <7 x i1> %{{.*}}, <7 x float> %{{.*}})
|
||||
%0 = "llvm.intr.masked.expandload"(%ptr, %mask, %passthru)
|
||||
: (!llvm.ptr<float>, !llvm.vec<7 x i1>, !llvm.vec<7 x float>) -> (!llvm.vec<7 x float>)
|
||||
// CHECK: call void @llvm.masked.compressstore.v7f32(<7 x float> %{{.*}}, float* %{{.*}}, <7 x i1> %{{.*}})
|
||||
"llvm.intr.masked.compressstore"(%0, %ptr, %mask)
|
||||
: (!llvm.vec<7 x float>, !llvm.ptr<float>, !llvm.vec<7 x i1>) -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @memcpy_test
|
||||
llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm.ptr<i8>, %arg3: !llvm.ptr<i8>) {
|
||||
// CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}})
|
||||
|
|
|
|||
Loading…
Reference in New Issue