[fir] Add fir.select_case conversion

The `fir.select_case` operation is converted to a if-then-else ladder.

Conversion of `fir.select_case` operation with character is not
implemented yet.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: kiranchandramohan, mehdi_amini

Differential Revision: https://reviews.llvm.org/D113484

Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
Valentin Clement 2021-11-11 15:00:18 +01:00
parent 9534e361ea
commit 39f4ef8146
No known key found for this signature in database
GPG Key ID: 086D54783C928776
5 changed files with 324 additions and 3 deletions

View File

@ -498,6 +498,8 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
llvm::Optional<mlir::OperandRange> getCompareOperands(unsigned cond);
llvm::Optional<llvm::ArrayRef<mlir::Value>> getCompareOperands(
llvm::ArrayRef<mlir::Value> operands, unsigned cond);
llvm::Optional<mlir::ValueRange> getCompareOperands(
mlir::ValueRange operands, unsigned cond);
llvm::Optional<llvm::ArrayRef<mlir::Value>> getSuccessorOperands(
llvm::ArrayRef<mlir::Value> operands, unsigned cond);

View File

@ -13,6 +13,7 @@
#include "flang/Optimizer/CodeGen/CodeGen.h"
#include "PassDetail.h"
#include "flang/ISO_Fortran_binding.h"
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
@ -40,6 +41,13 @@ genConstantIndex(mlir::Location loc, mlir::Type ity,
return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
}
static Block *createBlock(mlir::ConversionPatternRewriter &rewriter,
mlir::Block *insertBefore) {
assert(insertBefore && "expected valid insertion block");
return rewriter.createBlock(insertBefore->getParent(),
mlir::Region::iterator(insertBefore));
}
namespace {
/// FIR conversion pattern template
template <typename FromOp>
@ -695,6 +703,122 @@ struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
}
};
void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest,
Optional<mlir::ValueRange> destOps,
mlir::ConversionPatternRewriter &rewriter,
mlir::Block *newBlock) {
if (destOps.hasValue())
rewriter.create<mlir::LLVM::CondBrOp>(loc, cmp, dest, destOps.getValue(),
newBlock, mlir::ValueRange());
else
rewriter.create<mlir::LLVM::CondBrOp>(loc, cmp, dest, newBlock);
}
template <typename A, typename B>
void genBrOp(A caseOp, mlir::Block *dest, Optional<B> destOps,
mlir::ConversionPatternRewriter &rewriter) {
if (destOps.hasValue())
rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(caseOp, destOps.getValue(),
dest);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(caseOp, llvm::None, dest);
}
void genCaseLadderStep(mlir::Location loc, mlir::Value cmp, mlir::Block *dest,
Optional<mlir::ValueRange> destOps,
mlir::ConversionPatternRewriter &rewriter) {
auto *thisBlock = rewriter.getInsertionBlock();
auto *newBlock = createBlock(rewriter, dest);
rewriter.setInsertionPointToEnd(thisBlock);
genCondBrOp(loc, cmp, dest, destOps, rewriter, newBlock);
rewriter.setInsertionPointToEnd(newBlock);
}
/// Conversion of `fir.select_case`
///
/// The `fir.select_case` operation is converted to a if-then-else ladder.
/// Depending on the case condition type, one or several comparison and
/// conditional branching can be generated.
///
/// A a point value case such as `case(4)`, a lower bound case such as
/// `case(5:)` or an upper bound case such as `case(:3)` are converted to a
/// simple comparison between the selector value and the constant value in the
/// case. The block associated with the case condition is then executed if
/// the comparison succeed otherwise it branch to the next block with the
/// comparison for the the next case conditon.
///
/// A closed interval case condition such as `case(7:10)` is converted with a
/// first comparison and conditional branching for the lower bound. If
/// successful, it branch to a second block with the comparison for the
/// upper bound in the same case condition.
///
/// TODO: lowering of CHARACTER type cases is not handled yet.
struct SelectCaseOpConversion : public FIROpConversion<fir::SelectCaseOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::SelectCaseOp caseOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
unsigned conds = caseOp.getNumConditions();
llvm::ArrayRef<mlir::Attribute> cases = caseOp.getCases().getValue();
// Type can be CHARACTER, INTEGER, or LOGICAL (C1145)
LLVM_ATTRIBUTE_UNUSED auto ty = caseOp.getSelector().getType();
if (ty.isa<fir::CharacterType>())
return rewriter.notifyMatchFailure(caseOp,
"conversion of fir.select_case with "
"character type not implemented yet");
mlir::Value selector = caseOp.getSelector(adaptor.getOperands());
auto loc = caseOp.getLoc();
for (unsigned t = 0; t != conds; ++t) {
mlir::Block *dest = caseOp.getSuccessor(t);
llvm::Optional<mlir::ValueRange> destOps =
caseOp.getSuccessorOperands(adaptor.getOperands(), t);
llvm::Optional<mlir::ValueRange> cmpOps =
*caseOp.getCompareOperands(adaptor.getOperands(), t);
mlir::Value caseArg = *(cmpOps.getValue().begin());
mlir::Attribute attr = cases[t];
if (attr.isa<fir::PointIntervalAttr>()) {
auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
loc, mlir::LLVM::ICmpPredicate::eq, selector, caseArg);
genCaseLadderStep(loc, cmp, dest, destOps, rewriter);
continue;
}
if (attr.isa<fir::LowerBoundAttr>()) {
auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
loc, mlir::LLVM::ICmpPredicate::sle, caseArg, selector);
genCaseLadderStep(loc, cmp, dest, destOps, rewriter);
continue;
}
if (attr.isa<fir::UpperBoundAttr>()) {
auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
loc, mlir::LLVM::ICmpPredicate::sle, selector, caseArg);
genCaseLadderStep(loc, cmp, dest, destOps, rewriter);
continue;
}
if (attr.isa<fir::ClosedIntervalAttr>()) {
auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
loc, mlir::LLVM::ICmpPredicate::sle, caseArg, selector);
auto *thisBlock = rewriter.getInsertionBlock();
auto *newBlock1 = createBlock(rewriter, dest);
auto *newBlock2 = createBlock(rewriter, dest);
rewriter.setInsertionPointToEnd(thisBlock);
rewriter.create<mlir::LLVM::CondBrOp>(loc, cmp, newBlock1, newBlock2);
rewriter.setInsertionPointToEnd(newBlock1);
mlir::Value caseArg0 = *(cmpOps.getValue().begin() + 1);
auto cmp0 = rewriter.create<mlir::LLVM::ICmpOp>(
loc, mlir::LLVM::ICmpPredicate::sle, selector, caseArg0);
genCondBrOp(loc, cmp0, dest, destOps, rewriter, newBlock2);
rewriter.setInsertionPointToEnd(newBlock2);
continue;
}
assert(attr.isa<mlir::UnitAttr>());
assert((t + 1 == conds) && "unit must be last");
genBrOp(caseOp, dest, destOps, rewriter);
}
return success();
}
};
template <typename OP>
void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
typename OP::Adaptor adaptor,
@ -1233,9 +1357,9 @@ public:
DivcOpConversion, ExtractValueOpConversion, HasValueOpConversion,
GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion,
LoadOpConversion, NegcOpConversion, MulcOpConversion,
SelectOpConversion, SelectRankOpConversion, StoreOpConversion,
SubcOpConversion, UndefOpConversion, UnreachableOpConversion,
ZeroOpConversion>(typeConverter);
SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion,
StoreOpConversion, SubcOpConversion, UndefOpConversion,
UnreachableOpConversion, ZeroOpConversion>(typeConverter);
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern);

View File

@ -2297,6 +2297,16 @@ fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(cond, getSubOperands(1, operands, segments), a)};
}
llvm::Optional<mlir::ValueRange>
fir::SelectCaseOp::getCompareOperands(mlir::ValueRange operands,
unsigned cond) {
auto a = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getCompareOffsetAttr());
auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getOperandSegmentSizeAttr());
return {getSubOperands(cond, getSubOperands(1, operands, segments), a)};
}
llvm::Optional<mlir::MutableOperandRange>
fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) {
return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
@ -2313,6 +2323,16 @@ fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
llvm::Optional<mlir::ValueRange>
fir::SelectCaseOp::getSuccessorOperands(mlir::ValueRange operands,
unsigned oper) {
auto a =
(*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getOperandSegmentSizeAttr());
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
// parser for fir.select_case Op
static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser,
mlir::OperationState &result) {

View File

@ -2,6 +2,9 @@
// RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" --verify-diagnostics %s
// Test `fir.zero` conversion failure with aggregate type.
// Not implemented yet.
func @zero_aggregate() {
// expected-error@+1{{failed to legalize operation 'fir.zero_bits'}}
%a = fir.zero_bits !fir.array<10xf32>
@ -27,3 +30,23 @@ func @dispatch(%arg0: !fir.box<!fir.type<derived3{f:f32}>>) {
fir.dispatch_table @dispatch_tbl {
fir.dt_entry "method", @method_impl
}
// -----
// Test `fir.select_case` conversion failure with character type.
// Not implemented yet.
func @select_case_charachter(%arg0: !fir.char<2, 10>, %arg1: !fir.char<2, 10>, %arg2: !fir.char<2, 10>) {
// expected-error@+1{{failed to legalize operation 'fir.select_case'}}
fir.select_case %arg0 : !fir.char<2, 10> [#fir.point, %arg1, ^bb1,
#fir.point, %arg2, ^bb2,
unit, ^bb3]
^bb1:
%c1_i32 = arith.constant 1 : i32
br ^bb3
^bb2:
%c2_i32 = arith.constant 2 : i32
br ^bb3
^bb3:
return
}

View File

@ -961,3 +961,155 @@ func @alloca_array_with_holes(%0 : index, %1 : index) -> !fir.ref<!fir.array<4x?
// CHECK: [[PROD3:%.*]] = llvm.mul [[PROD2]], [[B]] : i64
// CHECK: [[RES:%.*]] = llvm.alloca [[PROD3]] x i32 {in_type = !fir.array<4x?x3x?x5xi32>
// CHECK: llvm.return [[RES]] : !llvm.ptr<i32>
// -----
// Test `fir.select_case` operation conversion with INTEGER.
func @select_case_integer(%arg0: !fir.ref<i32>) -> i32 {
%2 = fir.load %arg0 : !fir.ref<i32>
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c4_i32 = arith.constant 4 : i32
%c5_i32 = arith.constant 5 : i32
%c7_i32 = arith.constant 7 : i32
%c8_i32 = arith.constant 8 : i32
%c15_i32 = arith.constant 15 : i32
%c21_i32 = arith.constant 21 : i32
fir.select_case %2 : i32 [#fir.upper, %c1_i32, ^bb1,
#fir.point, %c2_i32, ^bb2,
#fir.interval, %c4_i32, %c5_i32, ^bb4,
#fir.point, %c7_i32, ^bb5,
#fir.interval, %c8_i32, %c15_i32, ^bb5,
#fir.lower, %c21_i32, ^bb5,
unit, ^bb3]
^bb1: // pred: ^bb0
%c1_i32_0 = arith.constant 1 : i32
fir.store %c1_i32_0 to %arg0 : !fir.ref<i32>
br ^bb6
^bb2: // pred: ^bb0
%c2_i32_1 = arith.constant 2 : i32
fir.store %c2_i32_1 to %arg0 : !fir.ref<i32>
br ^bb6
^bb3: // pred: ^bb0
%c0_i32 = arith.constant 0 : i32
fir.store %c0_i32 to %arg0 : !fir.ref<i32>
br ^bb6
^bb4: // pred: ^bb0
%c4_i32_2 = arith.constant 4 : i32
fir.store %c4_i32_2 to %arg0 : !fir.ref<i32>
br ^bb6
^bb5: // 3 preds: ^bb0, ^bb0, ^bb0
%c7_i32_3 = arith.constant 7 : i32
fir.store %c7_i32_3 to %arg0 : !fir.ref<i32>
br ^bb6
^bb6: // 5 preds: ^bb1, ^bb2, ^bb3, ^bb4, ^bb5
%3 = fir.load %arg0 : !fir.ref<i32>
return %3 : i32
}
// CHECK-LABEL: llvm.func @select_case_integer(
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<i32>) -> i32 {
// CHECK: %[[SELECT_VALUE:.*]] = llvm.load %[[ARG0]] : !llvm.ptr<i32>
// CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: %[[CST5:.*]] = llvm.mlir.constant(5 : i32) : i32
// CHECK: %[[CST7:.*]] = llvm.mlir.constant(7 : i32) : i32
// CHECK: %[[CST8:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: %[[CST15:.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: %[[CST21:.*]] = llvm.mlir.constant(21 : i32) : i32
// Check for upper bound `case (:1)`
// CHECK: %[[CMP_SLE:.*]] = llvm.icmp "sle" %[[SELECT_VALUE]], %[[CST1]] : i32
// CHECK: llvm.cond_br %[[CMP_SLE]], ^bb2, ^bb1
// CHECK-LABEL: ^bb1:
// Check for point value `case (2)`
// CHECK: %[[CMP_EQ:.*]] = llvm.icmp "eq" %[[SELECT_VALUE]], %[[CST2]] : i32
// CHECK: llvm.cond_br %[[CMP_EQ]], ^bb4, ^bb3
// Block ^bb1 in original FIR code.
// CHECK-LABEL: ^bb2:
// CHECK: llvm.br ^bb{{.*}}
// CHECK-LABEL: ^bb3:
// Check for the lower bound for the interval `case (4:5)`
// CHECK: %[[CMP_SLE:.*]] = llvm.icmp "sle" %[[CST4]], %[[SELECT_VALUE]] : i32
// CHECK: llvm.cond_br %[[CMP_SLE]], ^bb[[UPPERBOUND5:.*]], ^bb7
// Block ^bb2 in original FIR code.
// CHECK-LABEL: ^bb4:
// CHECK: llvm.br ^bb{{.*}}
// Block ^bb3 in original FIR code.
// CHECK-LABEL: ^bb5:
// CHECK: llvm.br ^bb{{.*}}
// CHECK: ^bb[[UPPERBOUND5]]:
// Check for the upper bound for the interval `case (4:5)`
// CHECK: %[[CMP_SLE:.*]] = llvm.icmp "sle" %[[SELECT_VALUE]], %[[CST5]] : i32
// CHECK: llvm.cond_br %[[CMP_SLE]], ^bb8, ^bb7
// CHECK-LABEL: ^bb7:
// Check for the point value 7 in `case (7,8:15,21:)`
// CHECK: %[[CMP_EQ:.*]] = llvm.icmp "eq" %[[SELECT_VALUE]], %[[CST7]] : i32
// CHECK: llvm.cond_br %[[CMP_EQ]], ^bb13, ^bb9
// Block ^bb4 in original FIR code.
// CHECK-LABEL: ^bb8:
// CHECK: llvm.br ^bb{{.*}}
// CHECK-LABEL: ^bb9:
// Check for lower bound 8 in `case (7,8:15,21:)`
// CHECK: %[[CMP_SLE:.*]] = llvm.icmp "sle" %[[CST8]], %[[SELECT_VALUE]] : i32
// CHECK: llvm.cond_br %[[CMP_SLE]], ^bb[[INTERVAL8_15:.*]], ^bb11
// CHECK: ^bb[[INTERVAL8_15]]:
// Check for upper bound 15 in `case (7,8:15,21:)`
// CHECK: %[[CMP_SLE:.*]] = llvm.icmp "sle" %[[SELECT_VALUE]], %[[CST15]] : i32
// CHECK: llvm.cond_br %[[CMP_SLE]], ^bb13, ^bb11
// CHECK-LABEL: ^bb11:
// Check for lower bound 21 in `case (7,8:15,21:)`
// CHECK: %[[CMP_SLE:.*]] = llvm.icmp "sle" %[[CST21]], %[[SELECT_VALUE]] : i32
// CHECK: llvm.cond_br %[[CMP_SLE]], ^bb13, ^bb12
// CHECK-LABEL: ^bb12:
// CHECK: llvm.br ^bb5
// Block ^bb5 in original FIR code.
// CHECK-LABEL: ^bb13:
// CHECK: llvm.br ^bb14
// Block ^bb6 in original FIR code.
// CHECK-LABEL: ^bb14:
// CHECK: %[[RET:.*]] = llvm.load %[[ARG0:.*]] : !llvm.ptr<i32>
// CHECK: llvm.return %[[RET]] : i32
// -----
// Test `fir.select_case` operation conversion with LOGICAL.
func @select_case_logical(%arg0: !fir.ref<!fir.logical<4>>) {
%1 = fir.load %arg0 : !fir.ref<!fir.logical<4>>
%2 = fir.convert %1 : (!fir.logical<4>) -> i1
%false = arith.constant false
%true = arith.constant true
fir.select_case %2 : i1 [#fir.point, %false, ^bb1,
#fir.point, %true, ^bb2,
unit, ^bb3]
^bb1:
%c1_i32 = arith.constant 1 : i32
br ^bb3
^bb2:
%c2_i32 = arith.constant 2 : i32
br ^bb3
^bb3:
return
}
// CHECK-LABEL: llvm.func @select_case_logical(
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<i32>
// CHECK: %[[LOAD_ARG0:.*]] = llvm.load %[[ARG0]] : !llvm.ptr<i32>
// CHECK: %[[SELECT_VALUE:.*]] = llvm.trunc %[[LOAD_ARG0]] : i32 to i1
// CHECK: %[[CST_FALSE:.*]] = llvm.mlir.constant(false) : i1
// CHECK: %[[CST_TRUE:.*]] = llvm.mlir.constant(true) : i1
// CHECK: %[[CMPEQ:.*]] = llvm.icmp "eq" %[[SELECT_VALUE]], %[[CST_FALSE]] : i1
// CHECK: llvm.cond_br %[[CMPEQ]], ^bb2, ^bb1
// CHECK-LABEL: ^bb1:
// CHECK: %[[CMPEQ:.*]] = llvm.icmp "eq" %[[SELECT_VALUE]], %[[CST_TRUE]] : i1
// CHECK: llvm.cond_br %[[CMPEQ]], ^bb4, ^bb3
// CHECK-LABEL: ^bb2:
// CHECK: llvm.br ^bb5
// CHECK-LABEL: ^bb3:
// CHECK: llvm.br ^bb5
// CHECK-LABEL: ^bb4:
// CHECK: llvm.br ^bb5
// CHECK-LABEL: ^bb5:
// CHECK: llvm.return