1113 lines
40 KiB
C++
1113 lines
40 KiB
C++
//===- Parser.cpp ---------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Tools/PDLL/Parser/Parser.h"
|
|
#include "Lexer.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Tools/PDLL/AST/Context.h"
|
|
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
|
|
#include "mlir/Tools/PDLL/AST/Nodes.h"
|
|
#include "mlir/Tools/PDLL/AST/Types.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/SaveAndRestore.h"
|
|
#include <string>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::pdll;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class Parser {
|
|
public:
|
|
Parser(ast::Context &ctx, llvm::SourceMgr &sourceMgr)
|
|
: ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine()),
|
|
curToken(lexer.lexToken()), curDeclScope(nullptr),
|
|
valueTy(ast::ValueType::get(ctx)),
|
|
valueRangeTy(ast::ValueRangeType::get(ctx)),
|
|
typeTy(ast::TypeType::get(ctx)),
|
|
typeRangeTy(ast::TypeRangeType::get(ctx)) {}
|
|
|
|
/// Try to parse a new module. Returns nullptr in the case of failure.
|
|
FailureOr<ast::Module *> parseModule();
|
|
|
|
private:
|
|
/// The current context of the parser. It allows for the parser to know a bit
|
|
/// about the construct it is nested within during parsing. This is used
|
|
/// specifically to provide additional verification during parsing, e.g. to
|
|
/// prevent using rewrites within a match context, matcher constraints within
|
|
/// a rewrite section, etc.
|
|
enum class ParserContext {
|
|
/// The parser is in the global context.
|
|
Global,
|
|
/// The parser is currently within the matcher portion of a Pattern, which
|
|
/// is allows a terminal operation rewrite statement but no other rewrite
|
|
/// transformations.
|
|
PatternMatch,
|
|
};
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Parsing
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
/// Push a new decl scope onto the lexer.
|
|
ast::DeclScope *pushDeclScope() {
|
|
ast::DeclScope *newScope =
|
|
new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope);
|
|
return (curDeclScope = newScope);
|
|
}
|
|
void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; }
|
|
|
|
/// Pop the last decl scope from the lexer.
|
|
void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
|
|
|
|
/// Parse the body of an AST module.
|
|
LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls);
|
|
|
|
/// Try to convert the given expression to `type`. Returns failure and emits
|
|
/// an error if a conversion is not viable. On failure, `noteAttachFn` is
|
|
/// invoked to attach notes to the emitted error diagnostic. On success,
|
|
/// `expr` is updated to the expression used to convert to `type`.
|
|
LogicalResult convertExpressionTo(
|
|
ast::Expr *&expr, ast::Type type,
|
|
function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Directives
|
|
|
|
LogicalResult parseDirective(SmallVector<ast::Decl *> &decls);
|
|
LogicalResult parseInclude(SmallVector<ast::Decl *> &decls);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Decls
|
|
|
|
FailureOr<ast::Decl *> parseTopLevelDecl();
|
|
FailureOr<ast::Decl *> parsePatternDecl();
|
|
|
|
/// Check to see if a decl has already been defined with the given name, if
|
|
/// one has emit and error and return failure. Returns success otherwise.
|
|
LogicalResult checkDefineNamedDecl(const ast::Name &name);
|
|
|
|
/// Try to define a variable decl with the given components, returns the
|
|
/// variable on success.
|
|
FailureOr<ast::VariableDecl *>
|
|
defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type,
|
|
ast::Expr *initExpr,
|
|
ArrayRef<ast::ConstraintRef> constraints);
|
|
FailureOr<ast::VariableDecl *>
|
|
defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type,
|
|
ArrayRef<ast::ConstraintRef> constraints);
|
|
|
|
/// Parse the constraint reference list for a variable decl.
|
|
LogicalResult parseVariableDeclConstraintList(
|
|
SmallVectorImpl<ast::ConstraintRef> &constraints);
|
|
|
|
/// Parse the expression used within a type constraint, e.g. Attr<type-expr>.
|
|
FailureOr<ast::Expr *> parseTypeConstraintExpr();
|
|
|
|
/// Try to parse a single reference to a constraint. `typeConstraint` is the
|
|
/// location of a previously parsed type constraint for the entity that will
|
|
/// be constrained by the parsed constraint. `existingConstraints` are any
|
|
/// existing constraints that have already been parsed for the same entity
|
|
/// that will be constrained by this constraint.
|
|
FailureOr<ast::ConstraintRef>
|
|
parseConstraint(Optional<llvm::SMRange> &typeConstraint,
|
|
ArrayRef<ast::ConstraintRef> existingConstraints);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Exprs
|
|
|
|
FailureOr<ast::Expr *> parseExpr();
|
|
|
|
/// Identifier expressions.
|
|
FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, llvm::SMRange loc);
|
|
FailureOr<ast::Expr *> parseIdentifierExpr();
|
|
FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
|
|
FailureOr<ast::OpNameDecl *> parseOperationName();
|
|
FailureOr<ast::OpNameDecl *> parseWrappedOperationName();
|
|
FailureOr<ast::Expr *> parseUnderscoreExpr();
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Stmts
|
|
|
|
FailureOr<ast::Stmt *> parseStmt(bool expectTerminalSemicolon = true);
|
|
FailureOr<ast::CompoundStmt *> parseCompoundStmt();
|
|
FailureOr<ast::EraseStmt *> parseEraseStmt();
|
|
FailureOr<ast::LetStmt *> parseLetStmt();
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Creation+Analysis
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Decls
|
|
|
|
/// Try to create a pattern decl with the given components, returning the
|
|
/// Pattern on success.
|
|
FailureOr<ast::PatternDecl *> createPatternDecl(llvm::SMRange loc,
|
|
const ast::Name *name,
|
|
Optional<uint16_t> benefit,
|
|
bool hasBoundedRecursion,
|
|
ast::CompoundStmt *body);
|
|
|
|
/// Try to create a variable decl with the given components, returning the
|
|
/// Variable on success.
|
|
FailureOr<ast::VariableDecl *>
|
|
createVariableDecl(StringRef name, llvm::SMRange loc, ast::Expr *initializer,
|
|
ArrayRef<ast::ConstraintRef> constraints);
|
|
|
|
/// Validate the constraints used to constraint a variable decl.
|
|
/// `inferredType` is the type of the variable inferred by the constraints
|
|
/// within the list, and is updated to the most refined type as determined by
|
|
/// the constraints. Returns success if the constraint list is valid, failure
|
|
/// otherwise.
|
|
LogicalResult
|
|
validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
|
|
ast::Type &inferredType);
|
|
/// Validate a single reference to a constraint. `inferredType` contains the
|
|
/// currently inferred variabled type and is refined within the type defined
|
|
/// by the constraint. Returns success if the constraint is valid, failure
|
|
/// otherwise.
|
|
LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
|
|
ast::Type &inferredType);
|
|
LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
|
|
LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Exprs
|
|
|
|
FailureOr<ast::DeclRefExpr *> createDeclRefExpr(llvm::SMRange loc,
|
|
ast::Decl *decl);
|
|
FailureOr<ast::DeclRefExpr *>
|
|
createInlineVariableExpr(ast::Type type, StringRef name, llvm::SMRange loc,
|
|
ArrayRef<ast::ConstraintRef> constraints);
|
|
FailureOr<ast::MemberAccessExpr *>
|
|
createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
|
|
llvm::SMRange loc);
|
|
|
|
/// Validate the member access `name` into the given parent expression. On
|
|
/// success, this also returns the type of the member accessed.
|
|
FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
|
|
StringRef name, llvm::SMRange loc);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Stmts
|
|
|
|
FailureOr<ast::EraseStmt *> createEraseStmt(llvm::SMRange loc,
|
|
ast::Expr *rootOp);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Lexer Utilities
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
/// If the current token has the specified kind, consume it and return true.
|
|
/// If not, return false.
|
|
bool consumeIf(Token::Kind kind) {
|
|
if (curToken.isNot(kind))
|
|
return false;
|
|
consumeToken(kind);
|
|
return true;
|
|
}
|
|
|
|
/// Advance the current lexer onto the next token.
|
|
void consumeToken() {
|
|
assert(curToken.isNot(Token::eof, Token::error) &&
|
|
"shouldn't advance past EOF or errors");
|
|
curToken = lexer.lexToken();
|
|
}
|
|
|
|
/// Advance the current lexer onto the next token, asserting what the expected
|
|
/// current token is. This is preferred to the above method because it leads
|
|
/// to more self-documenting code with better checking.
|
|
void consumeToken(Token::Kind kind) {
|
|
assert(curToken.is(kind) && "consumed an unexpected token");
|
|
consumeToken();
|
|
}
|
|
|
|
/// Consume the specified token if present and return success. On failure,
|
|
/// output a diagnostic and return failure.
|
|
LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
|
|
if (curToken.getKind() != kind)
|
|
return emitError(curToken.getLoc(), msg);
|
|
consumeToken();
|
|
return success();
|
|
}
|
|
LogicalResult emitError(llvm::SMRange loc, const Twine &msg) {
|
|
lexer.emitError(loc, msg);
|
|
return failure();
|
|
}
|
|
LogicalResult emitError(const Twine &msg) {
|
|
return emitError(curToken.getLoc(), msg);
|
|
}
|
|
LogicalResult emitErrorAndNote(llvm::SMRange loc, const Twine &msg,
|
|
llvm::SMRange noteLoc, const Twine ¬e) {
|
|
lexer.emitErrorAndNote(loc, msg, noteLoc, note);
|
|
return failure();
|
|
}
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Fields
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
/// The owning AST context.
|
|
ast::Context &ctx;
|
|
|
|
/// The lexer of this parser.
|
|
Lexer lexer;
|
|
|
|
/// The current token within the lexer.
|
|
Token curToken;
|
|
|
|
/// The most recently defined decl scope.
|
|
ast::DeclScope *curDeclScope;
|
|
llvm::SpecificBumpPtrAllocator<ast::DeclScope> scopeAllocator;
|
|
|
|
/// The current context of the parser.
|
|
ParserContext parserContext = ParserContext::Global;
|
|
|
|
/// Cached types to simplify verification and expression creation.
|
|
ast::Type valueTy, valueRangeTy;
|
|
ast::Type typeTy, typeRangeTy;
|
|
};
|
|
} // namespace
|
|
|
|
FailureOr<ast::Module *> Parser::parseModule() {
|
|
llvm::SMLoc moduleLoc = curToken.getStartLoc();
|
|
pushDeclScope();
|
|
|
|
// Parse the top-level decls of the module.
|
|
SmallVector<ast::Decl *> decls;
|
|
if (failed(parseModuleBody(decls)))
|
|
return popDeclScope(), failure();
|
|
|
|
popDeclScope();
|
|
return ast::Module::create(ctx, moduleLoc, decls);
|
|
}
|
|
|
|
LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) {
|
|
while (curToken.isNot(Token::eof)) {
|
|
if (curToken.is(Token::directive)) {
|
|
if (failed(parseDirective(decls)))
|
|
return failure();
|
|
continue;
|
|
}
|
|
|
|
FailureOr<ast::Decl *> decl = parseTopLevelDecl();
|
|
if (failed(decl))
|
|
return failure();
|
|
decls.push_back(*decl);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult Parser::convertExpressionTo(
|
|
ast::Expr *&expr, ast::Type type,
|
|
function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
|
|
ast::Type exprType = expr->getType();
|
|
if (exprType == type)
|
|
return success();
|
|
|
|
auto emitConvertError = [&]() -> ast::InFlightDiagnostic {
|
|
ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitError(
|
|
expr->getLoc(), llvm::formatv("unable to convert expression of type "
|
|
"`{0}` to the expected type of "
|
|
"`{1}`",
|
|
exprType, type));
|
|
if (noteAttachFn)
|
|
noteAttachFn(*diag);
|
|
return diag;
|
|
};
|
|
|
|
if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
|
|
// Two operation types are compatible if they have the same name, or if the
|
|
// expected type is more general.
|
|
if (auto opType = type.dyn_cast<ast::OperationType>()) {
|
|
if (opType.getName())
|
|
return emitConvertError();
|
|
return success();
|
|
}
|
|
|
|
// An operation can always convert to a ValueRange.
|
|
if (type == valueRangeTy) {
|
|
expr = ast::MemberAccessExpr::create(ctx, expr->getLoc(), expr,
|
|
"$results", valueRangeTy);
|
|
return success();
|
|
}
|
|
|
|
// Allow conversion to a single value by constraining the result range.
|
|
if (type == valueTy) {
|
|
expr = ast::MemberAccessExpr::create(ctx, expr->getLoc(), expr,
|
|
"$results", valueTy);
|
|
return success();
|
|
}
|
|
return emitConvertError();
|
|
}
|
|
|
|
// FIXME: Decide how to allow/support converting a single result to multiple,
|
|
// and multiple to a single result. For now, we just allow Single->Range,
|
|
// but this isn't something really supported in the PDL dialect. We should
|
|
// figure out some way to support both.
|
|
if ((exprType == valueTy || exprType == valueRangeTy) &&
|
|
(type == valueTy || type == valueRangeTy))
|
|
return success();
|
|
if ((exprType == typeTy || exprType == typeRangeTy) &&
|
|
(type == typeTy || type == typeRangeTy))
|
|
return success();
|
|
|
|
return emitConvertError();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Directives
|
|
|
|
LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
|
|
StringRef directive = curToken.getSpelling();
|
|
if (directive == "#include")
|
|
return parseInclude(decls);
|
|
|
|
return emitError("unknown directive `" + directive + "`");
|
|
}
|
|
|
|
LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::directive);
|
|
|
|
// Parse the file being included.
|
|
if (!curToken.isString())
|
|
return emitError(loc,
|
|
"expected string file name after `include` directive");
|
|
llvm::SMRange fileLoc = curToken.getLoc();
|
|
std::string filenameStr = curToken.getStringValue();
|
|
StringRef filename = filenameStr;
|
|
consumeToken();
|
|
|
|
// Check the type of include. If ending with `.pdll`, this is another pdl file
|
|
// to be parsed along with the current module.
|
|
if (filename.endswith(".pdll")) {
|
|
if (failed(lexer.pushInclude(filename)))
|
|
return emitError(fileLoc,
|
|
"unable to open include file `" + filename + "`");
|
|
|
|
// If we added the include successfully, parse it into the current module.
|
|
// Make sure to save the current token so that we can restore it when we
|
|
// finish parsing the nested file.
|
|
Token oldToken = curToken;
|
|
curToken = lexer.lexToken();
|
|
LogicalResult result = parseModuleBody(decls);
|
|
curToken = oldToken;
|
|
return result;
|
|
}
|
|
|
|
return emitError(fileLoc, "expected include filename to end with `.pdll`");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Decls
|
|
|
|
FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
|
|
FailureOr<ast::Decl *> decl;
|
|
switch (curToken.getKind()) {
|
|
case Token::kw_Pattern:
|
|
decl = parsePatternDecl();
|
|
break;
|
|
default:
|
|
return emitError("expected top-level declaration, such as a `Pattern`");
|
|
}
|
|
if (failed(decl))
|
|
return failure();
|
|
|
|
// If the decl has a name, add it to the current scope.
|
|
if (const ast::Name *name = (*decl)->getName()) {
|
|
if (failed(checkDefineNamedDecl(*name)))
|
|
return failure();
|
|
curDeclScope->add(*decl);
|
|
}
|
|
return decl;
|
|
}
|
|
|
|
FailureOr<ast::Decl *> Parser::parsePatternDecl() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::kw_Pattern);
|
|
llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
|
|
ParserContext::PatternMatch);
|
|
|
|
// Check for an optional identifier for the pattern name.
|
|
const ast::Name *name = nullptr;
|
|
if (curToken.is(Token::identifier)) {
|
|
name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
|
|
consumeToken(Token::identifier);
|
|
}
|
|
|
|
// TODO: Parse any pattern metadata.
|
|
Optional<uint16_t> benefit;
|
|
bool hasBoundedRecursion = false;
|
|
|
|
// Parse the pattern body.
|
|
ast::CompoundStmt *body;
|
|
|
|
if (curToken.isNot(Token::l_brace))
|
|
return emitError("expected `{` to start pattern body");
|
|
FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
|
|
if (failed(bodyResult))
|
|
return failure();
|
|
body = *bodyResult;
|
|
|
|
// Verify the body of the pattern.
|
|
auto bodyIt = body->begin(), bodyE = body->end();
|
|
for (; bodyIt != bodyE; ++bodyIt) {
|
|
// Break when we've found the rewrite statement.
|
|
if (isa<ast::OpRewriteStmt>(*bodyIt))
|
|
break;
|
|
}
|
|
if (bodyIt == bodyE) {
|
|
return emitError(loc,
|
|
"expected Pattern body to terminate with an operation "
|
|
"rewrite statement, such as `erase`");
|
|
}
|
|
if (std::next(bodyIt) != bodyE) {
|
|
return emitError((*std::next(bodyIt))->getLoc(),
|
|
"Pattern body was terminated by an operation "
|
|
"rewrite statement, but found trailing statements");
|
|
}
|
|
|
|
return createPatternDecl(loc, name, benefit, hasBoundedRecursion, body);
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseTypeConstraintExpr() {
|
|
consumeToken(Token::less);
|
|
|
|
FailureOr<ast::Expr *> typeExpr = parseExpr();
|
|
if (failed(typeExpr) ||
|
|
failed(parseToken(Token::greater,
|
|
"expected `>` after variable type constraint")))
|
|
return failure();
|
|
return typeExpr;
|
|
}
|
|
|
|
LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
|
|
assert(curDeclScope && "defining decl outside of a decl scope");
|
|
if (ast::Decl *lastDecl = curDeclScope->lookup(name.getName())) {
|
|
return emitErrorAndNote(
|
|
name.getLoc(), "`" + name.getName() + "` has already been defined",
|
|
lastDecl->getName()->getLoc(), "see previous definition here");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
FailureOr<ast::VariableDecl *>
|
|
Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc,
|
|
ast::Type type, ast::Expr *initExpr,
|
|
ArrayRef<ast::ConstraintRef> constraints) {
|
|
assert(curDeclScope && "defining variable outside of decl scope");
|
|
const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
|
|
|
|
// If the name of the variable indicates a special variable, we don't add it
|
|
// to the scope. This variable is local to the definition point.
|
|
if (name.empty() || name == "_") {
|
|
return ast::VariableDecl::create(ctx, nameDecl, type, initExpr,
|
|
constraints);
|
|
}
|
|
if (failed(checkDefineNamedDecl(nameDecl)))
|
|
return failure();
|
|
|
|
auto *varDecl =
|
|
ast::VariableDecl::create(ctx, nameDecl, type, initExpr, constraints);
|
|
curDeclScope->add(varDecl);
|
|
return varDecl;
|
|
}
|
|
|
|
FailureOr<ast::VariableDecl *>
|
|
Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc,
|
|
ast::Type type,
|
|
ArrayRef<ast::ConstraintRef> constraints) {
|
|
return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
|
|
constraints);
|
|
}
|
|
|
|
LogicalResult Parser::parseVariableDeclConstraintList(
|
|
SmallVectorImpl<ast::ConstraintRef> &constraints) {
|
|
Optional<llvm::SMRange> typeConstraint;
|
|
auto parseSingleConstraint = [&] {
|
|
FailureOr<ast::ConstraintRef> constraint =
|
|
parseConstraint(typeConstraint, constraints);
|
|
if (failed(constraint))
|
|
return failure();
|
|
constraints.push_back(*constraint);
|
|
return success();
|
|
};
|
|
|
|
// Check to see if this is a single constraint, or a list.
|
|
if (!consumeIf(Token::l_square))
|
|
return parseSingleConstraint();
|
|
|
|
do {
|
|
if (failed(parseSingleConstraint()))
|
|
return failure();
|
|
} while (consumeIf(Token::comma));
|
|
return parseToken(Token::r_square, "expected `]` after constraint list");
|
|
}
|
|
|
|
FailureOr<ast::ConstraintRef>
|
|
Parser::parseConstraint(Optional<llvm::SMRange> &typeConstraint,
|
|
ArrayRef<ast::ConstraintRef> existingConstraints) {
|
|
auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
|
|
if (typeConstraint)
|
|
return emitErrorAndNote(
|
|
curToken.getLoc(),
|
|
"the type of this variable has already been constrained",
|
|
*typeConstraint, "see previous constraint location here");
|
|
FailureOr<ast::Expr *> constraintExpr = parseTypeConstraintExpr();
|
|
if (failed(constraintExpr))
|
|
return failure();
|
|
typeExpr = *constraintExpr;
|
|
typeConstraint = typeExpr->getLoc();
|
|
return success();
|
|
};
|
|
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
switch (curToken.getKind()) {
|
|
case Token::kw_Attr: {
|
|
consumeToken(Token::kw_Attr);
|
|
|
|
// Check for a type constraint.
|
|
ast::Expr *typeExpr = nullptr;
|
|
if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
|
|
return failure();
|
|
return ast::ConstraintRef(
|
|
ast::AttrConstraintDecl::create(ctx, loc, typeExpr), loc);
|
|
}
|
|
case Token::kw_Op: {
|
|
consumeToken(Token::kw_Op);
|
|
|
|
// Parse an optional operation name.
|
|
FailureOr<ast::OpNameDecl *> opName = parseWrappedOperationName();
|
|
if (failed(opName))
|
|
return failure();
|
|
|
|
return ast::ConstraintRef(ast::OpConstraintDecl::create(ctx, loc, *opName),
|
|
loc);
|
|
}
|
|
case Token::kw_Type:
|
|
consumeToken(Token::kw_Type);
|
|
return ast::ConstraintRef(ast::TypeConstraintDecl::create(ctx, loc), loc);
|
|
case Token::kw_TypeRange:
|
|
consumeToken(Token::kw_TypeRange);
|
|
return ast::ConstraintRef(ast::TypeRangeConstraintDecl::create(ctx, loc),
|
|
loc);
|
|
case Token::kw_Value: {
|
|
consumeToken(Token::kw_Value);
|
|
|
|
// Check for a type constraint.
|
|
ast::Expr *typeExpr = nullptr;
|
|
if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
|
|
return failure();
|
|
|
|
return ast::ConstraintRef(
|
|
ast::ValueConstraintDecl::create(ctx, loc, typeExpr), loc);
|
|
}
|
|
case Token::kw_ValueRange: {
|
|
consumeToken(Token::kw_ValueRange);
|
|
|
|
// Check for a type constraint.
|
|
ast::Expr *typeExpr = nullptr;
|
|
if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr)))
|
|
return failure();
|
|
|
|
return ast::ConstraintRef(
|
|
ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
|
|
}
|
|
case Token::identifier: {
|
|
StringRef constraintName = curToken.getSpelling();
|
|
consumeToken(Token::identifier);
|
|
|
|
// Lookup the referenced constraint.
|
|
ast::Decl *cstDecl = curDeclScope->lookup<ast::Decl>(constraintName);
|
|
if (!cstDecl) {
|
|
return emitError(loc, "unknown reference to constraint `" +
|
|
constraintName + "`");
|
|
}
|
|
|
|
// Handle a reference to a proper constraint.
|
|
if (auto *cst = dyn_cast<ast::ConstraintDecl>(cstDecl))
|
|
return ast::ConstraintRef(cst, loc);
|
|
|
|
return emitErrorAndNote(
|
|
loc, "invalid reference to non-constraint", cstDecl->getLoc(),
|
|
"see the definition of `" + constraintName + "` here");
|
|
}
|
|
default:
|
|
break;
|
|
}
|
|
return emitError(loc, "expected identifier constraint");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Exprs
|
|
|
|
FailureOr<ast::Expr *> Parser::parseExpr() {
|
|
if (curToken.is(Token::underscore))
|
|
return parseUnderscoreExpr();
|
|
|
|
// Parse the LHS expression.
|
|
FailureOr<ast::Expr *> lhsExpr;
|
|
switch (curToken.getKind()) {
|
|
case Token::identifier:
|
|
lhsExpr = parseIdentifierExpr();
|
|
break;
|
|
default:
|
|
return emitError("expected expression");
|
|
}
|
|
if (failed(lhsExpr))
|
|
return failure();
|
|
|
|
// Check for an operator expression.
|
|
while (true) {
|
|
switch (curToken.getKind()) {
|
|
case Token::dot:
|
|
lhsExpr = parseMemberAccessExpr(*lhsExpr);
|
|
break;
|
|
default:
|
|
return lhsExpr;
|
|
}
|
|
if (failed(lhsExpr))
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name,
|
|
llvm::SMRange loc) {
|
|
ast::Decl *decl = curDeclScope->lookup(name);
|
|
if (!decl)
|
|
return emitError(loc, "undefined reference to `" + name + "`");
|
|
|
|
return createDeclRefExpr(loc, decl);
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
|
|
StringRef name = curToken.getSpelling();
|
|
llvm::SMRange nameLoc = curToken.getLoc();
|
|
consumeToken();
|
|
|
|
// Check to see if this is a decl ref expression that defines a variable
|
|
// inline.
|
|
if (consumeIf(Token::colon)) {
|
|
SmallVector<ast::ConstraintRef> constraints;
|
|
if (failed(parseVariableDeclConstraintList(constraints)))
|
|
return failure();
|
|
ast::Type type;
|
|
if (failed(validateVariableConstraints(constraints, type)))
|
|
return failure();
|
|
return createInlineVariableExpr(type, name, nameLoc, constraints);
|
|
}
|
|
|
|
return parseDeclRefExpr(name, nameLoc);
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::dot);
|
|
|
|
// Parse the member name.
|
|
Token memberNameTok = curToken;
|
|
if (memberNameTok.isNot(Token::identifier, Token::integer) &&
|
|
!memberNameTok.isKeyword())
|
|
return emitError(loc, "expected identifier or numeric member name");
|
|
StringRef memberName = memberNameTok.getSpelling();
|
|
consumeToken();
|
|
|
|
return createMemberAccessExpr(parentExpr, memberName, loc);
|
|
}
|
|
|
|
FailureOr<ast::OpNameDecl *> Parser::parseOperationName() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
|
|
// Handle the case of an no operation name.
|
|
if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
|
|
return ast::OpNameDecl::create(ctx, llvm::SMRange());
|
|
|
|
StringRef name = curToken.getSpelling();
|
|
consumeToken();
|
|
|
|
// Otherwise, this is a literal operation name.
|
|
if (failed(parseToken(Token::dot, "expected `.` after dialect namespace")))
|
|
return failure();
|
|
|
|
if (curToken.isNot(Token::identifier) && !curToken.isKeyword())
|
|
return emitError("expected operation name after dialect namespace");
|
|
|
|
name = StringRef(name.data(), name.size() + 1);
|
|
do {
|
|
name = StringRef(name.data(), name.size() + curToken.getSpelling().size());
|
|
loc.End = curToken.getEndLoc();
|
|
consumeToken();
|
|
} while (curToken.isAny(Token::identifier, Token::dot) ||
|
|
curToken.isKeyword());
|
|
return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc));
|
|
}
|
|
|
|
FailureOr<ast::OpNameDecl *> Parser::parseWrappedOperationName() {
|
|
if (!consumeIf(Token::less))
|
|
return ast::OpNameDecl::create(ctx, llvm::SMRange());
|
|
|
|
FailureOr<ast::OpNameDecl *> opNameDecl = parseOperationName();
|
|
if (failed(opNameDecl))
|
|
return failure();
|
|
|
|
if (failed(parseToken(Token::greater, "expected `>` after operation name")))
|
|
return failure();
|
|
return opNameDecl;
|
|
}
|
|
|
|
FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
|
|
StringRef name = curToken.getSpelling();
|
|
llvm::SMRange nameLoc = curToken.getLoc();
|
|
consumeToken(Token::underscore);
|
|
|
|
// Underscore expressions require a constraint list.
|
|
if (failed(parseToken(Token::colon, "expected `:` after `_` variable")))
|
|
return failure();
|
|
|
|
// Parse the constraints for the expression.
|
|
SmallVector<ast::ConstraintRef> constraints;
|
|
if (failed(parseVariableDeclConstraintList(constraints)))
|
|
return failure();
|
|
|
|
ast::Type type;
|
|
if (failed(validateVariableConstraints(constraints, type)))
|
|
return failure();
|
|
return createInlineVariableExpr(type, name, nameLoc, constraints);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Stmts
|
|
|
|
FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
|
|
FailureOr<ast::Stmt *> stmt;
|
|
switch (curToken.getKind()) {
|
|
case Token::kw_erase:
|
|
stmt = parseEraseStmt();
|
|
break;
|
|
case Token::kw_let:
|
|
stmt = parseLetStmt();
|
|
break;
|
|
default:
|
|
stmt = parseExpr();
|
|
break;
|
|
}
|
|
if (failed(stmt) ||
|
|
(expectTerminalSemicolon &&
|
|
failed(parseToken(Token::semicolon, "expected `;` after statement"))))
|
|
return failure();
|
|
return stmt;
|
|
}
|
|
|
|
FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
|
|
llvm::SMLoc startLoc = curToken.getStartLoc();
|
|
consumeToken(Token::l_brace);
|
|
|
|
// Push a new block scope and parse any nested statements.
|
|
pushDeclScope();
|
|
SmallVector<ast::Stmt *> statements;
|
|
while (curToken.isNot(Token::r_brace)) {
|
|
FailureOr<ast::Stmt *> statement = parseStmt();
|
|
if (failed(statement))
|
|
return popDeclScope(), failure();
|
|
statements.push_back(*statement);
|
|
}
|
|
popDeclScope();
|
|
|
|
// Consume the end brace.
|
|
llvm::SMRange location(startLoc, curToken.getEndLoc());
|
|
consumeToken(Token::r_brace);
|
|
|
|
return ast::CompoundStmt::create(ctx, location, statements);
|
|
}
|
|
|
|
FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::kw_erase);
|
|
|
|
// Parse the root operation expression.
|
|
FailureOr<ast::Expr *> rootOp = parseExpr();
|
|
if (failed(rootOp))
|
|
return failure();
|
|
|
|
return createEraseStmt(loc, *rootOp);
|
|
}
|
|
|
|
FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
|
|
llvm::SMRange loc = curToken.getLoc();
|
|
consumeToken(Token::kw_let);
|
|
|
|
// Parse the name of the new variable.
|
|
llvm::SMRange varLoc = curToken.getLoc();
|
|
if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) {
|
|
// `_` is a reserved variable name.
|
|
if (curToken.is(Token::underscore)) {
|
|
return emitError(varLoc,
|
|
"`_` may only be used to define \"inline\" variables");
|
|
}
|
|
return emitError(varLoc,
|
|
"expected identifier after `let` to name a new variable");
|
|
}
|
|
StringRef varName = curToken.getSpelling();
|
|
consumeToken();
|
|
|
|
// Parse the optional set of constraints.
|
|
SmallVector<ast::ConstraintRef> constraints;
|
|
if (consumeIf(Token::colon) &&
|
|
failed(parseVariableDeclConstraintList(constraints)))
|
|
return failure();
|
|
|
|
// Parse the optional initializer expression.
|
|
ast::Expr *initializer = nullptr;
|
|
if (consumeIf(Token::equal)) {
|
|
FailureOr<ast::Expr *> initOrFailure = parseExpr();
|
|
if (failed(initOrFailure))
|
|
return failure();
|
|
initializer = *initOrFailure;
|
|
|
|
// Check that the constraints are compatible with having an initializer,
|
|
// e.g. type constraints cannot be used with initializers.
|
|
for (ast::ConstraintRef constraint : constraints) {
|
|
LogicalResult result =
|
|
TypeSwitch<const ast::Node *, LogicalResult>(constraint.constraint)
|
|
.Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
|
|
ast::ValueRangeConstraintDecl>([&](const auto *cst) {
|
|
if (auto *typeConstraintExpr = cst->getTypeExpr()) {
|
|
return emitError(
|
|
constraint.referenceLoc,
|
|
"type constraints are not permitted on variables with "
|
|
"initializers");
|
|
}
|
|
return success();
|
|
})
|
|
.Default(success());
|
|
if (failed(result))
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
FailureOr<ast::VariableDecl *> varDecl =
|
|
createVariableDecl(varName, varLoc, initializer, constraints);
|
|
if (failed(varDecl))
|
|
return failure();
|
|
return ast::LetStmt::create(ctx, loc, *varDecl);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Creation+Analysis
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Decls
|
|
|
|
FailureOr<ast::PatternDecl *>
|
|
Parser::createPatternDecl(llvm::SMRange loc, const ast::Name *name,
|
|
Optional<uint16_t> benefit, bool hasBoundedRecursion,
|
|
ast::CompoundStmt *body) {
|
|
return ast::PatternDecl::create(ctx, loc, name, benefit, hasBoundedRecursion,
|
|
body);
|
|
}
|
|
|
|
FailureOr<ast::VariableDecl *>
|
|
Parser::createVariableDecl(StringRef name, llvm::SMRange loc,
|
|
ast::Expr *initializer,
|
|
ArrayRef<ast::ConstraintRef> constraints) {
|
|
// The type of the variable, which is expected to be inferred by either a
|
|
// constraint or an initializer expression.
|
|
ast::Type type;
|
|
if (failed(validateVariableConstraints(constraints, type)))
|
|
return failure();
|
|
|
|
if (initializer) {
|
|
// Update the variable type based on the initializer, or try to convert the
|
|
// initializer to the existing type.
|
|
if (!type)
|
|
type = initializer->getType();
|
|
else if (ast::Type mergedType = type.refineWith(initializer->getType()))
|
|
type = mergedType;
|
|
else if (failed(convertExpressionTo(initializer, type)))
|
|
return failure();
|
|
|
|
// Otherwise, if there is no initializer check that the type has already
|
|
// been resolved from the constraint list.
|
|
} else if (!type) {
|
|
return emitErrorAndNote(
|
|
loc, "unable to infer type for variable `" + name + "`", loc,
|
|
"the type of a variable must be inferable from the constraint "
|
|
"list or the initializer");
|
|
}
|
|
|
|
// Try to define a variable with the given name.
|
|
FailureOr<ast::VariableDecl *> varDecl =
|
|
defineVariableDecl(name, loc, type, initializer, constraints);
|
|
if (failed(varDecl))
|
|
return failure();
|
|
|
|
return *varDecl;
|
|
}
|
|
|
|
LogicalResult
|
|
Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
|
|
ast::Type &inferredType) {
|
|
for (const ast::ConstraintRef &ref : constraints)
|
|
if (failed(validateVariableConstraint(ref, inferredType)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
|
|
ast::Type &inferredType) {
|
|
ast::Type constraintType;
|
|
if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
|
|
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
|
|
if (failed(validateTypeConstraintExpr(typeExpr)))
|
|
return failure();
|
|
}
|
|
constraintType = ast::AttributeType::get(ctx);
|
|
} else if (const auto *cst =
|
|
dyn_cast<ast::OpConstraintDecl>(ref.constraint)) {
|
|
constraintType = ast::OperationType::get(ctx, cst->getName());
|
|
} else if (isa<ast::TypeConstraintDecl>(ref.constraint)) {
|
|
constraintType = typeTy;
|
|
} else if (isa<ast::TypeRangeConstraintDecl>(ref.constraint)) {
|
|
constraintType = typeRangeTy;
|
|
} else if (const auto *cst =
|
|
dyn_cast<ast::ValueConstraintDecl>(ref.constraint)) {
|
|
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
|
|
if (failed(validateTypeConstraintExpr(typeExpr)))
|
|
return failure();
|
|
}
|
|
constraintType = valueTy;
|
|
} else if (const auto *cst =
|
|
dyn_cast<ast::ValueRangeConstraintDecl>(ref.constraint)) {
|
|
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
|
|
if (failed(validateTypeRangeConstraintExpr(typeExpr)))
|
|
return failure();
|
|
}
|
|
constraintType = valueRangeTy;
|
|
} else {
|
|
llvm_unreachable("unknown constraint type");
|
|
}
|
|
|
|
// Check that the constraint type is compatible with the current inferred
|
|
// type.
|
|
if (!inferredType) {
|
|
inferredType = constraintType;
|
|
} else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) {
|
|
inferredType = mergedTy;
|
|
} else {
|
|
return emitError(ref.referenceLoc,
|
|
llvm::formatv("constraint type `{0}` is incompatible "
|
|
"with the previously inferred type `{1}`",
|
|
constraintType, inferredType));
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) {
|
|
ast::Type typeExprType = typeExpr->getType();
|
|
if (typeExprType != typeTy) {
|
|
return emitError(typeExpr->getLoc(),
|
|
"expected expression of `Type` in type constraint");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
|
|
ast::Type typeExprType = typeExpr->getType();
|
|
if (typeExprType != typeRangeTy) {
|
|
return emitError(typeExpr->getLoc(),
|
|
"expected expression of `TypeRange` in type constraint");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Exprs
|
|
|
|
FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(llvm::SMRange loc,
|
|
ast::Decl *decl) {
|
|
// Check the type of decl being referenced.
|
|
ast::Type declType;
|
|
if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
|
|
declType = varDecl->getType();
|
|
else
|
|
return emitError(loc, "invalid reference to `" +
|
|
decl->getName()->getName() + "`");
|
|
|
|
return ast::DeclRefExpr::create(ctx, loc, decl, declType);
|
|
}
|
|
|
|
FailureOr<ast::DeclRefExpr *>
|
|
Parser::createInlineVariableExpr(ast::Type type, StringRef name,
|
|
llvm::SMRange loc,
|
|
ArrayRef<ast::ConstraintRef> constraints) {
|
|
FailureOr<ast::VariableDecl *> decl =
|
|
defineVariableDecl(name, loc, type, constraints);
|
|
if (failed(decl))
|
|
return failure();
|
|
return ast::DeclRefExpr::create(ctx, loc, *decl, type);
|
|
}
|
|
|
|
FailureOr<ast::MemberAccessExpr *>
|
|
Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
|
|
llvm::SMRange loc) {
|
|
// Validate the member name for the given parent expression.
|
|
FailureOr<ast::Type> memberType = validateMemberAccess(parentExpr, name, loc);
|
|
if (failed(memberType))
|
|
return failure();
|
|
|
|
return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
|
|
}
|
|
|
|
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
|
|
StringRef name,
|
|
llvm::SMRange loc) {
|
|
ast::Type parentType = parentExpr->getType();
|
|
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
|
|
// $results is a special member access representing all of the results.
|
|
// TODO: Should we have special AST expressions for these? How does the
|
|
// user reference these in the language itself?
|
|
if (name == "$results")
|
|
return valueRangeTy;
|
|
}
|
|
return emitError(
|
|
loc,
|
|
llvm::formatv("invalid member access `{0}` on expression of type `{1}`",
|
|
name, parentType));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Stmts
|
|
|
|
FailureOr<ast::EraseStmt *> Parser::createEraseStmt(llvm::SMRange loc,
|
|
ast::Expr *rootOp) {
|
|
// Check that root is an Operation.
|
|
ast::Type rootType = rootOp->getType();
|
|
if (!rootType.isa<ast::OperationType>())
|
|
return emitError(rootOp->getLoc(), "expected `Op` expression");
|
|
|
|
return ast::EraseStmt::create(ctx, loc, rootOp);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parser
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
FailureOr<ast::Module *> mlir::pdll::parsePDLAST(ast::Context &ctx,
|
|
llvm::SourceMgr &sourceMgr) {
|
|
Parser parser(ctx, sourceMgr);
|
|
return parser.parseModule();
|
|
}
|