255 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			255 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			C++
		
	
	
	
| //===-- lib/Semantics/check-case.cpp --------------------------------------===//
 | |
| //
 | |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | |
| // See https://llvm.org/LICENSE.txt for license information.
 | |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | |
| //
 | |
| //===----------------------------------------------------------------------===//
 | |
| 
 | |
| #include "check-case.h"
 | |
| #include "flang/Common/idioms.h"
 | |
| #include "flang/Common/reference.h"
 | |
| #include "flang/Common/template.h"
 | |
| #include "flang/Evaluate/fold.h"
 | |
| #include "flang/Evaluate/type.h"
 | |
| #include "flang/Parser/parse-tree.h"
 | |
| #include "flang/Semantics/semantics.h"
 | |
| #include "flang/Semantics/tools.h"
 | |
| #include <tuple>
 | |
| 
 | |
| namespace Fortran::semantics {
 | |
| 
 | |
| template <typename T> class CaseValues {
 | |
| public:
 | |
|   CaseValues(SemanticsContext &c, const evaluate::DynamicType &t)
 | |
|       : context_{c}, caseExprType_{t} {}
 | |
| 
 | |
|   void Check(const std::list<parser::CaseConstruct::Case> &cases) {
 | |
|     for (const parser::CaseConstruct::Case &c : cases) {
 | |
|       AddCase(c);
 | |
|     }
 | |
|     if (!hasErrors_) {
 | |
|       cases_.sort(Comparator{});
 | |
|       if (!AreCasesDisjoint()) { // C1149
 | |
|         ReportConflictingCases();
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| 
 | |
| private:
 | |
|   using Value = evaluate::Scalar<T>;
 | |
| 
 | |
|   void AddCase(const parser::CaseConstruct::Case &c) {
 | |
|     const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)};
 | |
|     const parser::CaseStmt &caseStmt{stmt.statement};
 | |
|     const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)};
 | |
|     std::visit(
 | |
|         common::visitors{
 | |
|             [&](const std::list<parser::CaseValueRange> &ranges) {
 | |
|               for (const auto &range : ranges) {
 | |
|                 auto pair{ComputeBounds(range)};
 | |
|                 if (pair.first && pair.second && *pair.first > *pair.second) {
 | |
|                   context_.Say(stmt.source,
 | |
|                       "CASE has lower bound greater than upper bound"_en_US);
 | |
|                 } else {
 | |
|                   if constexpr (T::category == TypeCategory::Logical) { // C1148
 | |
|                     if ((pair.first || pair.second) &&
 | |
|                         (!pair.first || !pair.second ||
 | |
|                             *pair.first != *pair.second)) {
 | |
|                       context_.Say(stmt.source,
 | |
|                           "CASE range is not allowed for LOGICAL"_err_en_US);
 | |
|                     }
 | |
|                   }
 | |
|                   cases_.emplace_back(stmt);
 | |
|                   cases_.back().lower = std::move(pair.first);
 | |
|                   cases_.back().upper = std::move(pair.second);
 | |
|                 }
 | |
|               }
 | |
|             },
 | |
|             [&](const parser::Default &) { cases_.emplace_front(stmt); },
 | |
|         },
 | |
|         selector.u);
 | |
|   }
 | |
| 
 | |
|   std::optional<Value> GetValue(const parser::CaseValue &caseValue) {
 | |
|     const parser::Expr &expr{caseValue.thing.thing.value()};
 | |
|     auto *x{expr.typedExpr.get()};
 | |
|     if (x && x->v) { // C1147
 | |
|       auto type{x->v->GetType()};
 | |
|       if (type && type->category() == caseExprType_.category() &&
 | |
|           (type->category() != TypeCategory::Character ||
 | |
|               type->kind() == caseExprType_.kind())) {
 | |
|         x->v = evaluate::Fold(context_.foldingContext(),
 | |
|             evaluate::ConvertToType(T::GetType(), std::move(*x->v)));
 | |
|         if (x->v) {
 | |
|           if (auto value{evaluate::GetScalarConstantValue<T>(*x->v)}) {
 | |
|             return *value;
 | |
|           }
 | |
|         }
 | |
|         context_.Say(
 | |
|             expr.source, "CASE value must be a constant scalar"_err_en_US);
 | |
|       } else {
 | |
|         std::string typeStr{type ? type->AsFortran() : "typeless"s};
 | |
|         context_.Say(expr.source,
 | |
|             "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
 | |
|             typeStr, caseExprType_.AsFortran());
 | |
|       }
 | |
|       hasErrors_ = true;
 | |
|     }
 | |
|     return std::nullopt;
 | |
|   }
 | |
| 
 | |
|   using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>;
 | |
|   PairOfValues ComputeBounds(const parser::CaseValueRange &range) {
 | |
|     return std::visit(common::visitors{
 | |
|                           [&](const parser::CaseValue &x) {
 | |
|                             auto value{GetValue(x)};
 | |
|                             return PairOfValues{value, value};
 | |
|                           },
 | |
|                           [&](const parser::CaseValueRange::Range &x) {
 | |
|                             std::optional<Value> lo, hi;
 | |
|                             if (x.lower) {
 | |
|                               lo = GetValue(*x.lower);
 | |
|                             }
 | |
|                             if (x.upper) {
 | |
|                               hi = GetValue(*x.upper);
 | |
|                             }
 | |
|                             if ((x.lower && !lo) || (x.upper && !hi)) {
 | |
|                               return PairOfValues{}; // error case
 | |
|                             }
 | |
|                             return PairOfValues{std::move(lo), std::move(hi)};
 | |
|                           },
 | |
|                       },
 | |
|         range.u);
 | |
|   }
 | |
| 
 | |
|   struct Case {
 | |
|     explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {}
 | |
|     bool IsDefault() const { return !lower && !upper; }
 | |
|     std::string AsFortran() const {
 | |
|       std::string result;
 | |
|       {
 | |
|         llvm::raw_string_ostream bs{result};
 | |
|         if (lower) {
 | |
|           evaluate::Constant<T>{*lower}.AsFortran(bs << '(');
 | |
|           if (!upper) {
 | |
|             bs << ':';
 | |
|           } else if (*lower != *upper) {
 | |
|             evaluate::Constant<T>{*upper}.AsFortran(bs << ':');
 | |
|           }
 | |
|           bs << ')';
 | |
|         } else if (upper) {
 | |
|           evaluate::Constant<T>{*upper}.AsFortran(bs << "(:") << ')';
 | |
|         } else {
 | |
|           bs << "DEFAULT";
 | |
|         }
 | |
|       }
 | |
|       return result;
 | |
|     }
 | |
| 
 | |
|     const parser::Statement<parser::CaseStmt> &stmt;
 | |
|     std::optional<Value> lower, upper;
 | |
|   };
 | |
| 
 | |
|   // Defines a comparator for use with std::list<>::sort().
 | |
|   // Returns true if and only if the highest value in range x is less
 | |
|   // than the least value in range y.  The DEFAULT case is arbitrarily
 | |
|   // defined to be less than all others.  When two ranges overlap,
 | |
|   // neither is less than the other.
 | |
|   struct Comparator {
 | |
|     bool operator()(const Case &x, const Case &y) const {
 | |
|       if (x.IsDefault()) {
 | |
|         return !y.IsDefault();
 | |
|       } else {
 | |
|         return x.upper && y.lower && *x.upper < *y.lower;
 | |
|       }
 | |
|     }
 | |
|   };
 | |
| 
 | |
|   bool AreCasesDisjoint() const {
 | |
|     auto endIter{cases_.end()};
 | |
|     for (auto iter{cases_.begin()}; iter != endIter; ++iter) {
 | |
|       auto next{iter};
 | |
|       if (++next != endIter && !Comparator{}(*iter, *next)) {
 | |
|         return false;
 | |
|       }
 | |
|     }
 | |
|     return true;
 | |
|   }
 | |
| 
 | |
|   // This has quadratic time, but only runs in error cases
 | |
|   void ReportConflictingCases() {
 | |
|     for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) {
 | |
|       parser::Message *msg{nullptr};
 | |
|       for (auto p{cases_.begin()}; p != cases_.end(); ++p) {
 | |
|         if (p->stmt.source.begin() < iter->stmt.source.begin() &&
 | |
|             !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) {
 | |
|           if (!msg) {
 | |
|             msg = &context_.Say(iter->stmt.source,
 | |
|                 "CASE %s conflicts with previous cases"_err_en_US,
 | |
|                 iter->AsFortran());
 | |
|           }
 | |
|           msg->Attach(
 | |
|               p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran());
 | |
|         }
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   SemanticsContext &context_;
 | |
|   const evaluate::DynamicType &caseExprType_;
 | |
|   std::list<Case> cases_;
 | |
|   bool hasErrors_{false};
 | |
| };
 | |
| 
 | |
| template <TypeCategory CAT> struct TypeVisitor {
 | |
|   using Result = bool;
 | |
|   using Types = evaluate::CategoryTypes<CAT>;
 | |
|   template <typename T> Result Test() {
 | |
|     if (T::kind == exprType.kind()) {
 | |
|       CaseValues<T>(context, exprType).Check(caseList);
 | |
|       return true;
 | |
|     } else {
 | |
|       return false;
 | |
|     }
 | |
|   }
 | |
|   SemanticsContext &context;
 | |
|   const evaluate::DynamicType &exprType;
 | |
|   const std::list<parser::CaseConstruct::Case> &caseList;
 | |
| };
 | |
| 
 | |
| void CaseChecker::Enter(const parser::CaseConstruct &construct) {
 | |
|   const auto &selectCaseStmt{
 | |
|       std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)};
 | |
|   const auto &selectCase{selectCaseStmt.statement};
 | |
|   const auto &selectExpr{
 | |
|       std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing};
 | |
|   const auto *x{GetExpr(selectExpr)};
 | |
|   if (!x) {
 | |
|     return; // expression semantics failed
 | |
|   }
 | |
|   if (auto exprType{x->GetType()}) {
 | |
|     const auto &caseList{
 | |
|         std::get<std::list<parser::CaseConstruct::Case>>(construct.t)};
 | |
|     switch (exprType->category()) {
 | |
|     case TypeCategory::Integer:
 | |
|       common::SearchTypes(
 | |
|           TypeVisitor<TypeCategory::Integer>{context_, *exprType, caseList});
 | |
|       return;
 | |
|     case TypeCategory::Logical:
 | |
|       CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType}
 | |
|           .Check(caseList);
 | |
|       return;
 | |
|     case TypeCategory::Character:
 | |
|       common::SearchTypes(
 | |
|           TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList});
 | |
|       return;
 | |
|     default:
 | |
|       break;
 | |
|     }
 | |
|   }
 | |
|   context_.Say(selectExpr.source,
 | |
|       "SELECT CASE expression must be integer, logical, or character"_err_en_US);
 | |
| }
 | |
| } // namespace Fortran::semantics
 |