[mlir][PDLL] Add an initial frontend for PDLL
This is a new pattern rewrite frontend designed from the ground up to support MLIR constructs, and to target PDL. This frontend language was proposed in https://llvm.discourse.group/t/rfc-pdll-a-new-declarative-rewrite-frontend-for-mlir/4798 This commit starts sketching out the base structure of the frontend, and is intended to be a minimal starting point for building up the language. It essentially contains support for defining a pattern, variables, and erasing an operation. The features mentioned in the proposal RFC (including IDE support) will be added incrementally in followup commits. I intend to upstream the documentation for the language in a followup when a bit more of the pieces have been landed. Differential Revision: https://reviews.llvm.org/D115093
This commit is contained in:
parent
8f1ea2e85c
commit
11d26bd143
|
|
@ -0,0 +1,52 @@
|
|||
//===- Context.h - PDLL AST Context -----------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TOOLS_PDLL_AST_CONTEXT_H_
|
||||
#define MLIR_TOOLS_PDLL_AST_CONTEXT_H_
|
||||
|
||||
#include "mlir/Support/StorageUniquer.h"
|
||||
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace pdll {
|
||||
namespace ast {
|
||||
/// This class represents the main context of the PDLL AST. It handles
|
||||
/// allocating all of the AST constructs, and manages all state necessary for
|
||||
/// the AST.
|
||||
class Context {
|
||||
public:
|
||||
Context();
|
||||
Context(const Context &) = delete;
|
||||
Context &operator=(const Context &) = delete;
|
||||
|
||||
/// Return the allocator owned by this context.
|
||||
llvm::BumpPtrAllocator &getAllocator() { return allocator; }
|
||||
|
||||
/// Return the storage uniquer used for AST types.
|
||||
StorageUniquer &getTypeUniquer() { return typeUniquer; }
|
||||
|
||||
/// Return the diagnostic engine of this context.
|
||||
DiagnosticEngine &getDiagEngine() { return diagEngine; }
|
||||
|
||||
private:
|
||||
/// The diagnostic engine of this AST context.
|
||||
DiagnosticEngine diagEngine;
|
||||
|
||||
/// The allocator used for AST nodes, and other entities allocated within the
|
||||
/// context.
|
||||
llvm::BumpPtrAllocator allocator;
|
||||
|
||||
/// The uniquer used for creating AST types.
|
||||
StorageUniquer typeUniquer;
|
||||
};
|
||||
|
||||
} // namespace ast
|
||||
} // namespace pdll
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_TOOLS_PDLL_AST_CONTEXT_H_
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
//===- Diagnostic.h - PDLL AST Diagnostics ----------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TOOLS_PDLL_AST_DIAGNOSTICS_H_
|
||||
#define MLIR_TOOLS_PDLL_AST_DIAGNOSTICS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/FunctionExtras.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace pdll {
|
||||
namespace ast {
|
||||
class DiagnosticEngine;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Diagnostic
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class provides a simple implementation of a PDLL diagnostic.
|
||||
class Diagnostic {
|
||||
public:
|
||||
using Severity = llvm::SourceMgr::DiagKind;
|
||||
|
||||
/// Return the severity of this diagnostic.
|
||||
Severity getSeverity() const { return severity; }
|
||||
|
||||
/// Return the message of this diagnostic.
|
||||
StringRef getMessage() const { return message; }
|
||||
|
||||
/// Return the location of this diagnostic.
|
||||
llvm::SMRange getLocation() const { return location; }
|
||||
|
||||
/// Return the notes of this diagnostic.
|
||||
auto getNotes() const { return llvm::make_pointee_range(notes); }
|
||||
|
||||
/// Attach a note to this diagnostic.
|
||||
Diagnostic &attachNote(const Twine &msg,
|
||||
Optional<llvm::SMRange> noteLoc = llvm::None) {
|
||||
assert(getSeverity() != Severity::DK_Note &&
|
||||
"cannot attach a Note to a Note");
|
||||
notes.emplace_back(
|
||||
new Diagnostic(Severity::DK_Note, noteLoc.getValueOr(location), msg));
|
||||
return *notes.back();
|
||||
}
|
||||
|
||||
/// Allow an inflight diagnostic to be converted to 'failure', otherwise
|
||||
/// 'success' if this is an empty diagnostic.
|
||||
operator LogicalResult() const { return failure(); }
|
||||
|
||||
private:
|
||||
Diagnostic(Severity severity, llvm::SMRange loc, const Twine &msg)
|
||||
: severity(severity), message(msg.str()), location(loc) {}
|
||||
|
||||
// Allow access to the constructor.
|
||||
friend DiagnosticEngine;
|
||||
|
||||
/// The severity of this diagnostic.
|
||||
Severity severity;
|
||||
/// The message held by this diagnostic.
|
||||
std::string message;
|
||||
/// The raw location of this diagnostic.
|
||||
llvm::SMRange location;
|
||||
/// Any additional note diagnostics attached to this diagnostic.
|
||||
std::vector<std::unique_ptr<Diagnostic>> notes;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InFlightDiagnostic
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a diagnostic that is inflight and set to be reported.
|
||||
/// This allows for last minute modifications of the diagnostic before it is
|
||||
/// emitted by a DiagnosticEngine.
|
||||
class InFlightDiagnostic {
|
||||
public:
|
||||
InFlightDiagnostic() = default;
|
||||
InFlightDiagnostic(InFlightDiagnostic &&rhs)
|
||||
: owner(rhs.owner), impl(std::move(rhs.impl)) {
|
||||
// Reset the rhs diagnostic.
|
||||
rhs.impl.reset();
|
||||
rhs.abandon();
|
||||
}
|
||||
~InFlightDiagnostic() {
|
||||
if (isInFlight())
|
||||
report();
|
||||
}
|
||||
|
||||
/// Access the internal diagnostic.
|
||||
Diagnostic &operator*() { return *impl; }
|
||||
Diagnostic *operator->() { return &*impl; }
|
||||
|
||||
/// Reports the diagnostic to the engine.
|
||||
void report();
|
||||
|
||||
/// Abandons this diagnostic so that it will no longer be reported.
|
||||
void abandon() { owner = nullptr; }
|
||||
|
||||
/// Allow an inflight diagnostic to be converted to 'failure', otherwise
|
||||
/// 'success' if this is an empty diagnostic.
|
||||
operator LogicalResult() const { return failure(isActive()); }
|
||||
|
||||
private:
|
||||
InFlightDiagnostic &operator=(const InFlightDiagnostic &) = delete;
|
||||
InFlightDiagnostic &operator=(InFlightDiagnostic &&) = delete;
|
||||
InFlightDiagnostic(DiagnosticEngine *owner, Diagnostic &&rhs)
|
||||
: owner(owner), impl(std::move(rhs)) {}
|
||||
|
||||
/// Returns true if the diagnostic is still active, i.e. it has a live
|
||||
/// diagnostic.
|
||||
bool isActive() const { return impl.hasValue(); }
|
||||
|
||||
/// Returns true if the diagnostic is still in flight to be reported.
|
||||
bool isInFlight() const { return owner; }
|
||||
|
||||
// Allow access to the constructor.
|
||||
friend DiagnosticEngine;
|
||||
|
||||
/// The engine that this diagnostic is to report to.
|
||||
DiagnosticEngine *owner = nullptr;
|
||||
|
||||
/// The raw diagnostic that is inflight to be reported.
|
||||
Optional<Diagnostic> impl;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DiagnosticEngine
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class manages the construction and emission of PDLL diagnostics.
|
||||
class DiagnosticEngine {
|
||||
public:
|
||||
/// A function used to handle diagnostics emitted by the engine.
|
||||
using HandlerFn = llvm::unique_function<void(Diagnostic &)>;
|
||||
|
||||
/// Emit an error to the diagnostic engine.
|
||||
InFlightDiagnostic emitError(llvm::SMRange loc, const Twine &msg) {
|
||||
return InFlightDiagnostic(
|
||||
this, Diagnostic(Diagnostic::Severity::DK_Error, loc, msg));
|
||||
}
|
||||
InFlightDiagnostic emitWarning(llvm::SMRange loc, const Twine &msg) {
|
||||
return InFlightDiagnostic(
|
||||
this, Diagnostic(Diagnostic::Severity::DK_Warning, loc, msg));
|
||||
}
|
||||
|
||||
/// Report the given diagnostic.
|
||||
void report(Diagnostic &&diagnostic) {
|
||||
if (handler)
|
||||
handler(diagnostic);
|
||||
}
|
||||
|
||||
/// Get the current handler function of this diagnostic engine.
|
||||
const HandlerFn &getHandlerFn() const { return handler; }
|
||||
|
||||
/// Take the current handler function, resetting the current handler to null.
|
||||
HandlerFn takeHandlerFn() {
|
||||
HandlerFn oldHandler = std::move(handler);
|
||||
handler = {};
|
||||
return oldHandler;
|
||||
}
|
||||
|
||||
/// Set the handler function for this diagnostic engine.
|
||||
void setHandlerFn(HandlerFn &&newHandler) { handler = std::move(newHandler); }
|
||||
|
||||
private:
|
||||
/// The registered diagnostic handler function.
|
||||
HandlerFn handler;
|
||||
};
|
||||
|
||||
} // namespace ast
|
||||
} // namespace pdll
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_TOOLS_PDLL_AST_DIAGNOSTICS_H_
|
||||
|
|
@ -0,0 +1,690 @@
|
|||
//===- Nodes.h --------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TOOLS_PDLL_AST_NODES_H_
|
||||
#define MLIR_TOOLS_PDLL_AST_NODES_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Tools/PDLL/AST/Types.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/TrailingObjects.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace pdll {
|
||||
namespace ast {
|
||||
class Context;
|
||||
class Decl;
|
||||
class Expr;
|
||||
class OpNameDecl;
|
||||
class VariableDecl;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Name
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class provides a convenient API for interacting with source names. It
|
||||
/// contains a string name as well as the source location for that name.
|
||||
struct Name {
|
||||
static const Name &create(Context &ctx, StringRef name,
|
||||
llvm::SMRange location);
|
||||
|
||||
/// Return the raw string name.
|
||||
StringRef getName() const { return name; }
|
||||
|
||||
/// Get the location of this name.
|
||||
llvm::SMRange getLoc() const { return location; }
|
||||
|
||||
private:
|
||||
Name() = delete;
|
||||
Name(const Name &) = delete;
|
||||
Name &operator=(const Name &) = delete;
|
||||
Name(StringRef name, llvm::SMRange location)
|
||||
: name(name), location(location) {}
|
||||
|
||||
/// The string name of the decl.
|
||||
StringRef name;
|
||||
/// The location of the decl name.
|
||||
llvm::SMRange location;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DeclScope
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a scope for named AST decls. A scope determines the
|
||||
/// visibility and lifetime of a named declaration.
|
||||
class DeclScope {
|
||||
public:
|
||||
/// Create a new scope with an optional parent scope.
|
||||
DeclScope(DeclScope *parent = nullptr) : parent(parent) {}
|
||||
|
||||
/// Return the parent scope of this scope, or nullptr if there is no parent.
|
||||
DeclScope *getParentScope() { return parent; }
|
||||
const DeclScope *getParentScope() const { return parent; }
|
||||
|
||||
/// Return all of the decls within this scope.
|
||||
auto getDecls() const { return llvm::make_second_range(decls); }
|
||||
|
||||
/// Add a new decl to the scope.
|
||||
void add(Decl *decl);
|
||||
|
||||
/// Lookup a decl with the given name starting from this scope. Returns
|
||||
/// nullptr if no decl could be found.
|
||||
Decl *lookup(StringRef name);
|
||||
template <typename T> T *lookup(StringRef name) {
|
||||
return dyn_cast_or_null<T>(lookup(name));
|
||||
}
|
||||
const Decl *lookup(StringRef name) const {
|
||||
return const_cast<DeclScope *>(this)->lookup(name);
|
||||
}
|
||||
template <typename T> const T *lookup(StringRef name) const {
|
||||
return dyn_cast_or_null<T>(lookup(name));
|
||||
}
|
||||
|
||||
private:
|
||||
/// The parent scope, or null if this is a top-level scope.
|
||||
DeclScope *parent;
|
||||
/// The decls defined within this scope.
|
||||
llvm::StringMap<Decl *> decls;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Node
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a base AST node. All AST nodes are derived from this
|
||||
/// class, and it contains many of the base functionality for interacting with
|
||||
/// nodes.
|
||||
class Node {
|
||||
public:
|
||||
/// This CRTP class provides several utilies when defining new AST nodes.
|
||||
template <typename T, typename BaseT> class NodeBase : public BaseT {
|
||||
public:
|
||||
using Base = NodeBase<T, BaseT>;
|
||||
|
||||
/// Provide type casting support.
|
||||
static bool classof(const Node *node) {
|
||||
return node->getTypeID() == TypeID::get<T>();
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename... Args>
|
||||
explicit NodeBase(llvm::SMRange loc, Args &&...args)
|
||||
: BaseT(TypeID::get<T>(), loc, std::forward<Args>(args)...) {}
|
||||
};
|
||||
|
||||
/// Return the type identifier of this node.
|
||||
TypeID getTypeID() const { return typeID; }
|
||||
|
||||
/// Return the location of this node.
|
||||
llvm::SMRange getLoc() const { return loc; }
|
||||
|
||||
/// Print this node to the given stream.
|
||||
void print(raw_ostream &os) const;
|
||||
|
||||
protected:
|
||||
Node(TypeID typeID, llvm::SMRange loc) : typeID(typeID), loc(loc) {}
|
||||
|
||||
private:
|
||||
/// A unique type identifier for this node.
|
||||
TypeID typeID;
|
||||
|
||||
/// The location of this node.
|
||||
llvm::SMRange loc;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Stmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a base AST Statement node.
|
||||
class Stmt : public Node {
|
||||
public:
|
||||
using Node::Node;
|
||||
|
||||
/// Provide type casting support.
|
||||
static bool classof(const Node *node);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CompoundStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This statement represents a compound statement, which contains a collection
|
||||
/// of other statements.
|
||||
class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
|
||||
private llvm::TrailingObjects<CompoundStmt, Stmt *> {
|
||||
public:
|
||||
static CompoundStmt *create(Context &ctx, llvm::SMRange location,
|
||||
ArrayRef<Stmt *> children);
|
||||
|
||||
/// Return the children of this compound statement.
|
||||
MutableArrayRef<Stmt *> getChildren() {
|
||||
return {getTrailingObjects<Stmt *>(), numChildren};
|
||||
}
|
||||
ArrayRef<Stmt *> getChildren() const {
|
||||
return const_cast<CompoundStmt *>(this)->getChildren();
|
||||
}
|
||||
ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
|
||||
ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
|
||||
|
||||
private:
|
||||
CompoundStmt(llvm::SMRange location, unsigned numChildren)
|
||||
: Base(location), numChildren(numChildren) {}
|
||||
|
||||
/// The number of held children statements.
|
||||
unsigned numChildren;
|
||||
|
||||
// Allow access to various privates.
|
||||
friend class llvm::TrailingObjects<CompoundStmt, Stmt *>;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LetStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This statement represents a `let` statement in PDLL. This statement is used
|
||||
/// to define variables.
|
||||
class LetStmt final : public Node::NodeBase<LetStmt, Stmt> {
|
||||
public:
|
||||
static LetStmt *create(Context &ctx, llvm::SMRange loc,
|
||||
VariableDecl *varDecl);
|
||||
|
||||
/// Return the variable defined by this statement.
|
||||
VariableDecl *getVarDecl() const { return varDecl; }
|
||||
|
||||
private:
|
||||
LetStmt(llvm::SMRange loc, VariableDecl *varDecl)
|
||||
: Base(loc), varDecl(varDecl) {}
|
||||
|
||||
/// The variable defined by this statement.
|
||||
VariableDecl *varDecl;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpRewriteStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a base operation rewrite statement. Operation rewrite
|
||||
/// statements perform a set of transformations on a given root operation.
|
||||
class OpRewriteStmt : public Stmt {
|
||||
public:
|
||||
/// Provide type casting support.
|
||||
static bool classof(const Node *node);
|
||||
|
||||
/// Return the root operation of this rewrite.
|
||||
Expr *getRootOpExpr() const { return rootOp; }
|
||||
|
||||
protected:
|
||||
OpRewriteStmt(TypeID typeID, llvm::SMRange loc, Expr *rootOp)
|
||||
: Stmt(typeID, loc), rootOp(rootOp) {}
|
||||
|
||||
protected:
|
||||
/// The root operation being rewritten.
|
||||
Expr *rootOp;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EraseStmt
|
||||
|
||||
/// This statement represents the `erase` statement in PDLL. This statement
|
||||
/// erases the given root operation, corresponding roughly to the
|
||||
/// PatternRewriter::eraseOp API.
|
||||
class EraseStmt final : public Node::NodeBase<EraseStmt, OpRewriteStmt> {
|
||||
public:
|
||||
static EraseStmt *create(Context &ctx, llvm::SMRange loc, Expr *rootOp);
|
||||
|
||||
private:
|
||||
EraseStmt(llvm::SMRange loc, Expr *rootOp) : Base(loc, rootOp) {}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Expr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a base AST Expression node.
|
||||
class Expr : public Stmt {
|
||||
public:
|
||||
/// Return the type of this expression.
|
||||
Type getType() const { return type; }
|
||||
|
||||
/// Provide type casting support.
|
||||
static bool classof(const Node *node);
|
||||
|
||||
protected:
|
||||
Expr(TypeID typeID, llvm::SMRange loc, Type type)
|
||||
: Stmt(typeID, loc), type(type) {}
|
||||
|
||||
private:
|
||||
/// The type of this expression.
|
||||
Type type;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DeclRefExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This expression represents a reference to a Decl node.
|
||||
class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> {
|
||||
public:
|
||||
static DeclRefExpr *create(Context &ctx, llvm::SMRange loc, Decl *decl,
|
||||
Type type);
|
||||
|
||||
/// Get the decl referenced by this expression.
|
||||
Decl *getDecl() const { return decl; }
|
||||
|
||||
private:
|
||||
DeclRefExpr(llvm::SMRange loc, Decl *decl, Type type)
|
||||
: Base(loc, type), decl(decl) {}
|
||||
|
||||
/// The decl referenced by this expression.
|
||||
Decl *decl;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemberAccessExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This expression represents a named member or field access of a given parent
|
||||
/// expression.
|
||||
class MemberAccessExpr : public Node::NodeBase<MemberAccessExpr, Expr> {
|
||||
public:
|
||||
static MemberAccessExpr *create(Context &ctx, llvm::SMRange loc,
|
||||
const Expr *parentExpr, StringRef memberName,
|
||||
Type type);
|
||||
|
||||
/// Get the parent expression of this access.
|
||||
const Expr *getParentExpr() const { return parentExpr; }
|
||||
|
||||
/// Return the name of the member being accessed.
|
||||
StringRef getMemberName() const { return memberName; }
|
||||
|
||||
private:
|
||||
MemberAccessExpr(llvm::SMRange loc, const Expr *parentExpr,
|
||||
StringRef memberName, Type type)
|
||||
: Base(loc, type), parentExpr(parentExpr), memberName(memberName) {}
|
||||
|
||||
/// The parent expression of this access.
|
||||
const Expr *parentExpr;
|
||||
|
||||
/// The name of the member being accessed from the parent.
|
||||
StringRef memberName;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Decl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents the base Decl node.
|
||||
class Decl : public Node {
|
||||
public:
|
||||
/// Return the name of the decl, or nullptr if it doesn't have one.
|
||||
const Name *getName() const { return name; }
|
||||
|
||||
/// Provide type casting support.
|
||||
static bool classof(const Node *node);
|
||||
|
||||
protected:
|
||||
Decl(TypeID typeID, llvm::SMRange loc, const Name *name = nullptr)
|
||||
: Node(typeID, loc), name(name) {}
|
||||
|
||||
private:
|
||||
/// The name of the decl. This is optional for some decls, such as
|
||||
/// PatternDecl.
|
||||
const Name *name;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstraintDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents the base of all AST Constraint decls. Constraints
|
||||
/// apply matcher conditions to, and define the type of PDLL variables.
|
||||
class ConstraintDecl : public Decl {
|
||||
public:
|
||||
/// Provide type casting support.
|
||||
static bool classof(const Node *node);
|
||||
|
||||
protected:
|
||||
ConstraintDecl(TypeID typeID, llvm::SMRange loc, const Name *name = nullptr)
|
||||
: Decl(typeID, loc, name) {}
|
||||
};
|
||||
|
||||
/// This class represents a reference to a constraint, and contains a constraint
|
||||
/// and the location of the reference.
|
||||
struct ConstraintRef {
|
||||
ConstraintRef(const ConstraintDecl *constraint, llvm::SMRange refLoc)
|
||||
: constraint(constraint), referenceLoc(refLoc) {}
|
||||
explicit ConstraintRef(const ConstraintDecl *constraint)
|
||||
: ConstraintRef(constraint, constraint->getLoc()) {}
|
||||
|
||||
const ConstraintDecl *constraint;
|
||||
llvm::SMRange referenceLoc;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CoreConstraintDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents the base of all "core" constraints. Core constraints
|
||||
/// are those that generally represent a concrete IR construct, such as
|
||||
/// `Type`s or `Value`s.
|
||||
class CoreConstraintDecl : public ConstraintDecl {
|
||||
public:
|
||||
/// Provide type casting support.
|
||||
static bool classof(const Node *node);
|
||||
|
||||
protected:
|
||||
CoreConstraintDecl(TypeID typeID, llvm::SMRange loc,
|
||||
const Name *name = nullptr)
|
||||
: ConstraintDecl(typeID, loc, name) {}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttrConstraintDecl
|
||||
|
||||
/// The class represents an Attribute constraint, and constrains a variable to
|
||||
/// be an Attribute.
|
||||
class AttrConstraintDecl
|
||||
: public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> {
|
||||
public:
|
||||
static AttrConstraintDecl *create(Context &ctx, llvm::SMRange loc,
|
||||
Expr *typeExpr = nullptr);
|
||||
|
||||
/// Return the optional type the attribute is constrained to.
|
||||
Expr *getTypeExpr() { return typeExpr; }
|
||||
const Expr *getTypeExpr() const { return typeExpr; }
|
||||
|
||||
protected:
|
||||
AttrConstraintDecl(llvm::SMRange loc, Expr *typeExpr)
|
||||
: Base(loc), typeExpr(typeExpr) {}
|
||||
|
||||
/// An optional type that the attribute is constrained to.
|
||||
Expr *typeExpr;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpConstraintDecl
|
||||
|
||||
/// The class represents an Operation constraint, and constrains a variable to
|
||||
/// be an Operation.
|
||||
class OpConstraintDecl
|
||||
: public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> {
|
||||
public:
|
||||
static OpConstraintDecl *create(Context &ctx, llvm::SMRange loc,
|
||||
const OpNameDecl *nameDecl = nullptr);
|
||||
|
||||
/// Return the name of the operation, or None if there isn't one.
|
||||
Optional<StringRef> getName() const;
|
||||
|
||||
/// Return the declaration of the operation name.
|
||||
const OpNameDecl *getNameDecl() const { return nameDecl; }
|
||||
|
||||
protected:
|
||||
explicit OpConstraintDecl(llvm::SMRange loc, const OpNameDecl *nameDecl)
|
||||
: Base(loc), nameDecl(nameDecl) {}
|
||||
|
||||
/// The operation name of this constraint.
|
||||
const OpNameDecl *nameDecl;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeConstraintDecl
|
||||
|
||||
/// The class represents a Type constraint, and constrains a variable to be a
|
||||
/// Type.
|
||||
class TypeConstraintDecl
|
||||
: public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> {
|
||||
public:
|
||||
static TypeConstraintDecl *create(Context &ctx, llvm::SMRange loc);
|
||||
|
||||
protected:
|
||||
using Base::Base;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeRangeConstraintDecl
|
||||
|
||||
/// The class represents a TypeRange constraint, and constrains a variable to be
|
||||
/// a TypeRange.
|
||||
class TypeRangeConstraintDecl
|
||||
: public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> {
|
||||
public:
|
||||
static TypeRangeConstraintDecl *create(Context &ctx, llvm::SMRange loc);
|
||||
|
||||
protected:
|
||||
using Base::Base;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ValueConstraintDecl
|
||||
|
||||
/// The class represents a Value constraint, and constrains a variable to be a
|
||||
/// Value.
|
||||
class ValueConstraintDecl
|
||||
: public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {
|
||||
public:
|
||||
static ValueConstraintDecl *create(Context &ctx, llvm::SMRange loc,
|
||||
Expr *typeExpr);
|
||||
|
||||
/// Return the optional type the value is constrained to.
|
||||
Expr *getTypeExpr() { return typeExpr; }
|
||||
const Expr *getTypeExpr() const { return typeExpr; }
|
||||
|
||||
protected:
|
||||
ValueConstraintDecl(llvm::SMRange loc, Expr *typeExpr)
|
||||
: Base(loc), typeExpr(typeExpr) {}
|
||||
|
||||
/// An optional type that the value is constrained to.
|
||||
Expr *typeExpr;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ValueRangeConstraintDecl
|
||||
|
||||
/// The class represents a ValueRange constraint, and constrains a variable to
|
||||
/// be a ValueRange.
|
||||
class ValueRangeConstraintDecl
|
||||
: public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> {
|
||||
public:
|
||||
static ValueRangeConstraintDecl *create(Context &ctx, llvm::SMRange loc,
|
||||
Expr *typeExpr);
|
||||
|
||||
/// Return the optional type the value range is constrained to.
|
||||
Expr *getTypeExpr() { return typeExpr; }
|
||||
const Expr *getTypeExpr() const { return typeExpr; }
|
||||
|
||||
protected:
|
||||
ValueRangeConstraintDecl(llvm::SMRange loc, Expr *typeExpr)
|
||||
: Base(loc), typeExpr(typeExpr) {}
|
||||
|
||||
/// An optional type that the value range is constrained to.
|
||||
Expr *typeExpr;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpNameDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This Decl represents an OperationName.
|
||||
class OpNameDecl : public Node::NodeBase<OpNameDecl, Decl> {
|
||||
public:
|
||||
static OpNameDecl *create(Context &ctx, const Name &name);
|
||||
static OpNameDecl *create(Context &ctx, llvm::SMRange loc);
|
||||
|
||||
/// Return the name of this operation, or none if the name is unknown.
|
||||
Optional<StringRef> getName() const {
|
||||
const Name *name = Decl::getName();
|
||||
return name ? Optional<StringRef>(name->getName()) : llvm::None;
|
||||
}
|
||||
|
||||
private:
|
||||
explicit OpNameDecl(const Name &name) : Base(name.getLoc(), &name) {}
|
||||
explicit OpNameDecl(llvm::SMRange loc) : Base(loc) {}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This Decl represents a single Pattern.
|
||||
class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
|
||||
public:
|
||||
static PatternDecl *create(Context &ctx, llvm::SMRange location,
|
||||
const Name *name, Optional<uint16_t> benefit,
|
||||
bool hasBoundedRecursion,
|
||||
const CompoundStmt *body);
|
||||
|
||||
/// Return the benefit of this pattern if specified, or None.
|
||||
Optional<uint16_t> getBenefit() const { return benefit; }
|
||||
|
||||
/// Return if this pattern has bounded rewrite recursion.
|
||||
bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
|
||||
|
||||
/// Return the body of this pattern.
|
||||
const CompoundStmt *getBody() const { return patternBody; }
|
||||
|
||||
/// Return the root rewrite statement of this pattern.
|
||||
const OpRewriteStmt *getRootRewriteStmt() const {
|
||||
return cast<OpRewriteStmt>(patternBody->getChildren().back());
|
||||
}
|
||||
|
||||
private:
|
||||
PatternDecl(llvm::SMRange loc, const Name *name, Optional<uint16_t> benefit,
|
||||
bool hasBoundedRecursion, const CompoundStmt *body)
|
||||
: Base(loc, name), benefit(benefit),
|
||||
hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {}
|
||||
|
||||
/// The benefit of the pattern if it was explicitly specified, None otherwise.
|
||||
Optional<uint16_t> benefit;
|
||||
|
||||
/// If the pattern has properly bounded rewrite recursion or not.
|
||||
bool hasBoundedRecursion;
|
||||
|
||||
/// The compound statement representing the body of the pattern.
|
||||
const CompoundStmt *patternBody;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VariableDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This Decl represents the definition of a PDLL variable.
|
||||
class VariableDecl final
|
||||
: public Node::NodeBase<VariableDecl, Decl>,
|
||||
private llvm::TrailingObjects<VariableDecl, ConstraintRef> {
|
||||
public:
|
||||
static VariableDecl *create(Context &ctx, const Name &name, Type type,
|
||||
Expr *initExpr,
|
||||
ArrayRef<ConstraintRef> constraints);
|
||||
|
||||
/// Return the constraints of this variable.
|
||||
MutableArrayRef<ConstraintRef> getConstraints() {
|
||||
return {getTrailingObjects<ConstraintRef>(), numConstraints};
|
||||
}
|
||||
ArrayRef<ConstraintRef> getConstraints() const {
|
||||
return const_cast<VariableDecl *>(this)->getConstraints();
|
||||
}
|
||||
|
||||
/// Return the initializer expression of this statement, or nullptr if there
|
||||
/// was no initializer.
|
||||
Expr *getInitExpr() const { return initExpr; }
|
||||
|
||||
/// Return the name of the decl.
|
||||
const Name &getName() const { return *Decl::getName(); }
|
||||
|
||||
/// Return the type of the decl.
|
||||
Type getType() const { return type; }
|
||||
|
||||
private:
|
||||
VariableDecl(const Name &name, Type type, Expr *initExpr,
|
||||
unsigned numConstraints)
|
||||
: Base(name.getLoc(), &name), type(type), initExpr(initExpr),
|
||||
numConstraints(numConstraints) {}
|
||||
|
||||
/// The type of the variable.
|
||||
Type type;
|
||||
|
||||
/// The optional initializer expression of this statement.
|
||||
Expr *initExpr;
|
||||
|
||||
/// The number of constraints attached to this variable.
|
||||
unsigned numConstraints;
|
||||
|
||||
/// Allow access to various internals.
|
||||
friend llvm::TrailingObjects<VariableDecl, ConstraintRef>;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a top-level AST module.
|
||||
class Module final : public Node::NodeBase<Module, Node>,
|
||||
private llvm::TrailingObjects<Module, Decl *> {
|
||||
public:
|
||||
static Module *create(Context &ctx, llvm::SMLoc loc,
|
||||
ArrayRef<Decl *> children);
|
||||
|
||||
/// Return the children of this module.
|
||||
MutableArrayRef<Decl *> getChildren() {
|
||||
return {getTrailingObjects<Decl *>(), numChildren};
|
||||
}
|
||||
ArrayRef<Decl *> getChildren() const {
|
||||
return const_cast<Module *>(this)->getChildren();
|
||||
}
|
||||
|
||||
private:
|
||||
Module(llvm::SMLoc loc, unsigned numChildren)
|
||||
: Base(llvm::SMRange{loc, loc}), numChildren(numChildren) {}
|
||||
|
||||
/// The number of decls held by this module.
|
||||
unsigned numChildren;
|
||||
|
||||
/// Allow access to various internals.
|
||||
friend llvm::TrailingObjects<Module, Decl *>;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Defered Method Definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
inline bool Decl::classof(const Node *node) {
|
||||
return isa<ConstraintDecl, OpNameDecl, PatternDecl, VariableDecl>(node);
|
||||
}
|
||||
|
||||
inline bool ConstraintDecl::classof(const Node *node) {
|
||||
return isa<CoreConstraintDecl>(node);
|
||||
}
|
||||
|
||||
inline bool CoreConstraintDecl::classof(const Node *node) {
|
||||
return isa<AttrConstraintDecl, OpConstraintDecl, TypeConstraintDecl,
|
||||
TypeRangeConstraintDecl, ValueConstraintDecl,
|
||||
ValueRangeConstraintDecl>(node);
|
||||
}
|
||||
|
||||
inline bool Expr::classof(const Node *node) {
|
||||
return isa<DeclRefExpr, MemberAccessExpr>(node);
|
||||
}
|
||||
|
||||
inline bool OpRewriteStmt::classof(const Node *node) {
|
||||
return isa<EraseStmt>(node);
|
||||
}
|
||||
|
||||
inline bool Stmt::classof(const Node *node) {
|
||||
return isa<CompoundStmt, LetStmt, OpRewriteStmt, Expr>(node);
|
||||
}
|
||||
|
||||
} // namespace ast
|
||||
} // namespace pdll
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_TOOLS_PDLL_AST_NODES_H_
|
||||
|
|
@ -0,0 +1,257 @@
|
|||
//===- Types.h --------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TOOLS_PDLL_AST_TYPES_H_
|
||||
#define MLIR_TOOLS_PDLL_AST_TYPES_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/StorageUniquer.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace pdll {
|
||||
namespace ast {
|
||||
class Context;
|
||||
|
||||
namespace detail {
|
||||
struct AttributeTypeStorage;
|
||||
struct ConstraintTypeStorage;
|
||||
struct OperationTypeStorage;
|
||||
struct RangeTypeStorage;
|
||||
struct TypeTypeStorage;
|
||||
struct ValueTypeStorage;
|
||||
} // namespace detail
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class Type {
|
||||
public:
|
||||
/// This class represents the internal storage of the Type class.
|
||||
struct Storage;
|
||||
|
||||
/// This class provides several utilities when defining derived type classes.
|
||||
template <typename ImplT, typename BaseT = Type>
|
||||
class TypeBase : public BaseT {
|
||||
public:
|
||||
using Base = TypeBase<ImplT, BaseT>;
|
||||
using ImplTy = ImplT;
|
||||
using BaseT::BaseT;
|
||||
|
||||
/// Provide type casting support.
|
||||
static bool classof(Type type) {
|
||||
return type.getTypeID() == TypeID::get<ImplTy>();
|
||||
}
|
||||
};
|
||||
|
||||
Type(Storage *impl = nullptr) : impl(impl) {}
|
||||
Type(const Type &other) = default;
|
||||
|
||||
bool operator==(const Type &other) const { return impl == other.impl; }
|
||||
bool operator!=(const Type &other) const { return !(*this == other); }
|
||||
explicit operator bool() const { return impl; }
|
||||
|
||||
/// Provide type casting support.
|
||||
template <typename U> bool isa() const {
|
||||
assert(impl && "isa<> used on a null type.");
|
||||
return U::classof(*this);
|
||||
}
|
||||
template <typename U, typename V, typename... Others> bool isa() const {
|
||||
return isa<U>() || isa<V, Others...>();
|
||||
}
|
||||
template <typename U> U dyn_cast() const {
|
||||
return isa<U>() ? U(impl) : U(nullptr);
|
||||
}
|
||||
template <typename U> U dyn_cast_or_null() const {
|
||||
return (impl && isa<U>()) ? U(impl) : U(nullptr);
|
||||
}
|
||||
template <typename U> U cast() const {
|
||||
assert(isa<U>());
|
||||
return U(impl);
|
||||
}
|
||||
|
||||
/// Return the internal storage instance of this type.
|
||||
Storage *getImpl() const { return impl; }
|
||||
|
||||
/// Return the TypeID instance of this type.
|
||||
TypeID getTypeID() const;
|
||||
|
||||
/// Print this type to the given stream.
|
||||
void print(raw_ostream &os) const;
|
||||
|
||||
/// Try to refine this type with the one provided. Given two compatible types,
|
||||
/// this will return a merged type contains as much detail from the two types.
|
||||
/// For example, if refining two operation types and one contains a name,
|
||||
/// while the other doesn't, the refined type contains the name. If the two
|
||||
/// types are incompatible, null is returned.
|
||||
Type refineWith(Type other) const;
|
||||
|
||||
protected:
|
||||
/// Return the internal storage instance of this type reinterpreted as the
|
||||
/// given derived storage type.
|
||||
template <typename T> const T *getImplAs() const {
|
||||
return static_cast<const T *>(impl);
|
||||
}
|
||||
|
||||
private:
|
||||
Storage *impl;
|
||||
};
|
||||
|
||||
inline llvm::hash_code hash_value(Type type) {
|
||||
return DenseMapInfo<Type::Storage *>::getHashValue(type.getImpl());
|
||||
}
|
||||
|
||||
inline raw_ostream &operator<<(raw_ostream &os, Type type) {
|
||||
type.print(os);
|
||||
return os;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttributeType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a PDLL type that corresponds to an mlir::Attribute.
|
||||
class AttributeType : public Type::TypeBase<detail::AttributeTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Return an instance of the Attribute type.
|
||||
static AttributeType get(Context &context);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstraintType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a PDLL type that corresponds to a constraint. This
|
||||
/// type has no MLIR C++ API correspondance.
|
||||
class ConstraintType : public Type::TypeBase<detail::ConstraintTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Return an instance of the Constraint type.
|
||||
static ConstraintType get(Context &context);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OperationType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a PDLL type that corresponds to an mlir::Operation.
|
||||
class OperationType : public Type::TypeBase<detail::OperationTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Return an instance of the Operation type with an optional operation name.
|
||||
/// If no name is provided, this type may refer to any operation.
|
||||
static OperationType get(Context &context,
|
||||
Optional<StringRef> name = llvm::None);
|
||||
|
||||
/// Return the name of this operation type, or None if it doesn't have on.
|
||||
Optional<StringRef> getName() const;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RangeType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a PDLL type that corresponds to a range of elements
|
||||
/// with a given element type.
|
||||
class RangeType : public Type::TypeBase<detail::RangeTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Return an instance of the Range type with the given element type.
|
||||
static RangeType get(Context &context, Type elementType);
|
||||
|
||||
/// Return the element type of this range.
|
||||
Type getElementType() const;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeRangeType
|
||||
|
||||
/// This class represents a PDLL type that corresponds to an mlir::TypeRange.
|
||||
class TypeRangeType : public RangeType {
|
||||
public:
|
||||
using RangeType::RangeType;
|
||||
|
||||
/// Provide type casting support.
|
||||
static bool classof(Type type);
|
||||
|
||||
/// Return an instance of the TypeRange type.
|
||||
static TypeRangeType get(Context &context);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ValueRangeType
|
||||
|
||||
/// This class represents a PDLL type that corresponds to an mlir::ValueRange.
|
||||
class ValueRangeType : public RangeType {
|
||||
public:
|
||||
using RangeType::RangeType;
|
||||
|
||||
/// Provide type casting support.
|
||||
static bool classof(Type type);
|
||||
|
||||
/// Return an instance of the ValueRange type.
|
||||
static ValueRangeType get(Context &context);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a PDLL type that corresponds to an mlir::Type.
|
||||
class TypeType : public Type::TypeBase<detail::TypeTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Return an instance of the Type type.
|
||||
static TypeType get(Context &context);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ValueType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a PDLL type that corresponds to an mlir::Value.
|
||||
class ValueType : public Type::TypeBase<detail::ValueTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Return an instance of the Value type.
|
||||
static ValueType get(Context &context);
|
||||
};
|
||||
|
||||
} // namespace ast
|
||||
} // namespace pdll
|
||||
} // namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
template <> struct DenseMapInfo<mlir::pdll::ast::Type> {
|
||||
static mlir::pdll::ast::Type getEmptyKey() {
|
||||
void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::pdll::ast::Type(
|
||||
static_cast<mlir::pdll::ast::Type::Storage *>(pointer));
|
||||
}
|
||||
static mlir::pdll::ast::Type getTombstoneKey() {
|
||||
void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return mlir::pdll::ast::Type(
|
||||
static_cast<mlir::pdll::ast::Type::Storage *>(pointer));
|
||||
}
|
||||
static unsigned getHashValue(mlir::pdll::ast::Type val) {
|
||||
return llvm::hash_value(val.getImpl());
|
||||
}
|
||||
static bool isEqual(mlir::pdll::ast::Type lhs, mlir::pdll::ast::Type rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
} // namespace llvm
|
||||
|
||||
#endif // MLIR_TOOLS_PDLL_AST_TYPES_H_
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
//===- Parser.h - MLIR PDLL Frontend Parser ---------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TOOLS_PDLL_PARSER_PARSER_H_
|
||||
#define MLIR_TOOLS_PDLL_PARSER_PARSER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace llvm {
|
||||
class SourceMgr;
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
namespace pdll {
|
||||
namespace ast {
|
||||
class Context;
|
||||
class Module;
|
||||
} // namespace ast
|
||||
|
||||
/// Parse an AST module from the main file of the given source manager.
|
||||
FailureOr<ast::Module *> parsePDLAST(ast::Context &ctx,
|
||||
llvm::SourceMgr &sourceMgr);
|
||||
} // namespace pdll
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_TOOLS_PDLL_PARSER_PARSER_H_
|
||||
|
|
@ -1,2 +1,3 @@
|
|||
add_subdirectory(mlir-lsp-server)
|
||||
add_subdirectory(mlir-reduce)
|
||||
add_subdirectory(PDLL)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
add_mlir_library(MLIRPDLLAST
|
||||
Context.cpp
|
||||
Diagnostic.cpp
|
||||
NodePrinter.cpp
|
||||
Nodes.cpp
|
||||
Types.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRSupport
|
||||
)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
//===- Context.cpp --------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Tools/PDLL/AST/Context.h"
|
||||
#include "TypeDetail.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::pdll::ast;
|
||||
|
||||
Context::Context() {
|
||||
typeUniquer.registerSingletonStorageType<detail::AttributeTypeStorage>();
|
||||
typeUniquer.registerSingletonStorageType<detail::ConstraintTypeStorage>();
|
||||
typeUniquer.registerSingletonStorageType<detail::TypeTypeStorage>();
|
||||
typeUniquer.registerSingletonStorageType<detail::ValueTypeStorage>();
|
||||
|
||||
typeUniquer.registerParametricStorageType<detail::OperationTypeStorage>();
|
||||
typeUniquer.registerParametricStorageType<detail::RangeTypeStorage>();
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
//===- Diagnostic.cpp -----------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::pdll::ast;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InFlightDiagnostic
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void InFlightDiagnostic::report() {
|
||||
// If this diagnostic is still inflight and it hasn't been abandoned, then
|
||||
// report it.
|
||||
if (isInFlight()) {
|
||||
owner->report(std::move(*impl));
|
||||
owner = nullptr;
|
||||
}
|
||||
impl.reset();
|
||||
}
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
//===- NodePrinter.cpp ----------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Tools/PDLL/AST/Context.h"
|
||||
#include "mlir/Tools/PDLL/AST/Nodes.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/SaveAndRestore.h"
|
||||
#include "llvm/Support/ScopedPrinter.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::pdll::ast;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NodePrinter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class NodePrinter {
|
||||
public:
|
||||
NodePrinter(raw_ostream &os) : os(os) {}
|
||||
|
||||
/// Print the given type to the stream.
|
||||
void print(Type type);
|
||||
|
||||
/// Print the given node to the stream.
|
||||
void print(const Node *node);
|
||||
|
||||
private:
|
||||
/// Print a range containing children of a node.
|
||||
template <typename RangeT,
|
||||
std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
|
||||
* = nullptr>
|
||||
void printChildren(RangeT &&range) {
|
||||
if (llvm::empty(range))
|
||||
return;
|
||||
|
||||
// Print the first N-1 elements with a prefix of "|-".
|
||||
auto it = std::begin(range);
|
||||
for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
|
||||
print(*it);
|
||||
|
||||
// Print the last element.
|
||||
elementIndentStack.back() = true;
|
||||
print(*it);
|
||||
}
|
||||
template <typename RangeT, typename... OthersT,
|
||||
std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
|
||||
* = nullptr>
|
||||
void printChildren(RangeT &&range, OthersT &&...others) {
|
||||
printChildren(ArrayRef<const Node *>({range, others...}));
|
||||
}
|
||||
/// Print a range containing children of a node, nesting the children under
|
||||
/// the given label.
|
||||
template <typename RangeT>
|
||||
void printChildren(StringRef label, RangeT &&range) {
|
||||
if (llvm::empty(range))
|
||||
return;
|
||||
elementIndentStack.reserve(elementIndentStack.size() + 1);
|
||||
llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true);
|
||||
|
||||
printIndent();
|
||||
os << label << "`\n";
|
||||
elementIndentStack.push_back(/*isLastElt*/ false);
|
||||
printChildren(std::forward<RangeT>(range));
|
||||
elementIndentStack.pop_back();
|
||||
}
|
||||
|
||||
/// Print the given derived node to the stream.
|
||||
void printImpl(const CompoundStmt *stmt);
|
||||
void printImpl(const EraseStmt *stmt);
|
||||
void printImpl(const LetStmt *stmt);
|
||||
|
||||
void printImpl(const DeclRefExpr *expr);
|
||||
void printImpl(const MemberAccessExpr *expr);
|
||||
|
||||
void printImpl(const AttrConstraintDecl *decl);
|
||||
void printImpl(const OpConstraintDecl *decl);
|
||||
void printImpl(const TypeConstraintDecl *decl);
|
||||
void printImpl(const TypeRangeConstraintDecl *decl);
|
||||
void printImpl(const ValueConstraintDecl *decl);
|
||||
void printImpl(const ValueRangeConstraintDecl *decl);
|
||||
void printImpl(const OpNameDecl *decl);
|
||||
void printImpl(const PatternDecl *decl);
|
||||
void printImpl(const VariableDecl *decl);
|
||||
void printImpl(const Module *module);
|
||||
|
||||
/// Print the current indent stack.
|
||||
void printIndent() {
|
||||
if (elementIndentStack.empty())
|
||||
return;
|
||||
|
||||
for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back())
|
||||
os << (isLastElt ? " " : " |");
|
||||
os << (elementIndentStack.back() ? " `" : " |");
|
||||
}
|
||||
|
||||
/// The raw output stream.
|
||||
raw_ostream &os;
|
||||
|
||||
/// A stack of indents and a flag indicating if the current element being
|
||||
/// printed at that indent is the last element.
|
||||
SmallVector<bool> elementIndentStack;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void NodePrinter::print(Type type) {
|
||||
// Protect against invalid inputs.
|
||||
if (!type) {
|
||||
os << "Type<NULL>";
|
||||
return;
|
||||
}
|
||||
|
||||
TypeSwitch<Type>(type)
|
||||
.Case([&](AttributeType) { os << "Attr"; })
|
||||
.Case([&](ConstraintType) { os << "Constraint"; })
|
||||
.Case([&](OperationType type) {
|
||||
os << "Op";
|
||||
if (Optional<StringRef> name = type.getName())
|
||||
os << "<" << *name << ">";
|
||||
})
|
||||
.Case([&](RangeType type) {
|
||||
print(type.getElementType());
|
||||
os << "Range";
|
||||
})
|
||||
.Case([&](TypeType) { os << "Type"; })
|
||||
.Case([&](ValueType) { os << "Value"; })
|
||||
.Default([](Type) { llvm_unreachable("unknown AST type"); });
|
||||
}
|
||||
|
||||
void NodePrinter::print(const Node *node) {
|
||||
printIndent();
|
||||
os << "-";
|
||||
|
||||
elementIndentStack.push_back(/*isLastElt*/ false);
|
||||
TypeSwitch<const Node *>(node)
|
||||
.Case<
|
||||
// Statements.
|
||||
const CompoundStmt, const EraseStmt, const LetStmt,
|
||||
|
||||
// Expressions.
|
||||
const DeclRefExpr, const MemberAccessExpr,
|
||||
|
||||
// Decls.
|
||||
const AttrConstraintDecl, const OpConstraintDecl,
|
||||
const TypeConstraintDecl, const TypeRangeConstraintDecl,
|
||||
const ValueConstraintDecl, const ValueRangeConstraintDecl,
|
||||
const OpNameDecl, const PatternDecl, const VariableDecl,
|
||||
|
||||
const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
|
||||
.Default([](const Node *) { llvm_unreachable("unknown AST node"); });
|
||||
elementIndentStack.pop_back();
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const CompoundStmt *stmt) {
|
||||
os << "CompoundStmt " << stmt << "\n";
|
||||
printChildren(stmt->getChildren());
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const EraseStmt *stmt) {
|
||||
os << "EraseStmt " << stmt << "\n";
|
||||
printChildren(stmt->getRootOpExpr());
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const LetStmt *stmt) {
|
||||
os << "LetStmt " << stmt << "\n";
|
||||
printChildren(stmt->getVarDecl());
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const DeclRefExpr *expr) {
|
||||
os << "DeclRefExpr " << expr << " Type<";
|
||||
print(expr->getType());
|
||||
os << ">\n";
|
||||
printChildren(expr->getDecl());
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const MemberAccessExpr *expr) {
|
||||
os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
|
||||
<< "> Type<";
|
||||
print(expr->getType());
|
||||
os << ">\n";
|
||||
printChildren(expr->getParentExpr());
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
|
||||
os << "AttrConstraintDecl " << decl << "\n";
|
||||
if (const auto *typeExpr = decl->getTypeExpr())
|
||||
printChildren(typeExpr);
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const OpConstraintDecl *decl) {
|
||||
os << "OpConstraintDecl " << decl << "\n";
|
||||
printChildren(decl->getNameDecl());
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
|
||||
os << "TypeConstraintDecl " << decl << "\n";
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
|
||||
os << "TypeRangeConstraintDecl " << decl << "\n";
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
|
||||
os << "ValueConstraintDecl " << decl << "\n";
|
||||
if (const auto *typeExpr = decl->getTypeExpr())
|
||||
printChildren(typeExpr);
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
|
||||
os << "ValueRangeConstraintDecl " << decl << "\n";
|
||||
if (const auto *typeExpr = decl->getTypeExpr())
|
||||
printChildren(typeExpr);
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const OpNameDecl *decl) {
|
||||
os << "OpNameDecl " << decl;
|
||||
if (Optional<StringRef> name = decl->getName())
|
||||
os << " Name<" << name << ">";
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const PatternDecl *decl) {
|
||||
os << "PatternDecl " << decl;
|
||||
if (const Name *name = decl->getName())
|
||||
os << " Name<" << name->getName() << ">";
|
||||
if (Optional<uint16_t> benefit = decl->getBenefit())
|
||||
os << " Benefit<" << *benefit << ">";
|
||||
if (decl->hasBoundedRewriteRecursion())
|
||||
os << " Recursion";
|
||||
|
||||
os << "\n";
|
||||
printChildren(decl->getBody());
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const VariableDecl *decl) {
|
||||
os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
|
||||
<< "> Type<";
|
||||
print(decl->getType());
|
||||
os << ">\n";
|
||||
if (Expr *initExpr = decl->getInitExpr())
|
||||
printChildren(initExpr);
|
||||
|
||||
auto constraints =
|
||||
llvm::map_range(decl->getConstraints(),
|
||||
[](const ConstraintRef &ref) { return ref.constraint; });
|
||||
printChildren("Constraints", constraints);
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const Module *module) {
|
||||
os << "Module " << module << "\n";
|
||||
printChildren(module->getChildren());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Entry point
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
|
||||
|
||||
void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
|
||||
|
|
@ -0,0 +1,231 @@
|
|||
//===- Nodes.cpp ----------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Tools/PDLL/AST/Nodes.h"
|
||||
#include "mlir/Tools/PDLL/AST/Context.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::pdll::ast;
|
||||
|
||||
/// Copy a string reference into the context with a null terminator.
|
||||
static StringRef copyStringWithNull(Context &ctx, StringRef str) {
|
||||
if (str.empty())
|
||||
return str;
|
||||
|
||||
char *data = ctx.getAllocator().Allocate<char>(str.size() + 1);
|
||||
std::copy(str.begin(), str.end(), data);
|
||||
data[str.size()] = 0;
|
||||
return StringRef(data, str.size());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Name
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
const Name &Name::create(Context &ctx, StringRef name, llvm::SMRange location) {
|
||||
return *new (ctx.getAllocator().Allocate<Name>())
|
||||
Name(copyStringWithNull(ctx, name), location);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DeclScope
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void DeclScope::add(Decl *decl) {
|
||||
const Name *name = decl->getName();
|
||||
assert(name && "expected a named decl");
|
||||
assert(!decls.count(name->getName()) && "decl with this name already exists");
|
||||
decls.try_emplace(name->getName(), decl);
|
||||
}
|
||||
|
||||
Decl *DeclScope::lookup(StringRef name) {
|
||||
if (Decl *decl = decls.lookup(name))
|
||||
return decl;
|
||||
return parent ? parent->lookup(name) : nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CompoundStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
CompoundStmt *CompoundStmt::create(Context &ctx, llvm::SMRange loc,
|
||||
ArrayRef<Stmt *> children) {
|
||||
unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size());
|
||||
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt));
|
||||
|
||||
CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size());
|
||||
std::uninitialized_copy(children.begin(), children.end(),
|
||||
stmt->getChildren().begin());
|
||||
return stmt;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LetStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LetStmt *LetStmt::create(Context &ctx, llvm::SMRange loc,
|
||||
VariableDecl *varDecl) {
|
||||
return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpRewriteStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EraseStmt
|
||||
|
||||
EraseStmt *EraseStmt::create(Context &ctx, llvm::SMRange loc, Expr *rootOp) {
|
||||
return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DeclRefExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DeclRefExpr *DeclRefExpr::create(Context &ctx, llvm::SMRange loc, Decl *decl,
|
||||
Type type) {
|
||||
return new (ctx.getAllocator().Allocate<DeclRefExpr>())
|
||||
DeclRefExpr(loc, decl, type);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemberAccessExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MemberAccessExpr *MemberAccessExpr::create(Context &ctx, llvm::SMRange loc,
|
||||
const Expr *parentExpr,
|
||||
StringRef memberName, Type type) {
|
||||
return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr(
|
||||
loc, parentExpr, memberName.copy(ctx.getAllocator()), type);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttrConstraintDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, llvm::SMRange loc,
|
||||
Expr *typeExpr) {
|
||||
return new (ctx.getAllocator().Allocate<AttrConstraintDecl>())
|
||||
AttrConstraintDecl(loc, typeExpr);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpConstraintDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpConstraintDecl *OpConstraintDecl::create(Context &ctx, llvm::SMRange loc,
|
||||
const OpNameDecl *nameDecl) {
|
||||
if (!nameDecl)
|
||||
nameDecl = OpNameDecl::create(ctx, llvm::SMRange());
|
||||
|
||||
return new (ctx.getAllocator().Allocate<OpConstraintDecl>())
|
||||
OpConstraintDecl(loc, nameDecl);
|
||||
}
|
||||
|
||||
Optional<StringRef> OpConstraintDecl::getName() const {
|
||||
return getNameDecl()->getName();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeConstraintDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx,
|
||||
llvm::SMRange loc) {
|
||||
return new (ctx.getAllocator().Allocate<TypeConstraintDecl>())
|
||||
TypeConstraintDecl(loc);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeRangeConstraintDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx,
|
||||
llvm::SMRange loc) {
|
||||
return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>())
|
||||
TypeRangeConstraintDecl(loc);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ValueConstraintDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ValueConstraintDecl *
|
||||
ValueConstraintDecl::create(Context &ctx, llvm::SMRange loc, Expr *typeExpr) {
|
||||
return new (ctx.getAllocator().Allocate<ValueConstraintDecl>())
|
||||
ValueConstraintDecl(loc, typeExpr);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ValueRangeConstraintDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
|
||||
llvm::SMRange loc,
|
||||
Expr *typeExpr) {
|
||||
return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>())
|
||||
ValueRangeConstraintDecl(loc, typeExpr);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpNameDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) {
|
||||
return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name);
|
||||
}
|
||||
OpNameDecl *OpNameDecl::create(Context &ctx, llvm::SMRange loc) {
|
||||
return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternDecl *PatternDecl::create(Context &ctx, llvm::SMRange loc,
|
||||
const Name *name, Optional<uint16_t> benefit,
|
||||
bool hasBoundedRecursion,
|
||||
const CompoundStmt *body) {
|
||||
return new (ctx.getAllocator().Allocate<PatternDecl>())
|
||||
PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VariableDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type,
|
||||
Expr *initExpr,
|
||||
ArrayRef<ConstraintRef> constraints) {
|
||||
unsigned allocSize =
|
||||
VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size());
|
||||
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl));
|
||||
|
||||
VariableDecl *varDecl =
|
||||
new (rawData) VariableDecl(name, type, initExpr, constraints.size());
|
||||
std::uninitialized_copy(constraints.begin(), constraints.end(),
|
||||
varDecl->getConstraints().begin());
|
||||
return varDecl;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Module *Module::create(Context &ctx, llvm::SMLoc loc,
|
||||
ArrayRef<Decl *> children) {
|
||||
unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size());
|
||||
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module));
|
||||
|
||||
Module *module = new (rawData) Module(loc, children.size());
|
||||
std::uninitialized_copy(children.begin(), children.end(),
|
||||
module->getChildren().begin());
|
||||
return module;
|
||||
}
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
//===- TypeDetail.h ---------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_
|
||||
#define LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_
|
||||
|
||||
#include "mlir/Tools/PDLL/AST/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace pdll {
|
||||
namespace ast {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct Type::Storage : public StorageUniquer::BaseStorage {
|
||||
Storage(TypeID typeID) : typeID(typeID) {}
|
||||
|
||||
/// The type identifier for the derived type class.
|
||||
TypeID typeID;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// A utility CRTP base class that defines many of the necessary utilities for
|
||||
/// defining a PDLL AST Type.
|
||||
template <typename ConcreteT, typename KeyT = void>
|
||||
struct TypeStorageBase : public Type::Storage {
|
||||
using KeyTy = KeyT;
|
||||
using Base = TypeStorageBase<ConcreteT, KeyT>;
|
||||
|
||||
/// Construct an instance with the given storage allocator.
|
||||
static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
|
||||
const KeyTy &key) {
|
||||
return new (alloc.allocate<ConcreteT>()) ConcreteT(key);
|
||||
}
|
||||
|
||||
/// Utility methods required by the storage allocator.
|
||||
bool operator==(const KeyTy &key) const { return this->key == key; }
|
||||
|
||||
/// Return the key value of this storage class.
|
||||
const KeyTy &getValue() const { return key; }
|
||||
|
||||
protected:
|
||||
TypeStorageBase(KeyTy key)
|
||||
: Type::Storage(TypeID::get<ConcreteT>()), key(key) {}
|
||||
|
||||
KeyTy key;
|
||||
};
|
||||
/// A specialization of the storage base for singleton types.
|
||||
template <typename ConcreteT>
|
||||
struct TypeStorageBase<ConcreteT, void> : public Type::Storage {
|
||||
using Base = TypeStorageBase<ConcreteT, void>;
|
||||
|
||||
protected:
|
||||
TypeStorageBase() : Type::Storage(TypeID::get<ConcreteT>()) {}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttributeType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct AttributeTypeStorage : public TypeStorageBase<AttributeTypeStorage> {};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstraintType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct ConstraintTypeStorage : public TypeStorageBase<ConstraintTypeStorage> {};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OperationType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct OperationTypeStorage
|
||||
: public TypeStorageBase<OperationTypeStorage, StringRef> {
|
||||
using Base::Base;
|
||||
|
||||
static OperationTypeStorage *
|
||||
construct(StorageUniquer::StorageAllocator &alloc, StringRef key) {
|
||||
return new (alloc.allocate<OperationTypeStorage>())
|
||||
OperationTypeStorage(alloc.copyInto(key));
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RangeType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct RangeTypeStorage : public TypeStorageBase<RangeTypeStorage, Type> {
|
||||
using Base::Base;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct TypeTypeStorage : public TypeStorageBase<TypeTypeStorage> {};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ValueType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct ValueTypeStorage : public TypeStorageBase<ValueTypeStorage> {};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace ast
|
||||
} // namespace pdll
|
||||
} // namespace mlir
|
||||
|
||||
#endif // LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
//===- Types.cpp ----------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Tools/PDLL/AST/Types.h"
|
||||
#include "TypeDetail.h"
|
||||
#include "mlir/Tools/PDLL/AST/Context.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::pdll::ast;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TypeID Type::getTypeID() const { return impl->typeID; }
|
||||
|
||||
Type Type::refineWith(Type other) const {
|
||||
if (*this == other)
|
||||
return *this;
|
||||
|
||||
// Operation types are compatible if the operation names don't conflict.
|
||||
if (auto opTy = dyn_cast<OperationType>()) {
|
||||
auto otherOpTy = other.dyn_cast<ast::OperationType>();
|
||||
if (!otherOpTy)
|
||||
return nullptr;
|
||||
if (!otherOpTy.getName())
|
||||
return *this;
|
||||
if (!opTy.getName())
|
||||
return other;
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttributeType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AttributeType AttributeType::get(Context &context) {
|
||||
return context.getTypeUniquer().get<ImplTy>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstraintType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ConstraintType ConstraintType::get(Context &context) {
|
||||
return context.getTypeUniquer().get<ImplTy>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OperationType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperationType OperationType::get(Context &context, Optional<StringRef> name) {
|
||||
return context.getTypeUniquer().get<ImplTy>(
|
||||
/*initFn=*/function_ref<void(ImplTy *)>(), name.getValueOr(""));
|
||||
}
|
||||
|
||||
Optional<StringRef> OperationType::getName() const {
|
||||
StringRef name = getImplAs<ImplTy>()->getValue();
|
||||
return name.empty() ? Optional<StringRef>() : Optional<StringRef>(name);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RangeType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
RangeType RangeType::get(Context &context, Type elementType) {
|
||||
return context.getTypeUniquer().get<ImplTy>(
|
||||
/*initFn=*/function_ref<void(ImplTy *)>(), elementType);
|
||||
}
|
||||
|
||||
Type RangeType::getElementType() const {
|
||||
return getImplAs<ImplTy>()->getValue();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeRangeType
|
||||
|
||||
bool TypeRangeType::classof(Type type) {
|
||||
RangeType range = type.dyn_cast<RangeType>();
|
||||
return range && range.getElementType().isa<TypeType>();
|
||||
}
|
||||
|
||||
TypeRangeType TypeRangeType::get(Context &context) {
|
||||
return RangeType::get(context, TypeType::get(context)).cast<TypeRangeType>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ValueRangeType
|
||||
|
||||
bool ValueRangeType::classof(Type type) {
|
||||
RangeType range = type.dyn_cast<RangeType>();
|
||||
return range && range.getElementType().isa<ValueType>();
|
||||
}
|
||||
|
||||
ValueRangeType ValueRangeType::get(Context &context) {
|
||||
return RangeType::get(context, ValueType::get(context))
|
||||
.cast<ValueRangeType>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TypeType TypeType::get(Context &context) {
|
||||
return context.getTypeUniquer().get<ImplTy>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ValueType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ValueType ValueType::get(Context &context) {
|
||||
return context.getTypeUniquer().get<ImplTy>();
|
||||
}
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
add_subdirectory(AST)
|
||||
add_subdirectory(Parser)
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
add_mlir_library(MLIRPDLLParser
|
||||
Lexer.cpp
|
||||
Parser.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRPDLLAST
|
||||
MLIRSupport
|
||||
)
|
||||
|
|
@ -0,0 +1,366 @@
|
|||
//===- Lexer.cpp ----------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "Lexer.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::pdll;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Token
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::string Token::getStringValue() const {
|
||||
assert(getKind() == string || getKind() == string_block);
|
||||
|
||||
// Start by dropping the quotes.
|
||||
StringRef bytes = getSpelling().drop_front().drop_back();
|
||||
if (is(string_block)) bytes = bytes.drop_front().drop_back();
|
||||
|
||||
std::string result;
|
||||
result.reserve(bytes.size());
|
||||
for (unsigned i = 0, e = bytes.size(); i != e;) {
|
||||
auto c = bytes[i++];
|
||||
if (c != '\\') {
|
||||
result.push_back(c);
|
||||
continue;
|
||||
}
|
||||
|
||||
assert(i + 1 <= e && "invalid string should be caught by lexer");
|
||||
auto c1 = bytes[i++];
|
||||
switch (c1) {
|
||||
case '"':
|
||||
case '\\':
|
||||
result.push_back(c1);
|
||||
continue;
|
||||
case 'n':
|
||||
result.push_back('\n');
|
||||
continue;
|
||||
case 't':
|
||||
result.push_back('\t');
|
||||
continue;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
assert(i + 1 <= e && "invalid string should be caught by lexer");
|
||||
auto c2 = bytes[i++];
|
||||
|
||||
assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
|
||||
result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Lexer
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine)
|
||||
: srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false) {
|
||||
curBufferID = mgr.getMainFileID();
|
||||
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
|
||||
curPtr = curBuffer.begin();
|
||||
|
||||
// If the diag engine has no handler, add a default that emits to the
|
||||
// SourceMgr.
|
||||
if (!diagEngine.getHandlerFn()) {
|
||||
diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
|
||||
srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
|
||||
diag.getMessage());
|
||||
for (const ast::Diagnostic ¬e : diag.getNotes())
|
||||
srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
|
||||
note.getMessage());
|
||||
});
|
||||
addedHandlerToDiagEngine = true;
|
||||
}
|
||||
}
|
||||
|
||||
Lexer::~Lexer() {
|
||||
if (addedHandlerToDiagEngine) diagEngine.setHandlerFn(nullptr);
|
||||
}
|
||||
|
||||
LogicalResult Lexer::pushInclude(StringRef filename) {
|
||||
std::string includedFile;
|
||||
int bufferID = srcMgr.AddIncludeFile(
|
||||
filename.str(), llvm::SMLoc::getFromPointer(curPtr), includedFile);
|
||||
if (!bufferID) return failure();
|
||||
|
||||
curBufferID = bufferID;
|
||||
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
|
||||
curPtr = curBuffer.begin();
|
||||
return success();
|
||||
}
|
||||
|
||||
Token Lexer::emitError(llvm::SMRange loc, const Twine &msg) {
|
||||
diagEngine.emitError(loc, msg);
|
||||
return formToken(Token::error, loc.Start.getPointer());
|
||||
}
|
||||
Token Lexer::emitErrorAndNote(llvm::SMRange loc, const Twine &msg,
|
||||
llvm::SMRange noteLoc, const Twine ¬e) {
|
||||
diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
|
||||
return formToken(Token::error, loc.Start.getPointer());
|
||||
}
|
||||
Token Lexer::emitError(const char *loc, const Twine &msg) {
|
||||
return emitError(llvm::SMRange(llvm::SMLoc::getFromPointer(loc),
|
||||
llvm::SMLoc::getFromPointer(loc + 1)),
|
||||
msg);
|
||||
}
|
||||
|
||||
int Lexer::getNextChar() {
|
||||
char curChar = *curPtr++;
|
||||
switch (curChar) {
|
||||
default:
|
||||
return static_cast<unsigned char>(curChar);
|
||||
case 0: {
|
||||
// A nul character in the stream is either the end of the current buffer
|
||||
// or a random nul in the file. Disambiguate that here.
|
||||
if (curPtr - 1 != curBuffer.end()) return 0;
|
||||
|
||||
// Otherwise, return end of file.
|
||||
--curPtr;
|
||||
return EOF;
|
||||
}
|
||||
case '\n':
|
||||
case '\r':
|
||||
// Handle the newline character by ignoring it and incrementing the line
|
||||
// count. However, be careful about 'dos style' files with \n\r in them.
|
||||
// Only treat a \n\r or \r\n as a single line.
|
||||
if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
|
||||
++curPtr;
|
||||
return '\n';
|
||||
}
|
||||
}
|
||||
|
||||
Token Lexer::lexToken() {
|
||||
while (true) {
|
||||
const char *tokStart = curPtr;
|
||||
|
||||
// This always consumes at least one character.
|
||||
int curChar = getNextChar();
|
||||
switch (curChar) {
|
||||
default:
|
||||
// Handle identifiers: [a-zA-Z_]
|
||||
if (isalpha(curChar) || curChar == '_') return lexIdentifier(tokStart);
|
||||
|
||||
// Unknown character, emit an error.
|
||||
return emitError(tokStart, "unexpected character");
|
||||
case EOF: {
|
||||
// Return EOF denoting the end of lexing.
|
||||
Token eof = formToken(Token::eof, tokStart);
|
||||
|
||||
// Check to see if we are in an included file.
|
||||
llvm::SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
|
||||
if (parentIncludeLoc.isValid()) {
|
||||
curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
|
||||
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
|
||||
curPtr = parentIncludeLoc.getPointer();
|
||||
}
|
||||
|
||||
return eof;
|
||||
}
|
||||
|
||||
// Lex punctuation.
|
||||
case '-':
|
||||
if (*curPtr == '>') {
|
||||
++curPtr;
|
||||
return formToken(Token::arrow, tokStart);
|
||||
}
|
||||
return emitError(tokStart, "unexpected character");
|
||||
case ':':
|
||||
return formToken(Token::colon, tokStart);
|
||||
case ',':
|
||||
return formToken(Token::comma, tokStart);
|
||||
case '.':
|
||||
return formToken(Token::dot, tokStart);
|
||||
case '=':
|
||||
if (*curPtr == '>') {
|
||||
++curPtr;
|
||||
return formToken(Token::equal_arrow, tokStart);
|
||||
}
|
||||
return formToken(Token::equal, tokStart);
|
||||
case ';':
|
||||
return formToken(Token::semicolon, tokStart);
|
||||
case '[':
|
||||
if (*curPtr == '{') {
|
||||
++curPtr;
|
||||
return lexString(tokStart, /*isStringBlock=*/true);
|
||||
}
|
||||
return formToken(Token::l_square, tokStart);
|
||||
case ']':
|
||||
return formToken(Token::r_square, tokStart);
|
||||
|
||||
case '<':
|
||||
return formToken(Token::less, tokStart);
|
||||
case '>':
|
||||
return formToken(Token::greater, tokStart);
|
||||
case '{':
|
||||
return formToken(Token::l_brace, tokStart);
|
||||
case '}':
|
||||
return formToken(Token::r_brace, tokStart);
|
||||
case '(':
|
||||
return formToken(Token::l_paren, tokStart);
|
||||
case ')':
|
||||
return formToken(Token::r_paren, tokStart);
|
||||
case '/':
|
||||
if (*curPtr == '/') {
|
||||
lexComment();
|
||||
continue;
|
||||
}
|
||||
return emitError(tokStart, "unexpected character");
|
||||
|
||||
// Ignore whitespace characters.
|
||||
case 0:
|
||||
case ' ':
|
||||
case '\t':
|
||||
case '\n':
|
||||
return lexToken();
|
||||
|
||||
case '#':
|
||||
return lexDirective(tokStart);
|
||||
case '"':
|
||||
return lexString(tokStart, /*isStringBlock=*/false);
|
||||
|
||||
case '0':
|
||||
case '1':
|
||||
case '2':
|
||||
case '3':
|
||||
case '4':
|
||||
case '5':
|
||||
case '6':
|
||||
case '7':
|
||||
case '8':
|
||||
case '9':
|
||||
return lexNumber(tokStart);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Skip a comment line, starting with a '//'.
|
||||
void Lexer::lexComment() {
|
||||
// Advance over the second '/' in a '//' comment.
|
||||
assert(*curPtr == '/');
|
||||
++curPtr;
|
||||
|
||||
while (true) {
|
||||
switch (*curPtr++) {
|
||||
case '\n':
|
||||
case '\r':
|
||||
// Newline is end of comment.
|
||||
return;
|
||||
case 0:
|
||||
// If this is the end of the buffer, end the comment.
|
||||
if (curPtr - 1 == curBuffer.end()) {
|
||||
--curPtr;
|
||||
return;
|
||||
}
|
||||
LLVM_FALLTHROUGH;
|
||||
default:
|
||||
// Skip over other characters.
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Token Lexer::lexDirective(const char *tokStart) {
|
||||
// Match the rest with an identifier regex: [0-9a-zA-Z_]*
|
||||
while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
|
||||
|
||||
StringRef str(tokStart, curPtr - tokStart);
|
||||
return Token(Token::directive, str);
|
||||
}
|
||||
|
||||
Token Lexer::lexIdentifier(const char *tokStart) {
|
||||
// Match the rest of the identifier regex: [0-9a-zA-Z_]*
|
||||
while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
|
||||
|
||||
// Check to see if this identifier is a keyword.
|
||||
StringRef str(tokStart, curPtr - tokStart);
|
||||
Token::Kind kind = StringSwitch<Token::Kind>(str)
|
||||
.Case("attr", Token::kw_attr)
|
||||
.Case("Attr", Token::kw_Attr)
|
||||
.Case("erase", Token::kw_erase)
|
||||
.Case("let", Token::kw_let)
|
||||
.Case("Constraint", Token::kw_Constraint)
|
||||
.Case("op", Token::kw_op)
|
||||
.Case("Op", Token::kw_Op)
|
||||
.Case("OpName", Token::kw_OpName)
|
||||
.Case("Pattern", Token::kw_Pattern)
|
||||
.Case("replace", Token::kw_replace)
|
||||
.Case("rewrite", Token::kw_rewrite)
|
||||
.Case("type", Token::kw_type)
|
||||
.Case("Type", Token::kw_Type)
|
||||
.Case("TypeRange", Token::kw_TypeRange)
|
||||
.Case("Value", Token::kw_Value)
|
||||
.Case("ValueRange", Token::kw_ValueRange)
|
||||
.Case("with", Token::kw_with)
|
||||
.Case("_", Token::underscore)
|
||||
.Default(Token::identifier);
|
||||
return Token(kind, str);
|
||||
}
|
||||
|
||||
Token Lexer::lexNumber(const char *tokStart) {
|
||||
assert(isdigit(curPtr[-1]));
|
||||
|
||||
// Handle the normal decimal case.
|
||||
while (isdigit(*curPtr)) ++curPtr;
|
||||
|
||||
return formToken(Token::integer, tokStart);
|
||||
}
|
||||
|
||||
Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
|
||||
while (true) {
|
||||
switch (*curPtr++) {
|
||||
case '"':
|
||||
// If this is a string block, we only end the string when we encounter a
|
||||
// `}]`.
|
||||
if (!isStringBlock) return formToken(Token::string, tokStart);
|
||||
continue;
|
||||
case '}':
|
||||
// If this is a string block, we only end the string when we encounter a
|
||||
// `}]`.
|
||||
if (!isStringBlock || *curPtr != ']') continue;
|
||||
++curPtr;
|
||||
return formToken(Token::string_block, tokStart);
|
||||
case 0:
|
||||
// If this is a random nul character in the middle of a string, just
|
||||
// include it. If it is the end of file, then it is an error.
|
||||
if (curPtr - 1 != curBuffer.end()) continue;
|
||||
LLVM_FALLTHROUGH;
|
||||
case '\n':
|
||||
case '\v':
|
||||
case '\f':
|
||||
// String blocks allow multiple lines.
|
||||
if (!isStringBlock)
|
||||
return emitError(curPtr - 1, "expected '\"' in string literal");
|
||||
continue;
|
||||
|
||||
case '\\':
|
||||
// Handle explicitly a few escapes.
|
||||
if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
|
||||
*curPtr == 't') {
|
||||
++curPtr;
|
||||
} else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
|
||||
// Support \xx for two hex digits.
|
||||
curPtr += 2;
|
||||
} else {
|
||||
return emitError(curPtr - 1, "unknown escape in string literal");
|
||||
}
|
||||
continue;
|
||||
|
||||
default:
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,220 @@
|
|||
//===- Lexer.h - MLIR PDLL Frontend Lexer -----------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LIB_TOOLS_PDLL_PARSER_LEXER_H_
|
||||
#define LIB_TOOLS_PDLL_PARSER_LEXER_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
|
||||
namespace llvm {
|
||||
class SourceMgr;
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
struct LogicalResult;
|
||||
|
||||
namespace pdll {
|
||||
namespace ast {
|
||||
class DiagnosticEngine;
|
||||
} // namespace ast
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Token
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class Token {
|
||||
public:
|
||||
enum Kind {
|
||||
// Markers.
|
||||
eof,
|
||||
error,
|
||||
|
||||
// Keywords.
|
||||
KW_BEGIN,
|
||||
// Dependent keywords, i.e. those that are treated as keywords depending on
|
||||
// the current parser context.
|
||||
KW_DEPENDENT_BEGIN,
|
||||
kw_attr,
|
||||
kw_op,
|
||||
kw_type,
|
||||
KW_DEPENDENT_END,
|
||||
|
||||
// General keywords.
|
||||
kw_Attr,
|
||||
kw_erase,
|
||||
kw_let,
|
||||
kw_Constraint,
|
||||
kw_Op,
|
||||
kw_OpName,
|
||||
kw_Pattern,
|
||||
kw_replace,
|
||||
kw_rewrite,
|
||||
kw_Type,
|
||||
kw_TypeRange,
|
||||
kw_Value,
|
||||
kw_ValueRange,
|
||||
kw_with,
|
||||
KW_END,
|
||||
|
||||
// Punctuation.
|
||||
arrow,
|
||||
colon,
|
||||
comma,
|
||||
dot,
|
||||
equal,
|
||||
equal_arrow,
|
||||
semicolon,
|
||||
// Paired punctuation.
|
||||
less,
|
||||
greater,
|
||||
l_brace,
|
||||
r_brace,
|
||||
l_paren,
|
||||
r_paren,
|
||||
l_square,
|
||||
r_square,
|
||||
underscore,
|
||||
|
||||
// Tokens.
|
||||
directive,
|
||||
identifier,
|
||||
integer,
|
||||
string_block,
|
||||
string
|
||||
};
|
||||
Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
|
||||
|
||||
/// Given a token containing a string literal, return its value, including
|
||||
/// removing the quote characters and unescaping the contents of the string.
|
||||
std::string getStringValue() const;
|
||||
|
||||
/// Returns true if the current token is a string literal.
|
||||
bool isString() const { return isAny(Token::string, Token::string_block); }
|
||||
|
||||
/// Returns true if the current token is a keyword.
|
||||
bool isKeyword() const {
|
||||
return kind > Token::KW_BEGIN && kind < Token::KW_END;
|
||||
}
|
||||
|
||||
/// Returns true if the current token is a keyword in a dependent context, and
|
||||
/// in any other situation (e.g. variable names) may be treated as an
|
||||
/// identifier.
|
||||
bool isDependentKeyword() const {
|
||||
return kind > Token::KW_DEPENDENT_BEGIN && kind < Token::KW_DEPENDENT_END;
|
||||
}
|
||||
|
||||
/// Return the bytes that make up this token.
|
||||
StringRef getSpelling() const { return spelling; }
|
||||
|
||||
/// Return the kind of this token.
|
||||
Kind getKind() const { return kind; }
|
||||
|
||||
/// Return true if this token is one of the specified kinds.
|
||||
bool isAny(Kind k1, Kind k2) const { return is(k1) || is(k2); }
|
||||
template <typename... T>
|
||||
bool isAny(Kind k1, Kind k2, Kind k3, T... others) const {
|
||||
return is(k1) || isAny(k2, k3, others...);
|
||||
}
|
||||
|
||||
/// Return if the token does not have the given kind.
|
||||
bool isNot(Kind k) const { return k != kind; }
|
||||
template <typename... T> bool isNot(Kind k1, Kind k2, T... others) const {
|
||||
return !isAny(k1, k2, others...);
|
||||
}
|
||||
|
||||
/// Return if the token has the given kind.
|
||||
bool is(Kind K) const { return kind == K; }
|
||||
|
||||
/// Return a location for the start of this token.
|
||||
llvm::SMLoc getStartLoc() const {
|
||||
return llvm::SMLoc::getFromPointer(spelling.data());
|
||||
}
|
||||
/// Return a location at the end of this token.
|
||||
llvm::SMLoc getEndLoc() const {
|
||||
return llvm::SMLoc::getFromPointer(spelling.data() + spelling.size());
|
||||
}
|
||||
/// Return a location for the range of this token.
|
||||
llvm::SMRange getLoc() const {
|
||||
return llvm::SMRange(getStartLoc(), getEndLoc());
|
||||
}
|
||||
|
||||
private:
|
||||
/// Discriminator that indicates the kind of token this is.
|
||||
Kind kind;
|
||||
|
||||
/// A reference to the entire token contents; this is always a pointer into
|
||||
/// a memory buffer owned by the source manager.
|
||||
StringRef spelling;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Lexer
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class Lexer {
|
||||
public:
|
||||
Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine);
|
||||
~Lexer();
|
||||
|
||||
/// Return a reference to the source manager used by the lexer.
|
||||
llvm::SourceMgr &getSourceMgr() { return srcMgr; }
|
||||
|
||||
/// Return a reference to the diagnostic engine used by the lexer.
|
||||
ast::DiagnosticEngine &getDiagEngine() { return diagEngine; }
|
||||
|
||||
/// Push an include of the given file. This will cause the lexer to start
|
||||
/// processing the provided file. Returns failure if the file could not be
|
||||
/// opened, success otherwise.
|
||||
LogicalResult pushInclude(StringRef filename);
|
||||
|
||||
/// Lex the next token and return it.
|
||||
Token lexToken();
|
||||
|
||||
/// Change the position of the lexer cursor. The next token we lex will start
|
||||
/// at the designated point in the input.
|
||||
void resetPointer(const char *newPointer) { curPtr = newPointer; }
|
||||
|
||||
/// Emit an error to the lexer with the given location and message.
|
||||
Token emitError(llvm::SMRange loc, const Twine &msg);
|
||||
Token emitError(const char *loc, const Twine &msg);
|
||||
Token emitErrorAndNote(llvm::SMRange loc, const Twine &msg,
|
||||
llvm::SMRange noteLoc, const Twine ¬e);
|
||||
|
||||
private:
|
||||
Token formToken(Token::Kind kind, const char *tokStart) {
|
||||
return Token(kind, StringRef(tokStart, curPtr - tokStart));
|
||||
}
|
||||
|
||||
/// Return the next character in the stream.
|
||||
int getNextChar();
|
||||
|
||||
/// Lex methods.
|
||||
void lexComment();
|
||||
Token lexDirective(const char *tokStart);
|
||||
Token lexIdentifier(const char *tokStart);
|
||||
Token lexNumber(const char *tokStart);
|
||||
Token lexString(const char *tokStart, bool isStringBlock);
|
||||
|
||||
llvm::SourceMgr &srcMgr;
|
||||
int curBufferID;
|
||||
StringRef curBuffer;
|
||||
const char *curPtr;
|
||||
|
||||
/// The engine used to emit diagnostics during lexing/parsing.
|
||||
ast::DiagnosticEngine &diagEngine;
|
||||
|
||||
/// A flag indicating if we added a default diagnostic handler to the provided
|
||||
/// diagEngine.
|
||||
bool addedHandlerToDiagEngine;
|
||||
};
|
||||
} // namespace pdll
|
||||
} // namespace mlir
|
||||
|
||||
#endif // LIB_TOOLS_PDLL_PARSER_LEXER_H_
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -78,6 +78,7 @@ set(MLIR_TEST_DEPENDS
|
|||
mlir-linalg-ods-yaml-gen
|
||||
mlir-lsp-server
|
||||
mlir-opt
|
||||
mlir-pdll
|
||||
mlir-reduce
|
||||
mlir-tblgen
|
||||
mlir-translate
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ config.name = 'MLIR'
|
|||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||
|
||||
# suffixes: A list of file extensions to treat as test files.
|
||||
config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test']
|
||||
config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test', '.pdll']
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
|
@ -68,6 +68,7 @@ tools = [
|
|||
'mlir-cpu-runner',
|
||||
'mlir-linalg-ods-yaml-gen',
|
||||
'mlir-reduce',
|
||||
'mlir-pdll',
|
||||
]
|
||||
|
||||
# The following tools are optional
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: unknown directive `#foo`
|
||||
#foo
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Include
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: expected string file name after `include` directive
|
||||
#include <>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: unable to open include file `unknown_file.pdll`
|
||||
#include "unknown_file.pdll"
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: expected include filename to end with `.pdll`
|
||||
#include "unknown_file.foo"
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reference Expr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected identifier constraint
|
||||
let foo = Foo: ;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: undefined reference to `bar`
|
||||
let foo = bar;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern FooPattern {
|
||||
erase _: Op;
|
||||
}
|
||||
|
||||
Pattern {
|
||||
// CHECK: invalid reference to `FooPattern`
|
||||
let foo = FooPattern;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `:` after `_` variable
|
||||
let foo = _;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected identifier constraint
|
||||
let foo = _: ;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Member Access Expr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected identifier or numeric member name
|
||||
let root: Op;
|
||||
erase root.<>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: invalid member access `unknown_result` on expression of type `Op`
|
||||
let root: Op;
|
||||
erase root.unknown_result;
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
// RUN: mlir-pdll %s -I %S | FileCheck %s
|
||||
|
||||
Pattern BeforeIncludedPattern {
|
||||
erase _: Op;
|
||||
}
|
||||
|
||||
#include "include/included.pdll"
|
||||
|
||||
Pattern AfterIncludedPattern {
|
||||
erase _: Op;
|
||||
}
|
||||
|
||||
// CHECK: PatternDecl {{.*}} Name<BeforeIncludedPattern>
|
||||
// CHECK: PatternDecl {{.*}} Name<IncludedPattern>
|
||||
// CHECK: PatternDecl {{.*}} Name<AfterIncludedPattern>
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
// This file is included by 'include.pdll' as part of testing include files.
|
||||
|
||||
Pattern IncludedPattern {
|
||||
erase _: Op;
|
||||
}
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
config.suffixes = ['.pdll']
|
||||
config.excludes = ['include']
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: expected `{` to start pattern body
|
||||
Pattern }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: :6:9: error: `Foo` has already been defined
|
||||
// CHECK: :5:9: note: see previous definition here
|
||||
Pattern Foo { erase root: Op; }
|
||||
Pattern Foo { erase root: Op; }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: expected Pattern body to terminate with an operation rewrite statement
|
||||
Pattern {
|
||||
let value: Value;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: Pattern body was terminated by an operation rewrite statement, but found trailing statements
|
||||
Pattern {
|
||||
erase root: Op;
|
||||
let value: Value;
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: `-PatternDecl
|
||||
// CHECK: `-CompoundStmt
|
||||
// CHECK: `-EraseStmt
|
||||
Pattern {
|
||||
erase _: Op;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: `-PatternDecl {{.*}} Name<NamedPattern>
|
||||
Pattern NamedPattern {
|
||||
erase _: Op;
|
||||
}
|
||||
|
|
@ -0,0 +1,222 @@
|
|||
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: expected top-level declaration, such as a `Pattern`
|
||||
10
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `;` after statement
|
||||
erase _: Op
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// `erase`
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected expression
|
||||
erase;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `Op` expression
|
||||
erase _: Attr;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// `let`
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected identifier after `let` to name a new variable
|
||||
let 5;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: `_` may only be used to define "inline" variables
|
||||
let _;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected expression
|
||||
let foo: Attr<>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected expression of `Type` in type constraint
|
||||
let foo: Attr<_: Attr>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `>` after variable type constraint
|
||||
let foo: Attr<_: Type{};
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: the type of this variable has already been constrained
|
||||
let foo: [Attr<_: Type>, Attr<_: Type];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `.` after dialect namespace
|
||||
let foo: Op<builtin>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected operation name after dialect namespace
|
||||
let foo: Op<builtin.>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `>` after operation name
|
||||
let foo: Op<builtin.func<;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected expression
|
||||
let foo: Value<>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected expression of `Type` in type constraint
|
||||
let foo: Value<_: Attr>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `>` after variable type constraint
|
||||
let foo: Value<_: Type{};
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: the type of this variable has already been constrained
|
||||
let foo: [Value<_: Type>, Value<_: Type];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected expression
|
||||
let foo: ValueRange<10>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected expression of `TypeRange` in type constraint
|
||||
let foo: ValueRange<_: Type>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `>` after variable type constraint
|
||||
let foo: ValueRange<_: Type{};
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: the type of this variable has already been constrained
|
||||
let foo: [ValueRange<_: Type>, ValueRange<_: Type];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: unknown reference to constraint `UnknownConstraint`
|
||||
let foo: UnknownConstraint;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern Foo {
|
||||
erase root: Op;
|
||||
}
|
||||
|
||||
Pattern {
|
||||
// CHECK: invalid reference to non-constraint
|
||||
let foo: Foo;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: constraint type `Attr` is incompatible with the previously inferred type `Value`
|
||||
let foo: [Value, Attr];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `]` after constraint list
|
||||
let foo: [Attr[];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected expression
|
||||
let foo: Attr = ;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: type constraints are not permitted on variables with initializers
|
||||
let foo: ValueRange<_: Type> = _: Op;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: unable to infer type for variable `foo`
|
||||
// CHECK: note: the type of a variable must be inferable from the constraint list or the initializer
|
||||
let foo;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: unable to convert expression of type `Attr` to the expected type of `Value`
|
||||
let foo: Value = _: Attr;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: :7:7: error: `foo` has already been defined
|
||||
// CHECK: :6:7: note: see previous definition here
|
||||
let foo: Attr;
|
||||
let foo: Attr;
|
||||
}
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CompoundStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: CompoundStmt
|
||||
// CHECK: |-LetStmt
|
||||
// CHECK: `-EraseStmt
|
||||
Pattern {
|
||||
let root: Op;
|
||||
erase root;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EraseStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: EraseStmt
|
||||
// CHECK: `-DeclRefExpr {{.*}} Type<Op>
|
||||
Pattern {
|
||||
erase _: Op;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LetStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: LetStmt
|
||||
// CHECK: `-VariableDecl {{.*}} Name<attrVar> Type<Attr>
|
||||
// CHECK: `Constraints`
|
||||
// CHECK: `-AttrConstraintDecl
|
||||
// CHECK: `-VariableDecl {{.*}} Name<var> Type<Op>
|
||||
// CHECK: `Constraints`
|
||||
// CHECK: `-OpConstraintDecl
|
||||
// CHECK: `-OpNameDecl
|
||||
Pattern {
|
||||
let attrVar: Attr;
|
||||
let var: Op;
|
||||
erase var;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check for proper refinement between constraint types.
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: LetStmt
|
||||
// CHECK: `-VariableDecl {{.*}} Name<var> Type<Op<dialect.op>>
|
||||
// CHECK: `Constraints`
|
||||
// CHECK: `-OpConstraintDecl
|
||||
// CHECK: `-OpNameDecl
|
||||
// CHECK: `-OpConstraintDecl
|
||||
// CHECK: `-OpNameDecl {{.*}} Name<dialect.op>
|
||||
Pattern {
|
||||
let var: [Op, Op<dialect.op>];
|
||||
erase var;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check for proper conversion between initializer and constraint type.
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: LetStmt
|
||||
// CHECK: `-VariableDecl {{.*}} Name<var> Type<Op<dialect.op>>
|
||||
// CHECK: `-DeclRefExpr {{.*}} Type<Op<dialect.op>>
|
||||
// CHECK: `-VariableDecl {{.*}} Name<input>
|
||||
// CHECK: `Constraints`
|
||||
// CHECK: `-OpConstraintDecl
|
||||
// CHECK: `-OpNameDecl
|
||||
Pattern {
|
||||
let input: Op<dialect.op>;
|
||||
let var: Op = input;
|
||||
erase var;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check for proper conversion between initializer and constraint type.
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: LetStmt
|
||||
// CHECK: `-VariableDecl {{.*}} Name<var> Type<Value>
|
||||
// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type<Value>
|
||||
// CHECK: `-DeclRefExpr {{.*}} Type<Op<dialect.op>>
|
||||
// CHECK: `-VariableDecl {{.*}} Name<input>
|
||||
// CHECK: `Constraints`
|
||||
// CHECK: `-ValueConstraintDecl
|
||||
Pattern {
|
||||
let input: Op<dialect.op>;
|
||||
let var: Value = input;
|
||||
erase _: Op;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check for proper conversion between initializer and constraint type.
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: LetStmt
|
||||
// CHECK: `-VariableDecl {{.*}} Name<var> Type<ValueRange>
|
||||
// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type<ValueRange>
|
||||
// CHECK: `-DeclRefExpr {{.*}} Type<Op<dialect.op>>
|
||||
// CHECK: `-VariableDecl {{.*}} Name<input>
|
||||
// CHECK: `Constraints`
|
||||
// CHECK: `-ValueRangeConstraintDecl
|
||||
Pattern {
|
||||
let input: Op<dialect.op>;
|
||||
let var: ValueRange = input;
|
||||
erase _: Op;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check for proper handling of type constraints.
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: LetStmt
|
||||
// CHECK: `-VariableDecl {{.*}} Name<var> Type<Value>
|
||||
// CHECK: `Constraints`
|
||||
// CHECK: `-ValueConstraintDecl
|
||||
// CHECK: `-DeclRefExpr {{.*}} Type<Type>
|
||||
// CHECK: `-VariableDecl {{.*}} Name<_> Type<Type>
|
||||
// CHECK: `Constraints`
|
||||
// CHECK: `-TypeConstraintDecl
|
||||
Pattern {
|
||||
let var: Value<_: Type>;
|
||||
erase _: Op;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check for proper handling of type constraints.
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: LetStmt
|
||||
// CHECK: `-VariableDecl {{.*}} Name<var> Type<ValueRange>
|
||||
// CHECK: `Constraints`
|
||||
// CHECK: `-ValueRangeConstraintDecl
|
||||
// CHECK: `-DeclRefExpr {{.*}} Type<TypeRange>
|
||||
// CHECK: `-VariableDecl {{.*}} Name<_> Type<TypeRange>
|
||||
// CHECK: `Constraints`
|
||||
// CHECK: `-TypeRangeConstraintDecl
|
||||
Pattern {
|
||||
let var: ValueRange<_: TypeRange>;
|
||||
erase _: Op;
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
add_subdirectory(mlir-cpu-runner)
|
||||
add_subdirectory(mlir-lsp-server)
|
||||
add_subdirectory(mlir-opt)
|
||||
add_subdirectory(mlir-pdll)
|
||||
add_subdirectory(mlir-reduce)
|
||||
add_subdirectory(mlir-shlib)
|
||||
add_subdirectory(mlir-spirv-cpu-runner)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
set(LIBS
|
||||
MLIRPDLLAST
|
||||
MLIRPDLLParser
|
||||
)
|
||||
|
||||
add_llvm_tool(mlir-pdll
|
||||
mlir-pdll.cpp
|
||||
|
||||
DEPENDS
|
||||
${LIBS}
|
||||
)
|
||||
|
||||
target_link_libraries(mlir-pdll PRIVATE ${LIBS})
|
||||
llvm_update_compile_flags(mlir-pdll)
|
||||
|
||||
mlir_check_all_link_libraries(mlir-pdll)
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
//===- mlir-pdll.cpp - MLIR PDLL frontend -----------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
#include "mlir/Support/ToolUtilities.h"
|
||||
#include "mlir/Tools/PDLL/AST/Context.h"
|
||||
#include "mlir/Tools/PDLL/AST/Nodes.h"
|
||||
#include "mlir/Tools/PDLL/Parser/Parser.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::pdll;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// main
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// The desired output type.
|
||||
enum class OutputType {
|
||||
AST,
|
||||
};
|
||||
|
||||
static LogicalResult
|
||||
processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
|
||||
OutputType outputType, std::vector<std::string> &includeDirs) {
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.setIncludeDirs(includeDirs);
|
||||
sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), llvm::SMLoc());
|
||||
|
||||
ast::Context astContext;
|
||||
FailureOr<ast::Module *> module = parsePDLAST(astContext, sourceMgr);
|
||||
if (failed(module))
|
||||
return failure();
|
||||
|
||||
switch (outputType) {
|
||||
case OutputType::AST:
|
||||
(*module)->print(os);
|
||||
break;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
llvm::cl::opt<std::string> inputFilename(
|
||||
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
|
||||
llvm::cl::value_desc("filename"));
|
||||
|
||||
llvm::cl::opt<std::string> outputFilename(
|
||||
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
llvm::cl::list<std::string> includeDirs(
|
||||
"I", llvm::cl::desc("Directory of include files"),
|
||||
llvm::cl::value_desc("directory"), llvm::cl::Prefix);
|
||||
|
||||
llvm::cl::opt<bool> splitInputFile(
|
||||
"split-input-file",
|
||||
llvm::cl::desc("Split the input file into pieces and process each "
|
||||
"chunk independently"),
|
||||
llvm::cl::init(false));
|
||||
llvm::cl::opt<enum OutputType> outputType(
|
||||
"x", llvm::cl::init(OutputType::AST),
|
||||
llvm::cl::desc("The type of output desired"),
|
||||
llvm::cl::values(clEnumValN(OutputType::AST, "ast",
|
||||
"generate the AST for the input file")));
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "PDLL Frontend");
|
||||
|
||||
// Set up the input file.
|
||||
std::string errorMessage;
|
||||
std::unique_ptr<llvm::MemoryBuffer> inputFile =
|
||||
openInputFile(inputFilename, &errorMessage);
|
||||
if (!inputFile) {
|
||||
llvm::errs() << errorMessage << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Set up the output file.
|
||||
std::unique_ptr<llvm::ToolOutputFile> outputFile =
|
||||
openOutputFile(outputFilename, &errorMessage);
|
||||
if (!outputFile) {
|
||||
llvm::errs() << errorMessage << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// The split-input-file mode is a very specific mode that slices the file
|
||||
// up into small pieces and checks each independently.
|
||||
auto processFn = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
|
||||
raw_ostream &os) {
|
||||
return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs);
|
||||
};
|
||||
if (splitInputFile) {
|
||||
if (failed(splitAndProcessBuffer(std::move(inputFile), processFn,
|
||||
outputFile->os())))
|
||||
return 1;
|
||||
} else if (failed(processFn(std::move(inputFile), outputFile->os()))) {
|
||||
return 1;
|
||||
}
|
||||
outputFile->keep();
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -7923,3 +7923,51 @@ cc_binary(
|
|||
"//mlir/test:TestDialect",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "PDLLAST",
|
||||
srcs = glob(
|
||||
[
|
||||
"lib/Tools/PDLL/AST/*.cpp",
|
||||
"lib/Tools/PDLL/AST/*.h",
|
||||
],
|
||||
),
|
||||
hdrs = glob(["include/mlir/Tools/PDLL/AST/*.h"]),
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
"//llvm:Support",
|
||||
"//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "PDLLParser",
|
||||
srcs = glob(
|
||||
[
|
||||
"lib/Tools/PDLL/Parser/*.cpp",
|
||||
"lib/Tools/PDLL/Parser/*.h",
|
||||
],
|
||||
),
|
||||
hdrs = glob(["include/mlir/Tools/PDLL/Parser/*.h"]),
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":PDLLAST",
|
||||
":Support",
|
||||
":TableGen",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "mlir-pdll",
|
||||
srcs = [
|
||||
"tools/mlir-pdll/mlir-pdll.cpp",
|
||||
],
|
||||
deps = [
|
||||
":PDLLAST",
|
||||
":PDLLParser",
|
||||
":Support",
|
||||
"//llvm:Support",
|
||||
"//llvm:config",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue