[flang][fir] Add array value operations.
We lower expressions with rank > 0 to a set of high-level array operations. These operations are then analyzed and refined to more primitve operations in subsequent pass(es). This patch upstreams these array operations and some other helper ops. Authors: Eric Schweitz, Rajan Walia, Kiran Chandramohan, et.al. https://github.com/flang-compiler/f18-llvm-project/pull/565 Differential Revision: https://reviews.llvm.org/D97421
This commit is contained in:
parent
e890fffcab
commit
67360decc3
|
|
@ -19,7 +19,6 @@ include "mlir/Interfaces/CallInterfaces.td"
|
|||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
include "flang/Optimizer/Dialect/FIRTypes.td"
|
||||
|
||||
// Base class for FIR operations.
|
||||
|
|
@ -1495,9 +1494,263 @@ def fir_BoxTypeDescOp : fir_SimpleOneResultOp<"box_tdesc", [NoSideEffect]> {
|
|||
let results = (outs fir_TypeDescType);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Array value operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def fir_ArrayLoadOp : fir_Op<"array_load", [AttrSizedOperandSegments]> {
|
||||
|
||||
let summary = "Load an array as a value.";
|
||||
|
||||
let description = [{
|
||||
Load an entire array as a single SSA value.
|
||||
|
||||
```fortran
|
||||
real :: a(o:n,p:m)
|
||||
...
|
||||
... = ... a ...
|
||||
```
|
||||
|
||||
One can use `fir.array_load` to produce an ssa-value that captures an
|
||||
immutable value of the entire array `a`, as in the Fortran array expression
|
||||
shown above. Subsequent changes to the memory containing the array do not
|
||||
alter its composite value. This operation let's one load an array as a
|
||||
value while applying a runtime shape, shift, or slice to the memory
|
||||
reference, and its semantics guarantee immutability.
|
||||
|
||||
```mlir
|
||||
%s = fir.shape_shift %o, %n, %p, %m : (index, index, index, index) -> !fir.shape<2>
|
||||
// load the entire array 'a'
|
||||
%v = fir.array_load %a(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.array<?x?xf32>
|
||||
// a fir.store here into array %a does not change %v
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<AnyRefOrBox, "", [MemRead]>:$memref,
|
||||
Optional<AnyShapeOrShiftType>:$shape,
|
||||
Optional<fir_SliceType>:$slice,
|
||||
Variadic<AnyIntegerType>:$lenParams
|
||||
);
|
||||
|
||||
let results = (outs fir_SequenceType);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$memref (`(`$shape^`)`)? (`[`$slice^`]`)? (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
std::vector<mlir::Value> getExtents();
|
||||
}];
|
||||
}
|
||||
|
||||
def fir_ArrayFetchOp : fir_Op<"array_fetch", [NoSideEffect]> {
|
||||
|
||||
let summary = "Fetch the value of an element of an array value";
|
||||
|
||||
let description = [{
|
||||
Fetch the value of an element in an array value.
|
||||
|
||||
```fortran
|
||||
real :: a(n,m)
|
||||
...
|
||||
... a ...
|
||||
... a(r,s+1) ...
|
||||
```
|
||||
|
||||
One can use `fir.array_fetch` to fetch the (implied) value of `a(i,j)` in
|
||||
an array expression as shown above. It can also be used to extract the
|
||||
element `a(r,s+1)` in the second expression.
|
||||
|
||||
```mlir
|
||||
%s = fir.shape %n, %m : (index, index) -> !fir.shape<2>
|
||||
// load the entire array 'a'
|
||||
%v = fir.array_load %a(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.array<?x?xf32>
|
||||
// fetch the value of one of the array value's elements
|
||||
%1 = fir.array_fetch %v, %i, %j : (!fir.array<?x?xf32>, index, index) -> f32
|
||||
```
|
||||
|
||||
It is only possible to use `array_fetch` on an `array_load` result value.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
fir_SequenceType:$sequence,
|
||||
Variadic<AnyCoordinateType>:$indices
|
||||
);
|
||||
|
||||
let results = (outs AnyType:$element);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$sequence `,` $indices attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let verifier = [{
|
||||
auto arrTy = sequence().getType().cast<fir::SequenceType>();
|
||||
if (indices().size() != arrTy.getDimension())
|
||||
return emitOpError("number of indices != dimension of array");
|
||||
if (element().getType() != arrTy.getEleTy())
|
||||
return emitOpError("return type does not match array");
|
||||
if (!isa<fir::ArrayLoadOp>(sequence().getDefiningOp()))
|
||||
return emitOpError("argument #0 must be result of fir.array_load");
|
||||
return mlir::success();
|
||||
}];
|
||||
}
|
||||
|
||||
def fir_ArrayUpdateOp : fir_Op<"array_update", [NoSideEffect]> {
|
||||
|
||||
let summary = "Update the value of an element of an array value";
|
||||
|
||||
let description = [{
|
||||
Updates the value of an element in an array value. A new array value is
|
||||
returned where all element values of the input array are identical except
|
||||
for the selected element which is the value passed in the update.
|
||||
|
||||
```fortran
|
||||
real :: a(n,m)
|
||||
...
|
||||
a = ...
|
||||
```
|
||||
|
||||
One can use `fir.array_update` to update the (implied) value of `a(i,j)`
|
||||
in an array expression as shown above.
|
||||
|
||||
```mlir
|
||||
%s = fir.shape %n, %m : (index, index) -> !fir.shape<2>
|
||||
// load the entire array 'a'
|
||||
%v = fir.array_load %a(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.array<?x?xf32>
|
||||
// update the value of one of the array value's elements
|
||||
// %r_{ij} = %f if (i,j) = (%i,%j), %v_{ij} otherwise
|
||||
%r = fir.array_update %v, %f, %i, %j : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
|
||||
fir.array_merge_store %v, %r to %a : !fir.ref<!fir.array<?x?xf32>>
|
||||
```
|
||||
|
||||
An array value update behaves as if a mapping function from the indices
|
||||
to the new value has been added, replacing the previous mapping. These
|
||||
mappings can be added to the ssa-value, but will not be materialized in
|
||||
memory until the `fir.array_merge_store` is performed.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
fir_SequenceType:$sequence,
|
||||
AnyType:$merge,
|
||||
Variadic<AnyCoordinateType>:$indices
|
||||
);
|
||||
|
||||
let results = (outs fir_SequenceType);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$sequence `,` $merge `,` $indices attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let verifier = [{
|
||||
auto arrTy = sequence().getType().cast<fir::SequenceType>();
|
||||
if (merge().getType() != arrTy.getEleTy())
|
||||
return emitOpError("merged value does not have element type");
|
||||
if (indices().size() != arrTy.getDimension())
|
||||
return emitOpError("number of indices != dimension of array");
|
||||
return mlir::success();
|
||||
}];
|
||||
}
|
||||
|
||||
def fir_ArrayMergeStoreOp : fir_Op<"array_merge_store", [
|
||||
TypesMatchWith<"type of 'original' matches element type of 'memref'",
|
||||
"memref", "original",
|
||||
"fir::dyn_cast_ptrOrBoxEleTy($_self)">,
|
||||
TypesMatchWith<"type of 'sequence' matches element type of 'memref'",
|
||||
"memref", "sequence",
|
||||
"fir::dyn_cast_ptrOrBoxEleTy($_self)">]> {
|
||||
|
||||
let summary = "Store merged array value to memory.";
|
||||
|
||||
let description = [{
|
||||
Store a merged array value to memory.
|
||||
|
||||
```fortran
|
||||
real :: a(n,m)
|
||||
...
|
||||
a = ...
|
||||
```
|
||||
|
||||
One can use `fir.array_merge_store` to merge/copy the value of `a` in an
|
||||
array expression as shown above.
|
||||
|
||||
```mlir
|
||||
%v = fir.array_load %a(%shape) : ...
|
||||
%r = fir.array_update %v, %f, %i, %j : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
|
||||
fir.array_merge_store %v, %r to %a : !fir.ref<!fir.array<?x?xf32>>
|
||||
```
|
||||
|
||||
This operation merges the original loaded array value, `%v`, with the
|
||||
chained updates, `%r`, and stores the result to the array at address, `%a`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
fir_SequenceType:$original,
|
||||
fir_SequenceType:$sequence,
|
||||
Arg<AnyRefOrBox, "", [MemWrite]>:$memref
|
||||
);
|
||||
|
||||
let assemblyFormat = "$original `,` $sequence `to` $memref attr-dict `:` type($memref)";
|
||||
|
||||
let verifier = [{
|
||||
if (!isa<ArrayLoadOp>(original().getDefiningOp()))
|
||||
return emitOpError("operand #0 must be result of a fir.array_load op");
|
||||
return mlir::success();
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Record and array type operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def fir_ArrayCoorOp : fir_Op<"array_coor",
|
||||
[NoSideEffect, AttrSizedOperandSegments]> {
|
||||
|
||||
let summary = "Find the coordinate of an element of an array";
|
||||
|
||||
let description = [{
|
||||
Compute the location of an element in an array when the shape of the
|
||||
array is only known at runtime.
|
||||
|
||||
This operation is intended to capture all the runtime values needed to
|
||||
compute the address of an array reference in a single high-level op. Given
|
||||
the following Fortran input:
|
||||
|
||||
```fortran
|
||||
real :: a(n,m)
|
||||
...
|
||||
... a(i,j) ...
|
||||
```
|
||||
|
||||
One can use `fir.array_coor` to determine the address of `a(i,j)`.
|
||||
|
||||
```mlir
|
||||
%s = fir.shape %n, %m : (index, index) -> !fir.shape<2>
|
||||
%1 = fir.array_coor %a(%s) %i, %j : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyRefOrBox:$memref,
|
||||
Optional<AnyShapeOrShiftType>:$shape,
|
||||
Optional<fir_SliceType>:$slice,
|
||||
Variadic<AnyCoordinateType>:$indices,
|
||||
Variadic<AnyIntegerType>:$lenParams
|
||||
);
|
||||
|
||||
let results = (outs fir_ReferenceType);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$memref (`(`$shape^`)`)? (`[`$slice^`]`)? $indices (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
|
||||
|
||||
let summary = "Finds the coordinate (location) of a value in memory";
|
||||
|
||||
let description = [{
|
||||
|
|
@ -1674,18 +1927,218 @@ def fir_FieldIndexOp : fir_OneResultOp<"field_index", [NoSideEffect]> {
|
|||
}
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "StringRef":$fieldName, "Type":$recTy,
|
||||
CArg<"ValueRange", "{}">:$operands),
|
||||
let builders = [OpBuilderDAG<(ins "llvm::StringRef":$fieldName,
|
||||
"mlir::Type":$recTy, CArg<"mlir::ValueRange","{}">:$operands),
|
||||
[{
|
||||
$_state.addAttribute(fieldAttrName(), $_builder.getStringAttr(fieldName));
|
||||
$_state.addAttribute(fieldAttrName(),
|
||||
$_builder.getStringAttr(fieldName));
|
||||
$_state.addAttribute(typeAttrName(), TypeAttr::get(recTy));
|
||||
$_state.addOperands(operands);
|
||||
}]>];
|
||||
}]
|
||||
>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static constexpr llvm::StringRef fieldAttrName() { return "field_id"; }
|
||||
static constexpr llvm::StringRef typeAttrName() { return "on_type"; }
|
||||
llvm::StringRef getFieldName() { return field_id(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def fir_ShapeOp : fir_Op<"shape", [NoSideEffect]> {
|
||||
|
||||
let summary = "generate an abstract shape vector of type `!fir.shape`";
|
||||
|
||||
let description = [{
|
||||
The arguments are an ordered list of integral type values that define the
|
||||
runtime extent of each dimension of an array. The shape information is
|
||||
given in the same row-to-column order as Fortran. This abstract shape value
|
||||
must be applied to a reified object, so all shape information must be
|
||||
specified. The extent must be nonnegative.
|
||||
|
||||
```mlir
|
||||
%d = fir.shape %row_sz, %col_sz : (index, index) -> !fir.shape<2>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyIntegerType>:$extents);
|
||||
|
||||
let results = (outs fir_ShapeType);
|
||||
|
||||
let assemblyFormat = [{
|
||||
operands attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let verifier = [{
|
||||
auto size = extents().size();
|
||||
auto shapeTy = getType().dyn_cast<fir::ShapeType>();
|
||||
assert(shapeTy && "must be a shape type");
|
||||
if (shapeTy.getRank() != size)
|
||||
return emitOpError("shape type rank mismatch");
|
||||
return mlir::success();
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
std::vector<mlir::Value> getExtents() {
|
||||
return {extents().begin(), extents().end()};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def fir_ShapeShiftOp : fir_Op<"shape_shift", [NoSideEffect]> {
|
||||
|
||||
let summary = [{
|
||||
generate an abstract shape and shift vector of type `!fir.shapeshift`
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The arguments are an ordered list of integral type values that is a multiple
|
||||
of 2 in length. Each such pair is defined as: the lower bound and the
|
||||
extent for that dimension. The shifted shape information is given in the
|
||||
same row-to-column order as Fortran. This abstract shifted shape value must
|
||||
be applied to a reified object, so all shifted shape information must be
|
||||
specified. The extent must be nonnegative.
|
||||
|
||||
```mlir
|
||||
%d = fir.shape_shift %lo, %extent : (index, index) -> !fir.shapeshift<1>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyIntegerType>:$pairs);
|
||||
|
||||
let results = (outs fir_ShapeShiftType);
|
||||
|
||||
let assemblyFormat = [{
|
||||
operands attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let verifier = [{
|
||||
auto size = pairs().size();
|
||||
if (size < 2 || size > 16 * 2)
|
||||
return emitOpError("incorrect number of args");
|
||||
if (size % 2 != 0)
|
||||
return emitOpError("requires a multiple of 2 args");
|
||||
auto shapeTy = getType().dyn_cast<fir::ShapeShiftType>();
|
||||
assert(shapeTy && "must be a shape shift type");
|
||||
if (shapeTy.getRank() * 2 != size)
|
||||
return emitOpError("shape type rank mismatch");
|
||||
return mlir::success();
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Logically unzip the origins from the extent values.
|
||||
std::vector<mlir::Value> getOrigins() {
|
||||
std::vector<mlir::Value> result;
|
||||
for (auto i : llvm::enumerate(pairs()))
|
||||
if (!(i.index() & 1))
|
||||
result.push_back(i.value());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Logically unzip the extents from the origin values.
|
||||
std::vector<mlir::Value> getExtents() {
|
||||
std::vector<mlir::Value> result;
|
||||
for (auto i : llvm::enumerate(pairs()))
|
||||
if (i.index() & 1)
|
||||
result.push_back(i.value());
|
||||
return result;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def fir_ShiftOp : fir_Op<"shift", [NoSideEffect]> {
|
||||
|
||||
let summary = "generate an abstract shift vector of type `!fir.shift`";
|
||||
|
||||
let description = [{
|
||||
The arguments are an ordered list of integral type values that define the
|
||||
runtime lower bound of each dimension of an array. The shape information is
|
||||
given in the same row-to-column order as Fortran. This abstract shift value
|
||||
must be applied to a reified object, so all shift information must be
|
||||
specified.
|
||||
|
||||
```mlir
|
||||
%d = fir.shift %row_lb, %col_lb : (index, index) -> !fir.shift<2>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyIntegerType>:$origins);
|
||||
|
||||
let results = (outs fir_ShiftType);
|
||||
|
||||
let assemblyFormat = [{
|
||||
operands attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let verifier = [{
|
||||
auto size = origins().size();
|
||||
auto shiftTy = getType().dyn_cast<fir::ShiftType>();
|
||||
assert(shiftTy && "must be a shift type");
|
||||
if (shiftTy.getRank() != size)
|
||||
return emitOpError("shift type rank mismatch");
|
||||
return mlir::success();
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
std::vector<mlir::Value> getOrigins() {
|
||||
return {origins().begin(), origins().end()};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def fir_SliceOp : fir_Op<"slice", [NoSideEffect, AttrSizedOperandSegments]> {
|
||||
|
||||
let summary = "generate an abstract slice vector of type `!fir.slice`";
|
||||
|
||||
let description = [{
|
||||
The array slicing arguments are an ordered list of integral type values
|
||||
that must be a multiple of 3 in length. Each such triple is defined as:
|
||||
the lower bound, the upper bound, and the stride for that dimension, as in
|
||||
Fortran syntax. Both bounds are inclusive. The array slice information is
|
||||
given in the same row-to-column order as Fortran. This abstract slice value
|
||||
must be applied to a reified object, so all slice information must be
|
||||
specified. The extent must be nonnegative and the stride must not be zero.
|
||||
|
||||
```mlir
|
||||
%d = fir.slice %lo, %hi, %step : (index, index, index) -> !fir.slice<1>
|
||||
```
|
||||
|
||||
To support generalized slicing of Fortran's dynamic derived types, a slice
|
||||
op can be given a component path (narrowing from the product type of the
|
||||
original array to the specific elemental type of the sliced projection).
|
||||
|
||||
```mlir
|
||||
%fld = fir.field_index component, !fir.type<t{...component:ct...}>
|
||||
%d = fir.slice %lo, %hi, %step path %fld : (index, index, index, !fir.field) -> !fir.slice<1>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyCoordinateType>:$triples,
|
||||
Variadic<AnyComponentType>:$fields
|
||||
);
|
||||
|
||||
let results = (outs fir_SliceType);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$triples (`path` $fields^)? attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let verifier = [{
|
||||
auto size = triples().size();
|
||||
if (size < 3 || size > 16 * 3)
|
||||
return emitOpError("incorrect number of args for triple");
|
||||
if (size % 3 != 0)
|
||||
return emitOpError("requires a multiple of 3 args");
|
||||
auto sliceTy = getType().dyn_cast<fir::SliceType>();
|
||||
assert(sliceTy && "must be a slice type");
|
||||
if (sliceTy.getRank() * 3 != size)
|
||||
return emitOpError("slice type rank mismatch");
|
||||
return mlir::success();
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
unsigned getOutRank() { return getOutputRank(triples()); }
|
||||
static unsigned getOutputRank(mlir::ValueRange triples);
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef OPTIMIZER_DIALECT_FIRTYPE_H
|
||||
#define OPTIMIZER_DIALECT_FIRTYPE_H
|
||||
#ifndef FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H
|
||||
#define FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
|
@ -23,7 +23,8 @@
|
|||
namespace llvm {
|
||||
class raw_ostream;
|
||||
class StringRef;
|
||||
template <typename> class ArrayRef;
|
||||
template <typename>
|
||||
class ArrayRef;
|
||||
class hash_code;
|
||||
} // namespace llvm
|
||||
|
||||
|
|
@ -80,6 +81,10 @@ bool isa_aggregate(mlir::Type t);
|
|||
/// not a memory reference type, then returns a null `Type`.
|
||||
mlir::Type dyn_cast_ptrEleTy(mlir::Type t);
|
||||
|
||||
/// Extract the `Type` pointed to from a FIR memory reference or box type. If
|
||||
/// `t` is not a memory reference or box type, then returns a null `Type`.
|
||||
mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t);
|
||||
|
||||
/// Is `t` a FIR Real or MLIR Float type?
|
||||
inline bool isa_real(mlir::Type t) {
|
||||
return t.isa<fir::RealType>() || t.isa<mlir::FloatType>();
|
||||
|
|
@ -125,4 +130,4 @@ inline bool singleIndirectionLevel(mlir::Type ty) {
|
|||
|
||||
} // namespace fir
|
||||
|
||||
#endif // OPTIMIZER_DIALECT_FIRTYPE_H
|
||||
#endif // FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H
|
||||
|
|
|
|||
|
|
@ -5,6 +5,10 @@
|
|||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "flang/Optimizer/Dialect/FIROps.h"
|
||||
#include "flang/Optimizer/Dialect/FIRAttr.h"
|
||||
|
|
@ -115,6 +119,90 @@ mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) {
|
|||
return HeapType::get(intype);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArrayCoorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static mlir::LogicalResult verify(fir::ArrayCoorOp op) {
|
||||
auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType());
|
||||
auto arrTy = eleTy.dyn_cast<fir::SequenceType>();
|
||||
if (!arrTy)
|
||||
return op.emitOpError("must be a reference to an array");
|
||||
auto arrDim = arrTy.getDimension();
|
||||
|
||||
if (auto shapeOp = op.shape()) {
|
||||
auto shapeTy = shapeOp.getType();
|
||||
unsigned shapeTyRank = 0;
|
||||
if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) {
|
||||
shapeTyRank = s.getRank();
|
||||
} else if (auto ss = shapeTy.dyn_cast<fir::ShapeShiftType>()) {
|
||||
shapeTyRank = ss.getRank();
|
||||
} else {
|
||||
auto s = shapeTy.cast<fir::ShiftType>();
|
||||
shapeTyRank = s.getRank();
|
||||
if (!op.memref().getType().isa<fir::BoxType>())
|
||||
return op.emitOpError("shift can only be provided with fir.box memref");
|
||||
}
|
||||
if (arrDim && arrDim != shapeTyRank)
|
||||
return op.emitOpError("rank of dimension mismatched");
|
||||
if (shapeTyRank != op.indices().size())
|
||||
return op.emitOpError("number of indices do not match dim rank");
|
||||
}
|
||||
|
||||
if (auto sliceOp = op.slice())
|
||||
if (auto sliceTy = sliceOp.getType().dyn_cast<fir::SliceType>())
|
||||
if (sliceTy.getRank() != arrDim)
|
||||
return op.emitOpError("rank of dimension in slice mismatched");
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArrayLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::vector<mlir::Value> fir::ArrayLoadOp::getExtents() {
|
||||
if (auto sh = shape())
|
||||
if (auto *op = sh.getDefiningOp()) {
|
||||
if (auto shOp = dyn_cast<fir::ShapeOp>(op))
|
||||
return shOp.getExtents();
|
||||
return cast<fir::ShapeShiftOp>(op).getExtents();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
static mlir::LogicalResult verify(fir::ArrayLoadOp op) {
|
||||
auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType());
|
||||
auto arrTy = eleTy.dyn_cast<fir::SequenceType>();
|
||||
if (!arrTy)
|
||||
return op.emitOpError("must be a reference to an array");
|
||||
auto arrDim = arrTy.getDimension();
|
||||
|
||||
if (auto shapeOp = op.shape()) {
|
||||
auto shapeTy = shapeOp.getType();
|
||||
unsigned shapeTyRank = 0;
|
||||
if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) {
|
||||
shapeTyRank = s.getRank();
|
||||
} else if (auto ss = shapeTy.dyn_cast<fir::ShapeShiftType>()) {
|
||||
shapeTyRank = ss.getRank();
|
||||
} else {
|
||||
auto s = shapeTy.cast<fir::ShiftType>();
|
||||
shapeTyRank = s.getRank();
|
||||
if (!op.memref().getType().isa<fir::BoxType>())
|
||||
return op.emitOpError("shift can only be provided with fir.box memref");
|
||||
}
|
||||
if (arrDim && arrDim != shapeTyRank)
|
||||
return op.emitOpError("rank of dimension mismatched");
|
||||
}
|
||||
|
||||
if (auto sliceOp = op.slice())
|
||||
if (auto sliceTy = sliceOp.getType().dyn_cast<fir::SliceType>())
|
||||
if (sliceTy.getRank() != arrDim)
|
||||
return op.emitOpError("rank of dimension in slice mismatched");
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BoxAddrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
|
|
@ -223,6 +223,19 @@ mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
|
|||
.Default([](mlir::Type) { return mlir::Type{}; });
|
||||
}
|
||||
|
||||
mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) {
|
||||
return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
|
||||
.Case<fir::ReferenceType, fir::PointerType, fir::HeapType>(
|
||||
[](auto p) { return p.getEleTy(); })
|
||||
.Case<fir::BoxType>([](auto p) {
|
||||
auto eleTy = p.getEleTy();
|
||||
if (auto ty = fir::dyn_cast_ptrEleTy(eleTy))
|
||||
return ty;
|
||||
return eleTy;
|
||||
})
|
||||
.Default([](mlir::Type) { return mlir::Type{}; });
|
||||
}
|
||||
|
||||
} // namespace fir
|
||||
|
||||
namespace {
|
||||
|
|
|
|||
|
|
@ -618,5 +618,17 @@ func @test_misc_ops(%arr1 : !fir.ref<!fir.array<?x?xf32>>, %m : index, %n : inde
|
|||
|
||||
// CHECK: [[ARR2:%.*]] = fir.zero_bits !fir.array<10xi32>
|
||||
%arr2 = fir.zero_bits !fir.array<10xi32>
|
||||
|
||||
// CHECK: [[SHAPE:%.*]] = fir.shape_shift [[INDXM:%.*]], [[INDXN:%.*]], [[INDXO:%.*]], [[INDXP:%.*]] : (index, index, index, index) -> !fir.shapeshift<2>
|
||||
// CHECK: [[AV1:%.*]] = fir.array_load [[ARR1]]([[SHAPE]]) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
|
||||
// CHECK: [[FVAL:%.*]] = fir.array_fetch [[AV1]], [[I10]], [[J20]] : (!fir.array<?x?xf32>, index, index) -> f32
|
||||
// CHECK: [[AV2:%.*]] = fir.array_update [[AV1]], [[FVAL]], [[I10]], [[J20]] : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
|
||||
// CHECK: fir.array_merge_store [[AV1]], [[AV2]] to [[ARR1]] : !fir.ref<!fir.array<?x?xf32>>
|
||||
%s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2>
|
||||
%av1 = fir.array_load %arr1(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
|
||||
%f = fir.array_fetch %av1, %i10, %j20 : (!fir.array<?x?xf32>, index, index) -> f32
|
||||
%av2 = fir.array_update %av1, %f, %i10, %j20 : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
|
||||
fir.array_merge_store %av1, %av2 to %arr1 : !fir.ref<!fir.array<?x?xf32>>
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue