[mlir][PDLL] Add an initial frontend for PDLL

This is a new pattern rewrite frontend designed from the ground
up to support MLIR constructs, and to target PDL. This frontend
language was proposed in https://llvm.discourse.group/t/rfc-pdll-a-new-declarative-rewrite-frontend-for-mlir/4798

This commit starts sketching out the base structure of the
frontend, and is intended to be a minimal starting point for
building up the language. It essentially contains support for
defining a pattern, variables, and erasing an operation. The
features mentioned in the proposal RFC (including IDE support)
will be added incrementally in followup commits.

I intend to upstream the documentation for the language in a
followup when a bit more of the pieces have been landed.

Differential Revision: https://reviews.llvm.org/D115093
This commit is contained in:
River Riddle 2021-12-16 01:48:19 +00:00
parent 8f1ea2e85c
commit 11d26bd143
33 changed files with 4425 additions and 1 deletions

View File

@ -0,0 +1,52 @@
//===- Context.h - PDLL AST Context -----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_PDLL_AST_CONTEXT_H_
#define MLIR_TOOLS_PDLL_AST_CONTEXT_H_
#include "mlir/Support/StorageUniquer.h"
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
namespace mlir {
namespace pdll {
namespace ast {
/// This class represents the main context of the PDLL AST. It handles
/// allocating all of the AST constructs, and manages all state necessary for
/// the AST.
class Context {
public:
Context();
Context(const Context &) = delete;
Context &operator=(const Context &) = delete;
/// Return the allocator owned by this context.
llvm::BumpPtrAllocator &getAllocator() { return allocator; }
/// Return the storage uniquer used for AST types.
StorageUniquer &getTypeUniquer() { return typeUniquer; }
/// Return the diagnostic engine of this context.
DiagnosticEngine &getDiagEngine() { return diagEngine; }
private:
/// The diagnostic engine of this AST context.
DiagnosticEngine diagEngine;
/// The allocator used for AST nodes, and other entities allocated within the
/// context.
llvm::BumpPtrAllocator allocator;
/// The uniquer used for creating AST types.
StorageUniquer typeUniquer;
};
} // namespace ast
} // namespace pdll
} // namespace mlir
#endif // MLIR_TOOLS_PDLL_AST_CONTEXT_H_

View File

@ -0,0 +1,182 @@
//===- Diagnostic.h - PDLL AST Diagnostics ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_PDLL_AST_DIAGNOSTICS_H_
#define MLIR_TOOLS_PDLL_AST_DIAGNOSTICS_H_
#include <string>
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/FunctionExtras.h"
#include "llvm/Support/SourceMgr.h"
namespace mlir {
namespace pdll {
namespace ast {
class DiagnosticEngine;
//===----------------------------------------------------------------------===//
// Diagnostic
//===----------------------------------------------------------------------===//
/// This class provides a simple implementation of a PDLL diagnostic.
class Diagnostic {
public:
using Severity = llvm::SourceMgr::DiagKind;
/// Return the severity of this diagnostic.
Severity getSeverity() const { return severity; }
/// Return the message of this diagnostic.
StringRef getMessage() const { return message; }
/// Return the location of this diagnostic.
llvm::SMRange getLocation() const { return location; }
/// Return the notes of this diagnostic.
auto getNotes() const { return llvm::make_pointee_range(notes); }
/// Attach a note to this diagnostic.
Diagnostic &attachNote(const Twine &msg,
Optional<llvm::SMRange> noteLoc = llvm::None) {
assert(getSeverity() != Severity::DK_Note &&
"cannot attach a Note to a Note");
notes.emplace_back(
new Diagnostic(Severity::DK_Note, noteLoc.getValueOr(location), msg));
return *notes.back();
}
/// Allow an inflight diagnostic to be converted to 'failure', otherwise
/// 'success' if this is an empty diagnostic.
operator LogicalResult() const { return failure(); }
private:
Diagnostic(Severity severity, llvm::SMRange loc, const Twine &msg)
: severity(severity), message(msg.str()), location(loc) {}
// Allow access to the constructor.
friend DiagnosticEngine;
/// The severity of this diagnostic.
Severity severity;
/// The message held by this diagnostic.
std::string message;
/// The raw location of this diagnostic.
llvm::SMRange location;
/// Any additional note diagnostics attached to this diagnostic.
std::vector<std::unique_ptr<Diagnostic>> notes;
};
//===----------------------------------------------------------------------===//
// InFlightDiagnostic
//===----------------------------------------------------------------------===//
/// This class represents a diagnostic that is inflight and set to be reported.
/// This allows for last minute modifications of the diagnostic before it is
/// emitted by a DiagnosticEngine.
class InFlightDiagnostic {
public:
InFlightDiagnostic() = default;
InFlightDiagnostic(InFlightDiagnostic &&rhs)
: owner(rhs.owner), impl(std::move(rhs.impl)) {
// Reset the rhs diagnostic.
rhs.impl.reset();
rhs.abandon();
}
~InFlightDiagnostic() {
if (isInFlight())
report();
}
/// Access the internal diagnostic.
Diagnostic &operator*() { return *impl; }
Diagnostic *operator->() { return &*impl; }
/// Reports the diagnostic to the engine.
void report();
/// Abandons this diagnostic so that it will no longer be reported.
void abandon() { owner = nullptr; }
/// Allow an inflight diagnostic to be converted to 'failure', otherwise
/// 'success' if this is an empty diagnostic.
operator LogicalResult() const { return failure(isActive()); }
private:
InFlightDiagnostic &operator=(const InFlightDiagnostic &) = delete;
InFlightDiagnostic &operator=(InFlightDiagnostic &&) = delete;
InFlightDiagnostic(DiagnosticEngine *owner, Diagnostic &&rhs)
: owner(owner), impl(std::move(rhs)) {}
/// Returns true if the diagnostic is still active, i.e. it has a live
/// diagnostic.
bool isActive() const { return impl.hasValue(); }
/// Returns true if the diagnostic is still in flight to be reported.
bool isInFlight() const { return owner; }
// Allow access to the constructor.
friend DiagnosticEngine;
/// The engine that this diagnostic is to report to.
DiagnosticEngine *owner = nullptr;
/// The raw diagnostic that is inflight to be reported.
Optional<Diagnostic> impl;
};
//===----------------------------------------------------------------------===//
// DiagnosticEngine
//===----------------------------------------------------------------------===//
/// This class manages the construction and emission of PDLL diagnostics.
class DiagnosticEngine {
public:
/// A function used to handle diagnostics emitted by the engine.
using HandlerFn = llvm::unique_function<void(Diagnostic &)>;
/// Emit an error to the diagnostic engine.
InFlightDiagnostic emitError(llvm::SMRange loc, const Twine &msg) {
return InFlightDiagnostic(
this, Diagnostic(Diagnostic::Severity::DK_Error, loc, msg));
}
InFlightDiagnostic emitWarning(llvm::SMRange loc, const Twine &msg) {
return InFlightDiagnostic(
this, Diagnostic(Diagnostic::Severity::DK_Warning, loc, msg));
}
/// Report the given diagnostic.
void report(Diagnostic &&diagnostic) {
if (handler)
handler(diagnostic);
}
/// Get the current handler function of this diagnostic engine.
const HandlerFn &getHandlerFn() const { return handler; }
/// Take the current handler function, resetting the current handler to null.
HandlerFn takeHandlerFn() {
HandlerFn oldHandler = std::move(handler);
handler = {};
return oldHandler;
}
/// Set the handler function for this diagnostic engine.
void setHandlerFn(HandlerFn &&newHandler) { handler = std::move(newHandler); }
private:
/// The registered diagnostic handler function.
HandlerFn handler;
};
} // namespace ast
} // namespace pdll
} // namespace mlir
#endif // MLIR_TOOLS_PDLL_AST_DIAGNOSTICS_H_

View File

@ -0,0 +1,690 @@
//===- Nodes.h --------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_PDLL_AST_NODES_H_
#define MLIR_TOOLS_PDLL_AST_NODES_H_
#include "mlir/Support/LLVM.h"
#include "mlir/Tools/PDLL/AST/Types.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
namespace pdll {
namespace ast {
class Context;
class Decl;
class Expr;
class OpNameDecl;
class VariableDecl;
//===----------------------------------------------------------------------===//
// Name
//===----------------------------------------------------------------------===//
/// This class provides a convenient API for interacting with source names. It
/// contains a string name as well as the source location for that name.
struct Name {
static const Name &create(Context &ctx, StringRef name,
llvm::SMRange location);
/// Return the raw string name.
StringRef getName() const { return name; }
/// Get the location of this name.
llvm::SMRange getLoc() const { return location; }
private:
Name() = delete;
Name(const Name &) = delete;
Name &operator=(const Name &) = delete;
Name(StringRef name, llvm::SMRange location)
: name(name), location(location) {}
/// The string name of the decl.
StringRef name;
/// The location of the decl name.
llvm::SMRange location;
};
//===----------------------------------------------------------------------===//
// DeclScope
//===----------------------------------------------------------------------===//
/// This class represents a scope for named AST decls. A scope determines the
/// visibility and lifetime of a named declaration.
class DeclScope {
public:
/// Create a new scope with an optional parent scope.
DeclScope(DeclScope *parent = nullptr) : parent(parent) {}
/// Return the parent scope of this scope, or nullptr if there is no parent.
DeclScope *getParentScope() { return parent; }
const DeclScope *getParentScope() const { return parent; }
/// Return all of the decls within this scope.
auto getDecls() const { return llvm::make_second_range(decls); }
/// Add a new decl to the scope.
void add(Decl *decl);
/// Lookup a decl with the given name starting from this scope. Returns
/// nullptr if no decl could be found.
Decl *lookup(StringRef name);
template <typename T> T *lookup(StringRef name) {
return dyn_cast_or_null<T>(lookup(name));
}
const Decl *lookup(StringRef name) const {
return const_cast<DeclScope *>(this)->lookup(name);
}
template <typename T> const T *lookup(StringRef name) const {
return dyn_cast_or_null<T>(lookup(name));
}
private:
/// The parent scope, or null if this is a top-level scope.
DeclScope *parent;
/// The decls defined within this scope.
llvm::StringMap<Decl *> decls;
};
//===----------------------------------------------------------------------===//
// Node
//===----------------------------------------------------------------------===//
/// This class represents a base AST node. All AST nodes are derived from this
/// class, and it contains many of the base functionality for interacting with
/// nodes.
class Node {
public:
/// This CRTP class provides several utilies when defining new AST nodes.
template <typename T, typename BaseT> class NodeBase : public BaseT {
public:
using Base = NodeBase<T, BaseT>;
/// Provide type casting support.
static bool classof(const Node *node) {
return node->getTypeID() == TypeID::get<T>();
}
protected:
template <typename... Args>
explicit NodeBase(llvm::SMRange loc, Args &&...args)
: BaseT(TypeID::get<T>(), loc, std::forward<Args>(args)...) {}
};
/// Return the type identifier of this node.
TypeID getTypeID() const { return typeID; }
/// Return the location of this node.
llvm::SMRange getLoc() const { return loc; }
/// Print this node to the given stream.
void print(raw_ostream &os) const;
protected:
Node(TypeID typeID, llvm::SMRange loc) : typeID(typeID), loc(loc) {}
private:
/// A unique type identifier for this node.
TypeID typeID;
/// The location of this node.
llvm::SMRange loc;
};
//===----------------------------------------------------------------------===//
// Stmt
//===----------------------------------------------------------------------===//
/// This class represents a base AST Statement node.
class Stmt : public Node {
public:
using Node::Node;
/// Provide type casting support.
static bool classof(const Node *node);
};
//===----------------------------------------------------------------------===//
// CompoundStmt
//===----------------------------------------------------------------------===//
/// This statement represents a compound statement, which contains a collection
/// of other statements.
class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
private llvm::TrailingObjects<CompoundStmt, Stmt *> {
public:
static CompoundStmt *create(Context &ctx, llvm::SMRange location,
ArrayRef<Stmt *> children);
/// Return the children of this compound statement.
MutableArrayRef<Stmt *> getChildren() {
return {getTrailingObjects<Stmt *>(), numChildren};
}
ArrayRef<Stmt *> getChildren() const {
return const_cast<CompoundStmt *>(this)->getChildren();
}
ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
private:
CompoundStmt(llvm::SMRange location, unsigned numChildren)
: Base(location), numChildren(numChildren) {}
/// The number of held children statements.
unsigned numChildren;
// Allow access to various privates.
friend class llvm::TrailingObjects<CompoundStmt, Stmt *>;
};
//===----------------------------------------------------------------------===//
// LetStmt
//===----------------------------------------------------------------------===//
/// This statement represents a `let` statement in PDLL. This statement is used
/// to define variables.
class LetStmt final : public Node::NodeBase<LetStmt, Stmt> {
public:
static LetStmt *create(Context &ctx, llvm::SMRange loc,
VariableDecl *varDecl);
/// Return the variable defined by this statement.
VariableDecl *getVarDecl() const { return varDecl; }
private:
LetStmt(llvm::SMRange loc, VariableDecl *varDecl)
: Base(loc), varDecl(varDecl) {}
/// The variable defined by this statement.
VariableDecl *varDecl;
};
//===----------------------------------------------------------------------===//
// OpRewriteStmt
//===----------------------------------------------------------------------===//
/// This class represents a base operation rewrite statement. Operation rewrite
/// statements perform a set of transformations on a given root operation.
class OpRewriteStmt : public Stmt {
public:
/// Provide type casting support.
static bool classof(const Node *node);
/// Return the root operation of this rewrite.
Expr *getRootOpExpr() const { return rootOp; }
protected:
OpRewriteStmt(TypeID typeID, llvm::SMRange loc, Expr *rootOp)
: Stmt(typeID, loc), rootOp(rootOp) {}
protected:
/// The root operation being rewritten.
Expr *rootOp;
};
//===----------------------------------------------------------------------===//
// EraseStmt
/// This statement represents the `erase` statement in PDLL. This statement
/// erases the given root operation, corresponding roughly to the
/// PatternRewriter::eraseOp API.
class EraseStmt final : public Node::NodeBase<EraseStmt, OpRewriteStmt> {
public:
static EraseStmt *create(Context &ctx, llvm::SMRange loc, Expr *rootOp);
private:
EraseStmt(llvm::SMRange loc, Expr *rootOp) : Base(loc, rootOp) {}
};
//===----------------------------------------------------------------------===//
// Expr
//===----------------------------------------------------------------------===//
/// This class represents a base AST Expression node.
class Expr : public Stmt {
public:
/// Return the type of this expression.
Type getType() const { return type; }
/// Provide type casting support.
static bool classof(const Node *node);
protected:
Expr(TypeID typeID, llvm::SMRange loc, Type type)
: Stmt(typeID, loc), type(type) {}
private:
/// The type of this expression.
Type type;
};
//===----------------------------------------------------------------------===//
// DeclRefExpr
//===----------------------------------------------------------------------===//
/// This expression represents a reference to a Decl node.
class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> {
public:
static DeclRefExpr *create(Context &ctx, llvm::SMRange loc, Decl *decl,
Type type);
/// Get the decl referenced by this expression.
Decl *getDecl() const { return decl; }
private:
DeclRefExpr(llvm::SMRange loc, Decl *decl, Type type)
: Base(loc, type), decl(decl) {}
/// The decl referenced by this expression.
Decl *decl;
};
//===----------------------------------------------------------------------===//
// MemberAccessExpr
//===----------------------------------------------------------------------===//
/// This expression represents a named member or field access of a given parent
/// expression.
class MemberAccessExpr : public Node::NodeBase<MemberAccessExpr, Expr> {
public:
static MemberAccessExpr *create(Context &ctx, llvm::SMRange loc,
const Expr *parentExpr, StringRef memberName,
Type type);
/// Get the parent expression of this access.
const Expr *getParentExpr() const { return parentExpr; }
/// Return the name of the member being accessed.
StringRef getMemberName() const { return memberName; }
private:
MemberAccessExpr(llvm::SMRange loc, const Expr *parentExpr,
StringRef memberName, Type type)
: Base(loc, type), parentExpr(parentExpr), memberName(memberName) {}
/// The parent expression of this access.
const Expr *parentExpr;
/// The name of the member being accessed from the parent.
StringRef memberName;
};
//===----------------------------------------------------------------------===//
// Decl
//===----------------------------------------------------------------------===//
/// This class represents the base Decl node.
class Decl : public Node {
public:
/// Return the name of the decl, or nullptr if it doesn't have one.
const Name *getName() const { return name; }
/// Provide type casting support.
static bool classof(const Node *node);
protected:
Decl(TypeID typeID, llvm::SMRange loc, const Name *name = nullptr)
: Node(typeID, loc), name(name) {}
private:
/// The name of the decl. This is optional for some decls, such as
/// PatternDecl.
const Name *name;
};
//===----------------------------------------------------------------------===//
// ConstraintDecl
//===----------------------------------------------------------------------===//
/// This class represents the base of all AST Constraint decls. Constraints
/// apply matcher conditions to, and define the type of PDLL variables.
class ConstraintDecl : public Decl {
public:
/// Provide type casting support.
static bool classof(const Node *node);
protected:
ConstraintDecl(TypeID typeID, llvm::SMRange loc, const Name *name = nullptr)
: Decl(typeID, loc, name) {}
};
/// This class represents a reference to a constraint, and contains a constraint
/// and the location of the reference.
struct ConstraintRef {
ConstraintRef(const ConstraintDecl *constraint, llvm::SMRange refLoc)
: constraint(constraint), referenceLoc(refLoc) {}
explicit ConstraintRef(const ConstraintDecl *constraint)
: ConstraintRef(constraint, constraint->getLoc()) {}
const ConstraintDecl *constraint;
llvm::SMRange referenceLoc;
};
//===----------------------------------------------------------------------===//
// CoreConstraintDecl
//===----------------------------------------------------------------------===//
/// This class represents the base of all "core" constraints. Core constraints
/// are those that generally represent a concrete IR construct, such as
/// `Type`s or `Value`s.
class CoreConstraintDecl : public ConstraintDecl {
public:
/// Provide type casting support.
static bool classof(const Node *node);
protected:
CoreConstraintDecl(TypeID typeID, llvm::SMRange loc,
const Name *name = nullptr)
: ConstraintDecl(typeID, loc, name) {}
};
//===----------------------------------------------------------------------===//
// AttrConstraintDecl
/// The class represents an Attribute constraint, and constrains a variable to
/// be an Attribute.
class AttrConstraintDecl
: public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> {
public:
static AttrConstraintDecl *create(Context &ctx, llvm::SMRange loc,
Expr *typeExpr = nullptr);
/// Return the optional type the attribute is constrained to.
Expr *getTypeExpr() { return typeExpr; }
const Expr *getTypeExpr() const { return typeExpr; }
protected:
AttrConstraintDecl(llvm::SMRange loc, Expr *typeExpr)
: Base(loc), typeExpr(typeExpr) {}
/// An optional type that the attribute is constrained to.
Expr *typeExpr;
};
//===----------------------------------------------------------------------===//
// OpConstraintDecl
/// The class represents an Operation constraint, and constrains a variable to
/// be an Operation.
class OpConstraintDecl
: public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> {
public:
static OpConstraintDecl *create(Context &ctx, llvm::SMRange loc,
const OpNameDecl *nameDecl = nullptr);
/// Return the name of the operation, or None if there isn't one.
Optional<StringRef> getName() const;
/// Return the declaration of the operation name.
const OpNameDecl *getNameDecl() const { return nameDecl; }
protected:
explicit OpConstraintDecl(llvm::SMRange loc, const OpNameDecl *nameDecl)
: Base(loc), nameDecl(nameDecl) {}
/// The operation name of this constraint.
const OpNameDecl *nameDecl;
};
//===----------------------------------------------------------------------===//
// TypeConstraintDecl
/// The class represents a Type constraint, and constrains a variable to be a
/// Type.
class TypeConstraintDecl
: public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> {
public:
static TypeConstraintDecl *create(Context &ctx, llvm::SMRange loc);
protected:
using Base::Base;
};
//===----------------------------------------------------------------------===//
// TypeRangeConstraintDecl
/// The class represents a TypeRange constraint, and constrains a variable to be
/// a TypeRange.
class TypeRangeConstraintDecl
: public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> {
public:
static TypeRangeConstraintDecl *create(Context &ctx, llvm::SMRange loc);
protected:
using Base::Base;
};
//===----------------------------------------------------------------------===//
// ValueConstraintDecl
/// The class represents a Value constraint, and constrains a variable to be a
/// Value.
class ValueConstraintDecl
: public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {
public:
static ValueConstraintDecl *create(Context &ctx, llvm::SMRange loc,
Expr *typeExpr);
/// Return the optional type the value is constrained to.
Expr *getTypeExpr() { return typeExpr; }
const Expr *getTypeExpr() const { return typeExpr; }
protected:
ValueConstraintDecl(llvm::SMRange loc, Expr *typeExpr)
: Base(loc), typeExpr(typeExpr) {}
/// An optional type that the value is constrained to.
Expr *typeExpr;
};
//===----------------------------------------------------------------------===//
// ValueRangeConstraintDecl
/// The class represents a ValueRange constraint, and constrains a variable to
/// be a ValueRange.
class ValueRangeConstraintDecl
: public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> {
public:
static ValueRangeConstraintDecl *create(Context &ctx, llvm::SMRange loc,
Expr *typeExpr);
/// Return the optional type the value range is constrained to.
Expr *getTypeExpr() { return typeExpr; }
const Expr *getTypeExpr() const { return typeExpr; }
protected:
ValueRangeConstraintDecl(llvm::SMRange loc, Expr *typeExpr)
: Base(loc), typeExpr(typeExpr) {}
/// An optional type that the value range is constrained to.
Expr *typeExpr;
};
//===----------------------------------------------------------------------===//
// OpNameDecl
//===----------------------------------------------------------------------===//
/// This Decl represents an OperationName.
class OpNameDecl : public Node::NodeBase<OpNameDecl, Decl> {
public:
static OpNameDecl *create(Context &ctx, const Name &name);
static OpNameDecl *create(Context &ctx, llvm::SMRange loc);
/// Return the name of this operation, or none if the name is unknown.
Optional<StringRef> getName() const {
const Name *name = Decl::getName();
return name ? Optional<StringRef>(name->getName()) : llvm::None;
}
private:
explicit OpNameDecl(const Name &name) : Base(name.getLoc(), &name) {}
explicit OpNameDecl(llvm::SMRange loc) : Base(loc) {}
};
//===----------------------------------------------------------------------===//
// PatternDecl
//===----------------------------------------------------------------------===//
/// This Decl represents a single Pattern.
class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
public:
static PatternDecl *create(Context &ctx, llvm::SMRange location,
const Name *name, Optional<uint16_t> benefit,
bool hasBoundedRecursion,
const CompoundStmt *body);
/// Return the benefit of this pattern if specified, or None.
Optional<uint16_t> getBenefit() const { return benefit; }
/// Return if this pattern has bounded rewrite recursion.
bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
/// Return the body of this pattern.
const CompoundStmt *getBody() const { return patternBody; }
/// Return the root rewrite statement of this pattern.
const OpRewriteStmt *getRootRewriteStmt() const {
return cast<OpRewriteStmt>(patternBody->getChildren().back());
}
private:
PatternDecl(llvm::SMRange loc, const Name *name, Optional<uint16_t> benefit,
bool hasBoundedRecursion, const CompoundStmt *body)
: Base(loc, name), benefit(benefit),
hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {}
/// The benefit of the pattern if it was explicitly specified, None otherwise.
Optional<uint16_t> benefit;
/// If the pattern has properly bounded rewrite recursion or not.
bool hasBoundedRecursion;
/// The compound statement representing the body of the pattern.
const CompoundStmt *patternBody;
};
//===----------------------------------------------------------------------===//
// VariableDecl
//===----------------------------------------------------------------------===//
/// This Decl represents the definition of a PDLL variable.
class VariableDecl final
: public Node::NodeBase<VariableDecl, Decl>,
private llvm::TrailingObjects<VariableDecl, ConstraintRef> {
public:
static VariableDecl *create(Context &ctx, const Name &name, Type type,
Expr *initExpr,
ArrayRef<ConstraintRef> constraints);
/// Return the constraints of this variable.
MutableArrayRef<ConstraintRef> getConstraints() {
return {getTrailingObjects<ConstraintRef>(), numConstraints};
}
ArrayRef<ConstraintRef> getConstraints() const {
return const_cast<VariableDecl *>(this)->getConstraints();
}
/// Return the initializer expression of this statement, or nullptr if there
/// was no initializer.
Expr *getInitExpr() const { return initExpr; }
/// Return the name of the decl.
const Name &getName() const { return *Decl::getName(); }
/// Return the type of the decl.
Type getType() const { return type; }
private:
VariableDecl(const Name &name, Type type, Expr *initExpr,
unsigned numConstraints)
: Base(name.getLoc(), &name), type(type), initExpr(initExpr),
numConstraints(numConstraints) {}
/// The type of the variable.
Type type;
/// The optional initializer expression of this statement.
Expr *initExpr;
/// The number of constraints attached to this variable.
unsigned numConstraints;
/// Allow access to various internals.
friend llvm::TrailingObjects<VariableDecl, ConstraintRef>;
};
//===----------------------------------------------------------------------===//
// Module
//===----------------------------------------------------------------------===//
/// This class represents a top-level AST module.
class Module final : public Node::NodeBase<Module, Node>,
private llvm::TrailingObjects<Module, Decl *> {
public:
static Module *create(Context &ctx, llvm::SMLoc loc,
ArrayRef<Decl *> children);
/// Return the children of this module.
MutableArrayRef<Decl *> getChildren() {
return {getTrailingObjects<Decl *>(), numChildren};
}
ArrayRef<Decl *> getChildren() const {
return const_cast<Module *>(this)->getChildren();
}
private:
Module(llvm::SMLoc loc, unsigned numChildren)
: Base(llvm::SMRange{loc, loc}), numChildren(numChildren) {}
/// The number of decls held by this module.
unsigned numChildren;
/// Allow access to various internals.
friend llvm::TrailingObjects<Module, Decl *>;
};
//===----------------------------------------------------------------------===//
// Defered Method Definitions
//===----------------------------------------------------------------------===//
inline bool Decl::classof(const Node *node) {
return isa<ConstraintDecl, OpNameDecl, PatternDecl, VariableDecl>(node);
}
inline bool ConstraintDecl::classof(const Node *node) {
return isa<CoreConstraintDecl>(node);
}
inline bool CoreConstraintDecl::classof(const Node *node) {
return isa<AttrConstraintDecl, OpConstraintDecl, TypeConstraintDecl,
TypeRangeConstraintDecl, ValueConstraintDecl,
ValueRangeConstraintDecl>(node);
}
inline bool Expr::classof(const Node *node) {
return isa<DeclRefExpr, MemberAccessExpr>(node);
}
inline bool OpRewriteStmt::classof(const Node *node) {
return isa<EraseStmt>(node);
}
inline bool Stmt::classof(const Node *node) {
return isa<CompoundStmt, LetStmt, OpRewriteStmt, Expr>(node);
}
} // namespace ast
} // namespace pdll
} // namespace mlir
#endif // MLIR_TOOLS_PDLL_AST_NODES_H_

View File

@ -0,0 +1,257 @@
//===- Types.h --------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_PDLL_AST_TYPES_H_
#define MLIR_TOOLS_PDLL_AST_TYPES_H_
#include "mlir/Support/LLVM.h"
#include "mlir/Support/StorageUniquer.h"
namespace mlir {
namespace pdll {
namespace ast {
class Context;
namespace detail {
struct AttributeTypeStorage;
struct ConstraintTypeStorage;
struct OperationTypeStorage;
struct RangeTypeStorage;
struct TypeTypeStorage;
struct ValueTypeStorage;
} // namespace detail
//===----------------------------------------------------------------------===//
// Type
//===----------------------------------------------------------------------===//
class Type {
public:
/// This class represents the internal storage of the Type class.
struct Storage;
/// This class provides several utilities when defining derived type classes.
template <typename ImplT, typename BaseT = Type>
class TypeBase : public BaseT {
public:
using Base = TypeBase<ImplT, BaseT>;
using ImplTy = ImplT;
using BaseT::BaseT;
/// Provide type casting support.
static bool classof(Type type) {
return type.getTypeID() == TypeID::get<ImplTy>();
}
};
Type(Storage *impl = nullptr) : impl(impl) {}
Type(const Type &other) = default;
bool operator==(const Type &other) const { return impl == other.impl; }
bool operator!=(const Type &other) const { return !(*this == other); }
explicit operator bool() const { return impl; }
/// Provide type casting support.
template <typename U> bool isa() const {
assert(impl && "isa<> used on a null type.");
return U::classof(*this);
}
template <typename U, typename V, typename... Others> bool isa() const {
return isa<U>() || isa<V, Others...>();
}
template <typename U> U dyn_cast() const {
return isa<U>() ? U(impl) : U(nullptr);
}
template <typename U> U dyn_cast_or_null() const {
return (impl && isa<U>()) ? U(impl) : U(nullptr);
}
template <typename U> U cast() const {
assert(isa<U>());
return U(impl);
}
/// Return the internal storage instance of this type.
Storage *getImpl() const { return impl; }
/// Return the TypeID instance of this type.
TypeID getTypeID() const;
/// Print this type to the given stream.
void print(raw_ostream &os) const;
/// Try to refine this type with the one provided. Given two compatible types,
/// this will return a merged type contains as much detail from the two types.
/// For example, if refining two operation types and one contains a name,
/// while the other doesn't, the refined type contains the name. If the two
/// types are incompatible, null is returned.
Type refineWith(Type other) const;
protected:
/// Return the internal storage instance of this type reinterpreted as the
/// given derived storage type.
template <typename T> const T *getImplAs() const {
return static_cast<const T *>(impl);
}
private:
Storage *impl;
};
inline llvm::hash_code hash_value(Type type) {
return DenseMapInfo<Type::Storage *>::getHashValue(type.getImpl());
}
inline raw_ostream &operator<<(raw_ostream &os, Type type) {
type.print(os);
return os;
}
//===----------------------------------------------------------------------===//
// AttributeType
//===----------------------------------------------------------------------===//
/// This class represents a PDLL type that corresponds to an mlir::Attribute.
class AttributeType : public Type::TypeBase<detail::AttributeTypeStorage> {
public:
using Base::Base;
/// Return an instance of the Attribute type.
static AttributeType get(Context &context);
};
//===----------------------------------------------------------------------===//
// ConstraintType
//===----------------------------------------------------------------------===//
/// This class represents a PDLL type that corresponds to a constraint. This
/// type has no MLIR C++ API correspondance.
class ConstraintType : public Type::TypeBase<detail::ConstraintTypeStorage> {
public:
using Base::Base;
/// Return an instance of the Constraint type.
static ConstraintType get(Context &context);
};
//===----------------------------------------------------------------------===//
// OperationType
//===----------------------------------------------------------------------===//
/// This class represents a PDLL type that corresponds to an mlir::Operation.
class OperationType : public Type::TypeBase<detail::OperationTypeStorage> {
public:
using Base::Base;
/// Return an instance of the Operation type with an optional operation name.
/// If no name is provided, this type may refer to any operation.
static OperationType get(Context &context,
Optional<StringRef> name = llvm::None);
/// Return the name of this operation type, or None if it doesn't have on.
Optional<StringRef> getName() const;
};
//===----------------------------------------------------------------------===//
// RangeType
//===----------------------------------------------------------------------===//
/// This class represents a PDLL type that corresponds to a range of elements
/// with a given element type.
class RangeType : public Type::TypeBase<detail::RangeTypeStorage> {
public:
using Base::Base;
/// Return an instance of the Range type with the given element type.
static RangeType get(Context &context, Type elementType);
/// Return the element type of this range.
Type getElementType() const;
};
//===----------------------------------------------------------------------===//
// TypeRangeType
/// This class represents a PDLL type that corresponds to an mlir::TypeRange.
class TypeRangeType : public RangeType {
public:
using RangeType::RangeType;
/// Provide type casting support.
static bool classof(Type type);
/// Return an instance of the TypeRange type.
static TypeRangeType get(Context &context);
};
//===----------------------------------------------------------------------===//
// ValueRangeType
/// This class represents a PDLL type that corresponds to an mlir::ValueRange.
class ValueRangeType : public RangeType {
public:
using RangeType::RangeType;
/// Provide type casting support.
static bool classof(Type type);
/// Return an instance of the ValueRange type.
static ValueRangeType get(Context &context);
};
//===----------------------------------------------------------------------===//
// TypeType
//===----------------------------------------------------------------------===//
/// This class represents a PDLL type that corresponds to an mlir::Type.
class TypeType : public Type::TypeBase<detail::TypeTypeStorage> {
public:
using Base::Base;
/// Return an instance of the Type type.
static TypeType get(Context &context);
};
//===----------------------------------------------------------------------===//
// ValueType
//===----------------------------------------------------------------------===//
/// This class represents a PDLL type that corresponds to an mlir::Value.
class ValueType : public Type::TypeBase<detail::ValueTypeStorage> {
public:
using Base::Base;
/// Return an instance of the Value type.
static ValueType get(Context &context);
};
} // namespace ast
} // namespace pdll
} // namespace mlir
namespace llvm {
template <> struct DenseMapInfo<mlir::pdll::ast::Type> {
static mlir::pdll::ast::Type getEmptyKey() {
void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::pdll::ast::Type(
static_cast<mlir::pdll::ast::Type::Storage *>(pointer));
}
static mlir::pdll::ast::Type getTombstoneKey() {
void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::pdll::ast::Type(
static_cast<mlir::pdll::ast::Type::Storage *>(pointer));
}
static unsigned getHashValue(mlir::pdll::ast::Type val) {
return llvm::hash_value(val.getImpl());
}
static bool isEqual(mlir::pdll::ast::Type lhs, mlir::pdll::ast::Type rhs) {
return lhs == rhs;
}
};
} // namespace llvm
#endif // MLIR_TOOLS_PDLL_AST_TYPES_H_

View File

@ -0,0 +1,33 @@
//===- Parser.h - MLIR PDLL Frontend Parser ---------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_PDLL_PARSER_PARSER_H_
#define MLIR_TOOLS_PDLL_PARSER_PARSER_H_
#include <memory>
#include "mlir/Support/LogicalResult.h"
namespace llvm {
class SourceMgr;
} // namespace llvm
namespace mlir {
namespace pdll {
namespace ast {
class Context;
class Module;
} // namespace ast
/// Parse an AST module from the main file of the given source manager.
FailureOr<ast::Module *> parsePDLAST(ast::Context &ctx,
llvm::SourceMgr &sourceMgr);
} // namespace pdll
} // namespace mlir
#endif // MLIR_TOOLS_PDLL_PARSER_PARSER_H_

View File

@ -1,2 +1,3 @@
add_subdirectory(mlir-lsp-server)
add_subdirectory(mlir-reduce)
add_subdirectory(PDLL)

View File

@ -0,0 +1,10 @@
add_mlir_library(MLIRPDLLAST
Context.cpp
Diagnostic.cpp
NodePrinter.cpp
Nodes.cpp
Types.cpp
LINK_LIBS PUBLIC
MLIRSupport
)

View File

@ -0,0 +1,23 @@
//===- Context.cpp --------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/AST/Context.h"
#include "TypeDetail.h"
using namespace mlir;
using namespace mlir::pdll::ast;
Context::Context() {
typeUniquer.registerSingletonStorageType<detail::AttributeTypeStorage>();
typeUniquer.registerSingletonStorageType<detail::ConstraintTypeStorage>();
typeUniquer.registerSingletonStorageType<detail::TypeTypeStorage>();
typeUniquer.registerSingletonStorageType<detail::ValueTypeStorage>();
typeUniquer.registerParametricStorageType<detail::OperationTypeStorage>();
typeUniquer.registerParametricStorageType<detail::RangeTypeStorage>();
}

View File

@ -0,0 +1,26 @@
//===- Diagnostic.cpp -----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
using namespace mlir;
using namespace mlir::pdll::ast;
//===----------------------------------------------------------------------===//
// InFlightDiagnostic
//===----------------------------------------------------------------------===//
void InFlightDiagnostic::report() {
// If this diagnostic is still inflight and it hasn't been abandoned, then
// report it.
if (isInFlight()) {
owner->report(std::move(*impl));
owner = nullptr;
}
impl.reset();
}

View File

@ -0,0 +1,266 @@
//===- NodePrinter.cpp ----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
using namespace mlir;
using namespace mlir::pdll::ast;
//===----------------------------------------------------------------------===//
// NodePrinter
//===----------------------------------------------------------------------===//
namespace {
class NodePrinter {
public:
NodePrinter(raw_ostream &os) : os(os) {}
/// Print the given type to the stream.
void print(Type type);
/// Print the given node to the stream.
void print(const Node *node);
private:
/// Print a range containing children of a node.
template <typename RangeT,
std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
* = nullptr>
void printChildren(RangeT &&range) {
if (llvm::empty(range))
return;
// Print the first N-1 elements with a prefix of "|-".
auto it = std::begin(range);
for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
print(*it);
// Print the last element.
elementIndentStack.back() = true;
print(*it);
}
template <typename RangeT, typename... OthersT,
std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
* = nullptr>
void printChildren(RangeT &&range, OthersT &&...others) {
printChildren(ArrayRef<const Node *>({range, others...}));
}
/// Print a range containing children of a node, nesting the children under
/// the given label.
template <typename RangeT>
void printChildren(StringRef label, RangeT &&range) {
if (llvm::empty(range))
return;
elementIndentStack.reserve(elementIndentStack.size() + 1);
llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true);
printIndent();
os << label << "`\n";
elementIndentStack.push_back(/*isLastElt*/ false);
printChildren(std::forward<RangeT>(range));
elementIndentStack.pop_back();
}
/// Print the given derived node to the stream.
void printImpl(const CompoundStmt *stmt);
void printImpl(const EraseStmt *stmt);
void printImpl(const LetStmt *stmt);
void printImpl(const DeclRefExpr *expr);
void printImpl(const MemberAccessExpr *expr);
void printImpl(const AttrConstraintDecl *decl);
void printImpl(const OpConstraintDecl *decl);
void printImpl(const TypeConstraintDecl *decl);
void printImpl(const TypeRangeConstraintDecl *decl);
void printImpl(const ValueConstraintDecl *decl);
void printImpl(const ValueRangeConstraintDecl *decl);
void printImpl(const OpNameDecl *decl);
void printImpl(const PatternDecl *decl);
void printImpl(const VariableDecl *decl);
void printImpl(const Module *module);
/// Print the current indent stack.
void printIndent() {
if (elementIndentStack.empty())
return;
for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back())
os << (isLastElt ? " " : " |");
os << (elementIndentStack.back() ? " `" : " |");
}
/// The raw output stream.
raw_ostream &os;
/// A stack of indents and a flag indicating if the current element being
/// printed at that indent is the last element.
SmallVector<bool> elementIndentStack;
};
} // namespace
void NodePrinter::print(Type type) {
// Protect against invalid inputs.
if (!type) {
os << "Type<NULL>";
return;
}
TypeSwitch<Type>(type)
.Case([&](AttributeType) { os << "Attr"; })
.Case([&](ConstraintType) { os << "Constraint"; })
.Case([&](OperationType type) {
os << "Op";
if (Optional<StringRef> name = type.getName())
os << "<" << *name << ">";
})
.Case([&](RangeType type) {
print(type.getElementType());
os << "Range";
})
.Case([&](TypeType) { os << "Type"; })
.Case([&](ValueType) { os << "Value"; })
.Default([](Type) { llvm_unreachable("unknown AST type"); });
}
void NodePrinter::print(const Node *node) {
printIndent();
os << "-";
elementIndentStack.push_back(/*isLastElt*/ false);
TypeSwitch<const Node *>(node)
.Case<
// Statements.
const CompoundStmt, const EraseStmt, const LetStmt,
// Expressions.
const DeclRefExpr, const MemberAccessExpr,
// Decls.
const AttrConstraintDecl, const OpConstraintDecl,
const TypeConstraintDecl, const TypeRangeConstraintDecl,
const ValueConstraintDecl, const ValueRangeConstraintDecl,
const OpNameDecl, const PatternDecl, const VariableDecl,
const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
.Default([](const Node *) { llvm_unreachable("unknown AST node"); });
elementIndentStack.pop_back();
}
void NodePrinter::printImpl(const CompoundStmt *stmt) {
os << "CompoundStmt " << stmt << "\n";
printChildren(stmt->getChildren());
}
void NodePrinter::printImpl(const EraseStmt *stmt) {
os << "EraseStmt " << stmt << "\n";
printChildren(stmt->getRootOpExpr());
}
void NodePrinter::printImpl(const LetStmt *stmt) {
os << "LetStmt " << stmt << "\n";
printChildren(stmt->getVarDecl());
}
void NodePrinter::printImpl(const DeclRefExpr *expr) {
os << "DeclRefExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getDecl());
}
void NodePrinter::printImpl(const MemberAccessExpr *expr) {
os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
<< "> Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getParentExpr());
}
void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
os << "AttrConstraintDecl " << decl << "\n";
if (const auto *typeExpr = decl->getTypeExpr())
printChildren(typeExpr);
}
void NodePrinter::printImpl(const OpConstraintDecl *decl) {
os << "OpConstraintDecl " << decl << "\n";
printChildren(decl->getNameDecl());
}
void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
os << "TypeConstraintDecl " << decl << "\n";
}
void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
os << "TypeRangeConstraintDecl " << decl << "\n";
}
void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
os << "ValueConstraintDecl " << decl << "\n";
if (const auto *typeExpr = decl->getTypeExpr())
printChildren(typeExpr);
}
void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
os << "ValueRangeConstraintDecl " << decl << "\n";
if (const auto *typeExpr = decl->getTypeExpr())
printChildren(typeExpr);
}
void NodePrinter::printImpl(const OpNameDecl *decl) {
os << "OpNameDecl " << decl;
if (Optional<StringRef> name = decl->getName())
os << " Name<" << name << ">";
os << "\n";
}
void NodePrinter::printImpl(const PatternDecl *decl) {
os << "PatternDecl " << decl;
if (const Name *name = decl->getName())
os << " Name<" << name->getName() << ">";
if (Optional<uint16_t> benefit = decl->getBenefit())
os << " Benefit<" << *benefit << ">";
if (decl->hasBoundedRewriteRecursion())
os << " Recursion";
os << "\n";
printChildren(decl->getBody());
}
void NodePrinter::printImpl(const VariableDecl *decl) {
os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
<< "> Type<";
print(decl->getType());
os << ">\n";
if (Expr *initExpr = decl->getInitExpr())
printChildren(initExpr);
auto constraints =
llvm::map_range(decl->getConstraints(),
[](const ConstraintRef &ref) { return ref.constraint; });
printChildren("Constraints", constraints);
}
void NodePrinter::printImpl(const Module *module) {
os << "Module " << module << "\n";
printChildren(module->getChildren());
}
//===----------------------------------------------------------------------===//
// Entry point
//===----------------------------------------------------------------------===//
void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }

View File

@ -0,0 +1,231 @@
//===- Nodes.cpp ----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/AST/Context.h"
#include "llvm/ADT/SmallPtrSet.h"
using namespace mlir;
using namespace mlir::pdll::ast;
/// Copy a string reference into the context with a null terminator.
static StringRef copyStringWithNull(Context &ctx, StringRef str) {
if (str.empty())
return str;
char *data = ctx.getAllocator().Allocate<char>(str.size() + 1);
std::copy(str.begin(), str.end(), data);
data[str.size()] = 0;
return StringRef(data, str.size());
}
//===----------------------------------------------------------------------===//
// Name
//===----------------------------------------------------------------------===//
const Name &Name::create(Context &ctx, StringRef name, llvm::SMRange location) {
return *new (ctx.getAllocator().Allocate<Name>())
Name(copyStringWithNull(ctx, name), location);
}
//===----------------------------------------------------------------------===//
// DeclScope
//===----------------------------------------------------------------------===//
void DeclScope::add(Decl *decl) {
const Name *name = decl->getName();
assert(name && "expected a named decl");
assert(!decls.count(name->getName()) && "decl with this name already exists");
decls.try_emplace(name->getName(), decl);
}
Decl *DeclScope::lookup(StringRef name) {
if (Decl *decl = decls.lookup(name))
return decl;
return parent ? parent->lookup(name) : nullptr;
}
//===----------------------------------------------------------------------===//
// CompoundStmt
//===----------------------------------------------------------------------===//
CompoundStmt *CompoundStmt::create(Context &ctx, llvm::SMRange loc,
ArrayRef<Stmt *> children) {
unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt));
CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size());
std::uninitialized_copy(children.begin(), children.end(),
stmt->getChildren().begin());
return stmt;
}
//===----------------------------------------------------------------------===//
// LetStmt
//===----------------------------------------------------------------------===//
LetStmt *LetStmt::create(Context &ctx, llvm::SMRange loc,
VariableDecl *varDecl) {
return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl);
}
//===----------------------------------------------------------------------===//
// OpRewriteStmt
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// EraseStmt
EraseStmt *EraseStmt::create(Context &ctx, llvm::SMRange loc, Expr *rootOp) {
return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp);
}
//===----------------------------------------------------------------------===//
// DeclRefExpr
//===----------------------------------------------------------------------===//
DeclRefExpr *DeclRefExpr::create(Context &ctx, llvm::SMRange loc, Decl *decl,
Type type) {
return new (ctx.getAllocator().Allocate<DeclRefExpr>())
DeclRefExpr(loc, decl, type);
}
//===----------------------------------------------------------------------===//
// MemberAccessExpr
//===----------------------------------------------------------------------===//
MemberAccessExpr *MemberAccessExpr::create(Context &ctx, llvm::SMRange loc,
const Expr *parentExpr,
StringRef memberName, Type type) {
return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr(
loc, parentExpr, memberName.copy(ctx.getAllocator()), type);
}
//===----------------------------------------------------------------------===//
// AttrConstraintDecl
//===----------------------------------------------------------------------===//
AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, llvm::SMRange loc,
Expr *typeExpr) {
return new (ctx.getAllocator().Allocate<AttrConstraintDecl>())
AttrConstraintDecl(loc, typeExpr);
}
//===----------------------------------------------------------------------===//
// OpConstraintDecl
//===----------------------------------------------------------------------===//
OpConstraintDecl *OpConstraintDecl::create(Context &ctx, llvm::SMRange loc,
const OpNameDecl *nameDecl) {
if (!nameDecl)
nameDecl = OpNameDecl::create(ctx, llvm::SMRange());
return new (ctx.getAllocator().Allocate<OpConstraintDecl>())
OpConstraintDecl(loc, nameDecl);
}
Optional<StringRef> OpConstraintDecl::getName() const {
return getNameDecl()->getName();
}
//===----------------------------------------------------------------------===//
// TypeConstraintDecl
//===----------------------------------------------------------------------===//
TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx,
llvm::SMRange loc) {
return new (ctx.getAllocator().Allocate<TypeConstraintDecl>())
TypeConstraintDecl(loc);
}
//===----------------------------------------------------------------------===//
// TypeRangeConstraintDecl
//===----------------------------------------------------------------------===//
TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx,
llvm::SMRange loc) {
return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>())
TypeRangeConstraintDecl(loc);
}
//===----------------------------------------------------------------------===//
// ValueConstraintDecl
//===----------------------------------------------------------------------===//
ValueConstraintDecl *
ValueConstraintDecl::create(Context &ctx, llvm::SMRange loc, Expr *typeExpr) {
return new (ctx.getAllocator().Allocate<ValueConstraintDecl>())
ValueConstraintDecl(loc, typeExpr);
}
//===----------------------------------------------------------------------===//
// ValueRangeConstraintDecl
//===----------------------------------------------------------------------===//
ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
llvm::SMRange loc,
Expr *typeExpr) {
return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>())
ValueRangeConstraintDecl(loc, typeExpr);
}
//===----------------------------------------------------------------------===//
// OpNameDecl
//===----------------------------------------------------------------------===//
OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) {
return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name);
}
OpNameDecl *OpNameDecl::create(Context &ctx, llvm::SMRange loc) {
return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc);
}
//===----------------------------------------------------------------------===//
// PatternDecl
//===----------------------------------------------------------------------===//
PatternDecl *PatternDecl::create(Context &ctx, llvm::SMRange loc,
const Name *name, Optional<uint16_t> benefit,
bool hasBoundedRecursion,
const CompoundStmt *body) {
return new (ctx.getAllocator().Allocate<PatternDecl>())
PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
}
//===----------------------------------------------------------------------===//
// VariableDecl
//===----------------------------------------------------------------------===//
VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type,
Expr *initExpr,
ArrayRef<ConstraintRef> constraints) {
unsigned allocSize =
VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl));
VariableDecl *varDecl =
new (rawData) VariableDecl(name, type, initExpr, constraints.size());
std::uninitialized_copy(constraints.begin(), constraints.end(),
varDecl->getConstraints().begin());
return varDecl;
}
//===----------------------------------------------------------------------===//
// Module
//===----------------------------------------------------------------------===//
Module *Module::create(Context &ctx, llvm::SMLoc loc,
ArrayRef<Decl *> children) {
unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module));
Module *module = new (rawData) Module(loc, children.size());
std::uninitialized_copy(children.begin(), children.end(),
module->getChildren().begin());
return module;
}

View File

@ -0,0 +1,116 @@
//===- TypeDetail.h ---------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_
#define LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_
#include "mlir/Tools/PDLL/AST/Types.h"
namespace mlir {
namespace pdll {
namespace ast {
//===----------------------------------------------------------------------===//
// Type
//===----------------------------------------------------------------------===//
struct Type::Storage : public StorageUniquer::BaseStorage {
Storage(TypeID typeID) : typeID(typeID) {}
/// The type identifier for the derived type class.
TypeID typeID;
};
namespace detail {
/// A utility CRTP base class that defines many of the necessary utilities for
/// defining a PDLL AST Type.
template <typename ConcreteT, typename KeyT = void>
struct TypeStorageBase : public Type::Storage {
using KeyTy = KeyT;
using Base = TypeStorageBase<ConcreteT, KeyT>;
/// Construct an instance with the given storage allocator.
static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
const KeyTy &key) {
return new (alloc.allocate<ConcreteT>()) ConcreteT(key);
}
/// Utility methods required by the storage allocator.
bool operator==(const KeyTy &key) const { return this->key == key; }
/// Return the key value of this storage class.
const KeyTy &getValue() const { return key; }
protected:
TypeStorageBase(KeyTy key)
: Type::Storage(TypeID::get<ConcreteT>()), key(key) {}
KeyTy key;
};
/// A specialization of the storage base for singleton types.
template <typename ConcreteT>
struct TypeStorageBase<ConcreteT, void> : public Type::Storage {
using Base = TypeStorageBase<ConcreteT, void>;
protected:
TypeStorageBase() : Type::Storage(TypeID::get<ConcreteT>()) {}
};
//===----------------------------------------------------------------------===//
// AttributeType
//===----------------------------------------------------------------------===//
struct AttributeTypeStorage : public TypeStorageBase<AttributeTypeStorage> {};
//===----------------------------------------------------------------------===//
// ConstraintType
//===----------------------------------------------------------------------===//
struct ConstraintTypeStorage : public TypeStorageBase<ConstraintTypeStorage> {};
//===----------------------------------------------------------------------===//
// OperationType
//===----------------------------------------------------------------------===//
struct OperationTypeStorage
: public TypeStorageBase<OperationTypeStorage, StringRef> {
using Base::Base;
static OperationTypeStorage *
construct(StorageUniquer::StorageAllocator &alloc, StringRef key) {
return new (alloc.allocate<OperationTypeStorage>())
OperationTypeStorage(alloc.copyInto(key));
}
};
//===----------------------------------------------------------------------===//
// RangeType
//===----------------------------------------------------------------------===//
struct RangeTypeStorage : public TypeStorageBase<RangeTypeStorage, Type> {
using Base::Base;
};
//===----------------------------------------------------------------------===//
// TypeType
//===----------------------------------------------------------------------===//
struct TypeTypeStorage : public TypeStorageBase<TypeTypeStorage> {};
//===----------------------------------------------------------------------===//
// ValueType
//===----------------------------------------------------------------------===//
struct ValueTypeStorage : public TypeStorageBase<ValueTypeStorage> {};
} // namespace detail
} // namespace ast
} // namespace pdll
} // namespace mlir
#endif // LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_

View File

@ -0,0 +1,124 @@
//===- Types.cpp ----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/AST/Types.h"
#include "TypeDetail.h"
#include "mlir/Tools/PDLL/AST/Context.h"
using namespace mlir;
using namespace mlir::pdll::ast;
//===----------------------------------------------------------------------===//
// Type
//===----------------------------------------------------------------------===//
TypeID Type::getTypeID() const { return impl->typeID; }
Type Type::refineWith(Type other) const {
if (*this == other)
return *this;
// Operation types are compatible if the operation names don't conflict.
if (auto opTy = dyn_cast<OperationType>()) {
auto otherOpTy = other.dyn_cast<ast::OperationType>();
if (!otherOpTy)
return nullptr;
if (!otherOpTy.getName())
return *this;
if (!opTy.getName())
return other;
return nullptr;
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// AttributeType
//===----------------------------------------------------------------------===//
AttributeType AttributeType::get(Context &context) {
return context.getTypeUniquer().get<ImplTy>();
}
//===----------------------------------------------------------------------===//
// ConstraintType
//===----------------------------------------------------------------------===//
ConstraintType ConstraintType::get(Context &context) {
return context.getTypeUniquer().get<ImplTy>();
}
//===----------------------------------------------------------------------===//
// OperationType
//===----------------------------------------------------------------------===//
OperationType OperationType::get(Context &context, Optional<StringRef> name) {
return context.getTypeUniquer().get<ImplTy>(
/*initFn=*/function_ref<void(ImplTy *)>(), name.getValueOr(""));
}
Optional<StringRef> OperationType::getName() const {
StringRef name = getImplAs<ImplTy>()->getValue();
return name.empty() ? Optional<StringRef>() : Optional<StringRef>(name);
}
//===----------------------------------------------------------------------===//
// RangeType
//===----------------------------------------------------------------------===//
RangeType RangeType::get(Context &context, Type elementType) {
return context.getTypeUniquer().get<ImplTy>(
/*initFn=*/function_ref<void(ImplTy *)>(), elementType);
}
Type RangeType::getElementType() const {
return getImplAs<ImplTy>()->getValue();
}
//===----------------------------------------------------------------------===//
// TypeRangeType
bool TypeRangeType::classof(Type type) {
RangeType range = type.dyn_cast<RangeType>();
return range && range.getElementType().isa<TypeType>();
}
TypeRangeType TypeRangeType::get(Context &context) {
return RangeType::get(context, TypeType::get(context)).cast<TypeRangeType>();
}
//===----------------------------------------------------------------------===//
// ValueRangeType
bool ValueRangeType::classof(Type type) {
RangeType range = type.dyn_cast<RangeType>();
return range && range.getElementType().isa<ValueType>();
}
ValueRangeType ValueRangeType::get(Context &context) {
return RangeType::get(context, ValueType::get(context))
.cast<ValueRangeType>();
}
//===----------------------------------------------------------------------===//
// TypeType
//===----------------------------------------------------------------------===//
TypeType TypeType::get(Context &context) {
return context.getTypeUniquer().get<ImplTy>();
}
//===----------------------------------------------------------------------===//
// ValueType
//===----------------------------------------------------------------------===//
ValueType ValueType::get(Context &context) {
return context.getTypeUniquer().get<ImplTy>();
}

View File

@ -0,0 +1,2 @@
add_subdirectory(AST)
add_subdirectory(Parser)

View File

@ -0,0 +1,8 @@
add_mlir_library(MLIRPDLLParser
Lexer.cpp
Parser.cpp
LINK_LIBS PUBLIC
MLIRPDLLAST
MLIRSupport
)

View File

@ -0,0 +1,366 @@
//===- Lexer.cpp ----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Lexer.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using namespace mlir::pdll;
//===----------------------------------------------------------------------===//
// Token
//===----------------------------------------------------------------------===//
std::string Token::getStringValue() const {
assert(getKind() == string || getKind() == string_block);
// Start by dropping the quotes.
StringRef bytes = getSpelling().drop_front().drop_back();
if (is(string_block)) bytes = bytes.drop_front().drop_back();
std::string result;
result.reserve(bytes.size());
for (unsigned i = 0, e = bytes.size(); i != e;) {
auto c = bytes[i++];
if (c != '\\') {
result.push_back(c);
continue;
}
assert(i + 1 <= e && "invalid string should be caught by lexer");
auto c1 = bytes[i++];
switch (c1) {
case '"':
case '\\':
result.push_back(c1);
continue;
case 'n':
result.push_back('\n');
continue;
case 't':
result.push_back('\t');
continue;
default:
break;
}
assert(i + 1 <= e && "invalid string should be caught by lexer");
auto c2 = bytes[i++];
assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
}
return result;
}
//===----------------------------------------------------------------------===//
// Lexer
//===----------------------------------------------------------------------===//
Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine)
: srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false) {
curBufferID = mgr.getMainFileID();
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
curPtr = curBuffer.begin();
// If the diag engine has no handler, add a default that emits to the
// SourceMgr.
if (!diagEngine.getHandlerFn()) {
diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
diag.getMessage());
for (const ast::Diagnostic &note : diag.getNotes())
srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
note.getMessage());
});
addedHandlerToDiagEngine = true;
}
}
Lexer::~Lexer() {
if (addedHandlerToDiagEngine) diagEngine.setHandlerFn(nullptr);
}
LogicalResult Lexer::pushInclude(StringRef filename) {
std::string includedFile;
int bufferID = srcMgr.AddIncludeFile(
filename.str(), llvm::SMLoc::getFromPointer(curPtr), includedFile);
if (!bufferID) return failure();
curBufferID = bufferID;
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
curPtr = curBuffer.begin();
return success();
}
Token Lexer::emitError(llvm::SMRange loc, const Twine &msg) {
diagEngine.emitError(loc, msg);
return formToken(Token::error, loc.Start.getPointer());
}
Token Lexer::emitErrorAndNote(llvm::SMRange loc, const Twine &msg,
llvm::SMRange noteLoc, const Twine &note) {
diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
return formToken(Token::error, loc.Start.getPointer());
}
Token Lexer::emitError(const char *loc, const Twine &msg) {
return emitError(llvm::SMRange(llvm::SMLoc::getFromPointer(loc),
llvm::SMLoc::getFromPointer(loc + 1)),
msg);
}
int Lexer::getNextChar() {
char curChar = *curPtr++;
switch (curChar) {
default:
return static_cast<unsigned char>(curChar);
case 0: {
// A nul character in the stream is either the end of the current buffer
// or a random nul in the file. Disambiguate that here.
if (curPtr - 1 != curBuffer.end()) return 0;
// Otherwise, return end of file.
--curPtr;
return EOF;
}
case '\n':
case '\r':
// Handle the newline character by ignoring it and incrementing the line
// count. However, be careful about 'dos style' files with \n\r in them.
// Only treat a \n\r or \r\n as a single line.
if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
++curPtr;
return '\n';
}
}
Token Lexer::lexToken() {
while (true) {
const char *tokStart = curPtr;
// This always consumes at least one character.
int curChar = getNextChar();
switch (curChar) {
default:
// Handle identifiers: [a-zA-Z_]
if (isalpha(curChar) || curChar == '_') return lexIdentifier(tokStart);
// Unknown character, emit an error.
return emitError(tokStart, "unexpected character");
case EOF: {
// Return EOF denoting the end of lexing.
Token eof = formToken(Token::eof, tokStart);
// Check to see if we are in an included file.
llvm::SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
if (parentIncludeLoc.isValid()) {
curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
curPtr = parentIncludeLoc.getPointer();
}
return eof;
}
// Lex punctuation.
case '-':
if (*curPtr == '>') {
++curPtr;
return formToken(Token::arrow, tokStart);
}
return emitError(tokStart, "unexpected character");
case ':':
return formToken(Token::colon, tokStart);
case ',':
return formToken(Token::comma, tokStart);
case '.':
return formToken(Token::dot, tokStart);
case '=':
if (*curPtr == '>') {
++curPtr;
return formToken(Token::equal_arrow, tokStart);
}
return formToken(Token::equal, tokStart);
case ';':
return formToken(Token::semicolon, tokStart);
case '[':
if (*curPtr == '{') {
++curPtr;
return lexString(tokStart, /*isStringBlock=*/true);
}
return formToken(Token::l_square, tokStart);
case ']':
return formToken(Token::r_square, tokStart);
case '<':
return formToken(Token::less, tokStart);
case '>':
return formToken(Token::greater, tokStart);
case '{':
return formToken(Token::l_brace, tokStart);
case '}':
return formToken(Token::r_brace, tokStart);
case '(':
return formToken(Token::l_paren, tokStart);
case ')':
return formToken(Token::r_paren, tokStart);
case '/':
if (*curPtr == '/') {
lexComment();
continue;
}
return emitError(tokStart, "unexpected character");
// Ignore whitespace characters.
case 0:
case ' ':
case '\t':
case '\n':
return lexToken();
case '#':
return lexDirective(tokStart);
case '"':
return lexString(tokStart, /*isStringBlock=*/false);
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
return lexNumber(tokStart);
}
}
}
/// Skip a comment line, starting with a '//'.
void Lexer::lexComment() {
// Advance over the second '/' in a '//' comment.
assert(*curPtr == '/');
++curPtr;
while (true) {
switch (*curPtr++) {
case '\n':
case '\r':
// Newline is end of comment.
return;
case 0:
// If this is the end of the buffer, end the comment.
if (curPtr - 1 == curBuffer.end()) {
--curPtr;
return;
}
LLVM_FALLTHROUGH;
default:
// Skip over other characters.
break;
}
}
}
Token Lexer::lexDirective(const char *tokStart) {
// Match the rest with an identifier regex: [0-9a-zA-Z_]*
while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
StringRef str(tokStart, curPtr - tokStart);
return Token(Token::directive, str);
}
Token Lexer::lexIdentifier(const char *tokStart) {
// Match the rest of the identifier regex: [0-9a-zA-Z_]*
while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
// Check to see if this identifier is a keyword.
StringRef str(tokStart, curPtr - tokStart);
Token::Kind kind = StringSwitch<Token::Kind>(str)
.Case("attr", Token::kw_attr)
.Case("Attr", Token::kw_Attr)
.Case("erase", Token::kw_erase)
.Case("let", Token::kw_let)
.Case("Constraint", Token::kw_Constraint)
.Case("op", Token::kw_op)
.Case("Op", Token::kw_Op)
.Case("OpName", Token::kw_OpName)
.Case("Pattern", Token::kw_Pattern)
.Case("replace", Token::kw_replace)
.Case("rewrite", Token::kw_rewrite)
.Case("type", Token::kw_type)
.Case("Type", Token::kw_Type)
.Case("TypeRange", Token::kw_TypeRange)
.Case("Value", Token::kw_Value)
.Case("ValueRange", Token::kw_ValueRange)
.Case("with", Token::kw_with)
.Case("_", Token::underscore)
.Default(Token::identifier);
return Token(kind, str);
}
Token Lexer::lexNumber(const char *tokStart) {
assert(isdigit(curPtr[-1]));
// Handle the normal decimal case.
while (isdigit(*curPtr)) ++curPtr;
return formToken(Token::integer, tokStart);
}
Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
while (true) {
switch (*curPtr++) {
case '"':
// If this is a string block, we only end the string when we encounter a
// `}]`.
if (!isStringBlock) return formToken(Token::string, tokStart);
continue;
case '}':
// If this is a string block, we only end the string when we encounter a
// `}]`.
if (!isStringBlock || *curPtr != ']') continue;
++curPtr;
return formToken(Token::string_block, tokStart);
case 0:
// If this is a random nul character in the middle of a string, just
// include it. If it is the end of file, then it is an error.
if (curPtr - 1 != curBuffer.end()) continue;
LLVM_FALLTHROUGH;
case '\n':
case '\v':
case '\f':
// String blocks allow multiple lines.
if (!isStringBlock)
return emitError(curPtr - 1, "expected '\"' in string literal");
continue;
case '\\':
// Handle explicitly a few escapes.
if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
*curPtr == 't') {
++curPtr;
} else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
// Support \xx for two hex digits.
curPtr += 2;
} else {
return emitError(curPtr - 1, "unknown escape in string literal");
}
continue;
default:
continue;
}
}
}

View File

@ -0,0 +1,220 @@
//===- Lexer.h - MLIR PDLL Frontend Lexer -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef LIB_TOOLS_PDLL_PARSER_LEXER_H_
#define LIB_TOOLS_PDLL_PARSER_LEXER_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/SMLoc.h"
namespace llvm {
class SourceMgr;
} // namespace llvm
namespace mlir {
struct LogicalResult;
namespace pdll {
namespace ast {
class DiagnosticEngine;
} // namespace ast
//===----------------------------------------------------------------------===//
// Token
//===----------------------------------------------------------------------===//
class Token {
public:
enum Kind {
// Markers.
eof,
error,
// Keywords.
KW_BEGIN,
// Dependent keywords, i.e. those that are treated as keywords depending on
// the current parser context.
KW_DEPENDENT_BEGIN,
kw_attr,
kw_op,
kw_type,
KW_DEPENDENT_END,
// General keywords.
kw_Attr,
kw_erase,
kw_let,
kw_Constraint,
kw_Op,
kw_OpName,
kw_Pattern,
kw_replace,
kw_rewrite,
kw_Type,
kw_TypeRange,
kw_Value,
kw_ValueRange,
kw_with,
KW_END,
// Punctuation.
arrow,
colon,
comma,
dot,
equal,
equal_arrow,
semicolon,
// Paired punctuation.
less,
greater,
l_brace,
r_brace,
l_paren,
r_paren,
l_square,
r_square,
underscore,
// Tokens.
directive,
identifier,
integer,
string_block,
string
};
Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
/// Given a token containing a string literal, return its value, including
/// removing the quote characters and unescaping the contents of the string.
std::string getStringValue() const;
/// Returns true if the current token is a string literal.
bool isString() const { return isAny(Token::string, Token::string_block); }
/// Returns true if the current token is a keyword.
bool isKeyword() const {
return kind > Token::KW_BEGIN && kind < Token::KW_END;
}
/// Returns true if the current token is a keyword in a dependent context, and
/// in any other situation (e.g. variable names) may be treated as an
/// identifier.
bool isDependentKeyword() const {
return kind > Token::KW_DEPENDENT_BEGIN && kind < Token::KW_DEPENDENT_END;
}
/// Return the bytes that make up this token.
StringRef getSpelling() const { return spelling; }
/// Return the kind of this token.
Kind getKind() const { return kind; }
/// Return true if this token is one of the specified kinds.
bool isAny(Kind k1, Kind k2) const { return is(k1) || is(k2); }
template <typename... T>
bool isAny(Kind k1, Kind k2, Kind k3, T... others) const {
return is(k1) || isAny(k2, k3, others...);
}
/// Return if the token does not have the given kind.
bool isNot(Kind k) const { return k != kind; }
template <typename... T> bool isNot(Kind k1, Kind k2, T... others) const {
return !isAny(k1, k2, others...);
}
/// Return if the token has the given kind.
bool is(Kind K) const { return kind == K; }
/// Return a location for the start of this token.
llvm::SMLoc getStartLoc() const {
return llvm::SMLoc::getFromPointer(spelling.data());
}
/// Return a location at the end of this token.
llvm::SMLoc getEndLoc() const {
return llvm::SMLoc::getFromPointer(spelling.data() + spelling.size());
}
/// Return a location for the range of this token.
llvm::SMRange getLoc() const {
return llvm::SMRange(getStartLoc(), getEndLoc());
}
private:
/// Discriminator that indicates the kind of token this is.
Kind kind;
/// A reference to the entire token contents; this is always a pointer into
/// a memory buffer owned by the source manager.
StringRef spelling;
};
//===----------------------------------------------------------------------===//
// Lexer
//===----------------------------------------------------------------------===//
class Lexer {
public:
Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine);
~Lexer();
/// Return a reference to the source manager used by the lexer.
llvm::SourceMgr &getSourceMgr() { return srcMgr; }
/// Return a reference to the diagnostic engine used by the lexer.
ast::DiagnosticEngine &getDiagEngine() { return diagEngine; }
/// Push an include of the given file. This will cause the lexer to start
/// processing the provided file. Returns failure if the file could not be
/// opened, success otherwise.
LogicalResult pushInclude(StringRef filename);
/// Lex the next token and return it.
Token lexToken();
/// Change the position of the lexer cursor. The next token we lex will start
/// at the designated point in the input.
void resetPointer(const char *newPointer) { curPtr = newPointer; }
/// Emit an error to the lexer with the given location and message.
Token emitError(llvm::SMRange loc, const Twine &msg);
Token emitError(const char *loc, const Twine &msg);
Token emitErrorAndNote(llvm::SMRange loc, const Twine &msg,
llvm::SMRange noteLoc, const Twine &note);
private:
Token formToken(Token::Kind kind, const char *tokStart) {
return Token(kind, StringRef(tokStart, curPtr - tokStart));
}
/// Return the next character in the stream.
int getNextChar();
/// Lex methods.
void lexComment();
Token lexDirective(const char *tokStart);
Token lexIdentifier(const char *tokStart);
Token lexNumber(const char *tokStart);
Token lexString(const char *tokStart, bool isStringBlock);
llvm::SourceMgr &srcMgr;
int curBufferID;
StringRef curBuffer;
const char *curPtr;
/// The engine used to emit diagnostics during lexing/parsing.
ast::DiagnosticEngine &diagEngine;
/// A flag indicating if we added a default diagnostic handler to the provided
/// diagEngine.
bool addedHandlerToDiagEngine;
};
} // namespace pdll
} // namespace mlir
#endif // LIB_TOOLS_PDLL_PARSER_LEXER_H_

File diff suppressed because it is too large Load Diff

View File

@ -78,6 +78,7 @@ set(MLIR_TEST_DEPENDS
mlir-linalg-ods-yaml-gen
mlir-lsp-server
mlir-opt
mlir-pdll
mlir-reduce
mlir-tblgen
mlir-translate

View File

@ -21,7 +21,7 @@ config.name = 'MLIR'
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test']
config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test', '.pdll']
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
@ -68,6 +68,7 @@ tools = [
'mlir-cpu-runner',
'mlir-linalg-ods-yaml-gen',
'mlir-reduce',
'mlir-pdll',
]
# The following tools are optional

View File

@ -0,0 +1,23 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
// CHECK: unknown directive `#foo`
#foo
// -----
//===----------------------------------------------------------------------===//
// Include
//===----------------------------------------------------------------------===//
// CHECK: expected string file name after `include` directive
#include <>
// -----
// CHECK: unable to open include file `unknown_file.pdll`
#include "unknown_file.pdll"
// -----
// CHECK: expected include filename to end with `.pdll`
#include "unknown_file.foo"

View File

@ -0,0 +1,62 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
//===----------------------------------------------------------------------===//
// Reference Expr
//===----------------------------------------------------------------------===//
Pattern {
// CHECK: expected identifier constraint
let foo = Foo: ;
}
// -----
Pattern {
// CHECK: undefined reference to `bar`
let foo = bar;
}
// -----
Pattern FooPattern {
erase _: Op;
}
Pattern {
// CHECK: invalid reference to `FooPattern`
let foo = FooPattern;
}
// -----
Pattern {
// CHECK: expected `:` after `_` variable
let foo = _;
}
// -----
Pattern {
// CHECK: expected identifier constraint
let foo = _: ;
}
// -----
//===----------------------------------------------------------------------===//
// Member Access Expr
//===----------------------------------------------------------------------===//
Pattern {
// CHECK: expected identifier or numeric member name
let root: Op;
erase root.<>;
}
// -----
Pattern {
// CHECK: invalid member access `unknown_result` on expression of type `Op`
let root: Op;
erase root.unknown_result;
}

View File

@ -0,0 +1,15 @@
// RUN: mlir-pdll %s -I %S | FileCheck %s
Pattern BeforeIncludedPattern {
erase _: Op;
}
#include "include/included.pdll"
Pattern AfterIncludedPattern {
erase _: Op;
}
// CHECK: PatternDecl {{.*}} Name<BeforeIncludedPattern>
// CHECK: PatternDecl {{.*}} Name<IncludedPattern>
// CHECK: PatternDecl {{.*}} Name<AfterIncludedPattern>

View File

@ -0,0 +1,5 @@
// This file is included by 'include.pdll' as part of testing include files.
Pattern IncludedPattern {
erase _: Op;
}

View File

@ -0,0 +1,2 @@
config.suffixes = ['.pdll']
config.excludes = ['include']

View File

@ -0,0 +1,26 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
// CHECK: expected `{` to start pattern body
Pattern }
// -----
// CHECK: :6:9: error: `Foo` has already been defined
// CHECK: :5:9: note: see previous definition here
Pattern Foo { erase root: Op; }
Pattern Foo { erase root: Op; }
// -----
// CHECK: expected Pattern body to terminate with an operation rewrite statement
Pattern {
let value: Value;
}
// -----
// CHECK: Pattern body was terminated by an operation rewrite statement, but found trailing statements
Pattern {
erase root: Op;
let value: Value;
}

View File

@ -0,0 +1,17 @@
// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s
// CHECK: Module
// CHECK: `-PatternDecl
// CHECK: `-CompoundStmt
// CHECK: `-EraseStmt
Pattern {
erase _: Op;
}
// -----
// CHECK: Module
// CHECK: `-PatternDecl {{.*}} Name<NamedPattern>
Pattern NamedPattern {
erase _: Op;
}

View File

@ -0,0 +1,222 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
// CHECK: expected top-level declaration, such as a `Pattern`
10
// -----
Pattern {
// CHECK: expected `;` after statement
erase _: Op
}
// -----
//===----------------------------------------------------------------------===//
// `erase`
//===----------------------------------------------------------------------===//
Pattern {
// CHECK: expected expression
erase;
}
// -----
Pattern {
// CHECK: expected `Op` expression
erase _: Attr;
}
// -----
//===----------------------------------------------------------------------===//
// `let`
//===----------------------------------------------------------------------===//
Pattern {
// CHECK: expected identifier after `let` to name a new variable
let 5;
}
// -----
Pattern {
// CHECK: `_` may only be used to define "inline" variables
let _;
}
// -----
Pattern {
// CHECK: expected expression
let foo: Attr<>;
}
// -----
Pattern {
// CHECK: expected expression of `Type` in type constraint
let foo: Attr<_: Attr>;
}
// -----
Pattern {
// CHECK: expected `>` after variable type constraint
let foo: Attr<_: Type{};
}
// -----
Pattern {
// CHECK: the type of this variable has already been constrained
let foo: [Attr<_: Type>, Attr<_: Type];
}
// -----
Pattern {
// CHECK: expected `.` after dialect namespace
let foo: Op<builtin>;
}
// -----
Pattern {
// CHECK: expected operation name after dialect namespace
let foo: Op<builtin.>;
}
// -----
Pattern {
// CHECK: expected `>` after operation name
let foo: Op<builtin.func<;
}
// -----
Pattern {
// CHECK: expected expression
let foo: Value<>;
}
// -----
Pattern {
// CHECK: expected expression of `Type` in type constraint
let foo: Value<_: Attr>;
}
// -----
Pattern {
// CHECK: expected `>` after variable type constraint
let foo: Value<_: Type{};
}
// -----
Pattern {
// CHECK: the type of this variable has already been constrained
let foo: [Value<_: Type>, Value<_: Type];
}
// -----
Pattern {
// CHECK: expected expression
let foo: ValueRange<10>;
}
// -----
Pattern {
// CHECK: expected expression of `TypeRange` in type constraint
let foo: ValueRange<_: Type>;
}
// -----
Pattern {
// CHECK: expected `>` after variable type constraint
let foo: ValueRange<_: Type{};
}
// -----
Pattern {
// CHECK: the type of this variable has already been constrained
let foo: [ValueRange<_: Type>, ValueRange<_: Type];
}
// -----
Pattern {
// CHECK: unknown reference to constraint `UnknownConstraint`
let foo: UnknownConstraint;
}
// -----
Pattern Foo {
erase root: Op;
}
Pattern {
// CHECK: invalid reference to non-constraint
let foo: Foo;
}
// -----
Pattern {
// CHECK: constraint type `Attr` is incompatible with the previously inferred type `Value`
let foo: [Value, Attr];
}
// -----
Pattern {
// CHECK: expected `]` after constraint list
let foo: [Attr[];
}
// -----
Pattern {
// CHECK: expected expression
let foo: Attr = ;
}
// -----
Pattern {
// CHECK: type constraints are not permitted on variables with initializers
let foo: ValueRange<_: Type> = _: Op;
}
// -----
Pattern {
// CHECK: unable to infer type for variable `foo`
// CHECK: note: the type of a variable must be inferable from the constraint list or the initializer
let foo;
}
// -----
Pattern {
// CHECK: unable to convert expression of type `Attr` to the expected type of `Value`
let foo: Value = _: Attr;
}
// -----
Pattern {
// CHECK: :7:7: error: `foo` has already been defined
// CHECK: :6:7: note: see previous definition here
let foo: Attr;
let foo: Attr;
}

View File

@ -0,0 +1,155 @@
// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s
//===----------------------------------------------------------------------===//
// CompoundStmt
//===----------------------------------------------------------------------===//
// CHECK: Module
// CHECK: CompoundStmt
// CHECK: |-LetStmt
// CHECK: `-EraseStmt
Pattern {
let root: Op;
erase root;
}
// -----
//===----------------------------------------------------------------------===//
// EraseStmt
//===----------------------------------------------------------------------===//
// CHECK: Module
// CHECK: EraseStmt
// CHECK: `-DeclRefExpr {{.*}} Type<Op>
Pattern {
erase _: Op;
}
// -----
//===----------------------------------------------------------------------===//
// LetStmt
//===----------------------------------------------------------------------===//
// CHECK: Module
// CHECK: LetStmt
// CHECK: `-VariableDecl {{.*}} Name<attrVar> Type<Attr>
// CHECK: `Constraints`
// CHECK: `-AttrConstraintDecl
// CHECK: `-VariableDecl {{.*}} Name<var> Type<Op>
// CHECK: `Constraints`
// CHECK: `-OpConstraintDecl
// CHECK: `-OpNameDecl
Pattern {
let attrVar: Attr;
let var: Op;
erase var;
}
// -----
// Check for proper refinement between constraint types.
// CHECK: Module
// CHECK: LetStmt
// CHECK: `-VariableDecl {{.*}} Name<var> Type<Op<dialect.op>>
// CHECK: `Constraints`
// CHECK: `-OpConstraintDecl
// CHECK: `-OpNameDecl
// CHECK: `-OpConstraintDecl
// CHECK: `-OpNameDecl {{.*}} Name<dialect.op>
Pattern {
let var: [Op, Op<dialect.op>];
erase var;
}
// -----
// Check for proper conversion between initializer and constraint type.
// CHECK: Module
// CHECK: LetStmt
// CHECK: `-VariableDecl {{.*}} Name<var> Type<Op<dialect.op>>
// CHECK: `-DeclRefExpr {{.*}} Type<Op<dialect.op>>
// CHECK: `-VariableDecl {{.*}} Name<input>
// CHECK: `Constraints`
// CHECK: `-OpConstraintDecl
// CHECK: `-OpNameDecl
Pattern {
let input: Op<dialect.op>;
let var: Op = input;
erase var;
}
// -----
// Check for proper conversion between initializer and constraint type.
// CHECK: Module
// CHECK: LetStmt
// CHECK: `-VariableDecl {{.*}} Name<var> Type<Value>
// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type<Value>
// CHECK: `-DeclRefExpr {{.*}} Type<Op<dialect.op>>
// CHECK: `-VariableDecl {{.*}} Name<input>
// CHECK: `Constraints`
// CHECK: `-ValueConstraintDecl
Pattern {
let input: Op<dialect.op>;
let var: Value = input;
erase _: Op;
}
// -----
// Check for proper conversion between initializer and constraint type.
// CHECK: Module
// CHECK: LetStmt
// CHECK: `-VariableDecl {{.*}} Name<var> Type<ValueRange>
// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type<ValueRange>
// CHECK: `-DeclRefExpr {{.*}} Type<Op<dialect.op>>
// CHECK: `-VariableDecl {{.*}} Name<input>
// CHECK: `Constraints`
// CHECK: `-ValueRangeConstraintDecl
Pattern {
let input: Op<dialect.op>;
let var: ValueRange = input;
erase _: Op;
}
// -----
// Check for proper handling of type constraints.
// CHECK: Module
// CHECK: LetStmt
// CHECK: `-VariableDecl {{.*}} Name<var> Type<Value>
// CHECK: `Constraints`
// CHECK: `-ValueConstraintDecl
// CHECK: `-DeclRefExpr {{.*}} Type<Type>
// CHECK: `-VariableDecl {{.*}} Name<_> Type<Type>
// CHECK: `Constraints`
// CHECK: `-TypeConstraintDecl
Pattern {
let var: Value<_: Type>;
erase _: Op;
}
// -----
// Check for proper handling of type constraints.
// CHECK: Module
// CHECK: LetStmt
// CHECK: `-VariableDecl {{.*}} Name<var> Type<ValueRange>
// CHECK: `Constraints`
// CHECK: `-ValueRangeConstraintDecl
// CHECK: `-DeclRefExpr {{.*}} Type<TypeRange>
// CHECK: `-VariableDecl {{.*}} Name<_> Type<TypeRange>
// CHECK: `Constraints`
// CHECK: `-TypeRangeConstraintDecl
Pattern {
let var: ValueRange<_: TypeRange>;
erase _: Op;
}

View File

@ -1,6 +1,7 @@
add_subdirectory(mlir-cpu-runner)
add_subdirectory(mlir-lsp-server)
add_subdirectory(mlir-opt)
add_subdirectory(mlir-pdll)
add_subdirectory(mlir-reduce)
add_subdirectory(mlir-shlib)
add_subdirectory(mlir-spirv-cpu-runner)

View File

@ -0,0 +1,16 @@
set(LIBS
MLIRPDLLAST
MLIRPDLLParser
)
add_llvm_tool(mlir-pdll
mlir-pdll.cpp
DEPENDS
${LIBS}
)
target_link_libraries(mlir-pdll PRIVATE ${LIBS})
llvm_update_compile_flags(mlir-pdll)
mlir_check_all_link_libraries(mlir-pdll)

View File

@ -0,0 +1,111 @@
//===- mlir-pdll.cpp - MLIR PDLL frontend -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/ToolUtilities.h"
#include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/Parser/Parser.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
using namespace mlir::pdll;
//===----------------------------------------------------------------------===//
// main
//===----------------------------------------------------------------------===//
/// The desired output type.
enum class OutputType {
AST,
};
static LogicalResult
processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
OutputType outputType, std::vector<std::string> &includeDirs) {
llvm::SourceMgr sourceMgr;
sourceMgr.setIncludeDirs(includeDirs);
sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), llvm::SMLoc());
ast::Context astContext;
FailureOr<ast::Module *> module = parsePDLAST(astContext, sourceMgr);
if (failed(module))
return failure();
switch (outputType) {
case OutputType::AST:
(*module)->print(os);
break;
}
return success();
}
int main(int argc, char **argv) {
llvm::cl::opt<std::string> inputFilename(
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
llvm::cl::value_desc("filename"));
llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
llvm::cl::list<std::string> includeDirs(
"I", llvm::cl::desc("Directory of include files"),
llvm::cl::value_desc("directory"), llvm::cl::Prefix);
llvm::cl::opt<bool> splitInputFile(
"split-input-file",
llvm::cl::desc("Split the input file into pieces and process each "
"chunk independently"),
llvm::cl::init(false));
llvm::cl::opt<enum OutputType> outputType(
"x", llvm::cl::init(OutputType::AST),
llvm::cl::desc("The type of output desired"),
llvm::cl::values(clEnumValN(OutputType::AST, "ast",
"generate the AST for the input file")));
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv, "PDLL Frontend");
// Set up the input file.
std::string errorMessage;
std::unique_ptr<llvm::MemoryBuffer> inputFile =
openInputFile(inputFilename, &errorMessage);
if (!inputFile) {
llvm::errs() << errorMessage << "\n";
return 1;
}
// Set up the output file.
std::unique_ptr<llvm::ToolOutputFile> outputFile =
openOutputFile(outputFilename, &errorMessage);
if (!outputFile) {
llvm::errs() << errorMessage << "\n";
return 1;
}
// The split-input-file mode is a very specific mode that slices the file
// up into small pieces and checks each independently.
auto processFn = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs);
};
if (splitInputFile) {
if (failed(splitAndProcessBuffer(std::move(inputFile), processFn,
outputFile->os())))
return 1;
} else if (failed(processFn(std::move(inputFile), outputFile->os()))) {
return 1;
}
outputFile->keep();
return 0;
}

View File

@ -7923,3 +7923,51 @@ cc_binary(
"//mlir/test:TestDialect",
],
)
cc_library(
name = "PDLLAST",
srcs = glob(
[
"lib/Tools/PDLL/AST/*.cpp",
"lib/Tools/PDLL/AST/*.h",
],
),
hdrs = glob(["include/mlir/Tools/PDLL/AST/*.h"]),
includes = ["include"],
deps = [
"//llvm:Support",
"//mlir:Support",
],
)
cc_library(
name = "PDLLParser",
srcs = glob(
[
"lib/Tools/PDLL/Parser/*.cpp",
"lib/Tools/PDLL/Parser/*.h",
],
),
hdrs = glob(["include/mlir/Tools/PDLL/Parser/*.h"]),
includes = ["include"],
deps = [
":PDLLAST",
":Support",
":TableGen",
"//llvm:Support",
],
)
cc_binary(
name = "mlir-pdll",
srcs = [
"tools/mlir-pdll/mlir-pdll.cpp",
],
deps = [
":PDLLAST",
":PDLLParser",
":Support",
"//llvm:Support",
"//llvm:config",
],
)