[mlir][PDLL] Add an initial frontend for PDLL
This is a new pattern rewrite frontend designed from the ground up to support MLIR constructs, and to target PDL. This frontend language was proposed in https://llvm.discourse.group/t/rfc-pdll-a-new-declarative-rewrite-frontend-for-mlir/4798 This commit starts sketching out the base structure of the frontend, and is intended to be a minimal starting point for building up the language. It essentially contains support for defining a pattern, variables, and erasing an operation. The features mentioned in the proposal RFC (including IDE support) will be added incrementally in followup commits. I intend to upstream the documentation for the language in a followup when a bit more of the pieces have been landed. Differential Revision: https://reviews.llvm.org/D115093
This commit is contained in:
		
							parent
							
								
									8f1ea2e85c
								
							
						
					
					
						commit
						11d26bd143
					
				| 
						 | 
				
			
			@ -0,0 +1,52 @@
 | 
			
		|||
//===- Context.h - PDLL AST Context -----------------------------*- C++ -*-===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#ifndef MLIR_TOOLS_PDLL_AST_CONTEXT_H_
 | 
			
		||||
#define MLIR_TOOLS_PDLL_AST_CONTEXT_H_
 | 
			
		||||
 | 
			
		||||
#include "mlir/Support/StorageUniquer.h"
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace pdll {
 | 
			
		||||
namespace ast {
 | 
			
		||||
/// This class represents the main context of the PDLL AST. It handles
 | 
			
		||||
/// allocating all of the AST constructs, and manages all state necessary for
 | 
			
		||||
/// the AST.
 | 
			
		||||
class Context {
 | 
			
		||||
public:
 | 
			
		||||
  Context();
 | 
			
		||||
  Context(const Context &) = delete;
 | 
			
		||||
  Context &operator=(const Context &) = delete;
 | 
			
		||||
 | 
			
		||||
  /// Return the allocator owned by this context.
 | 
			
		||||
  llvm::BumpPtrAllocator &getAllocator() { return allocator; }
 | 
			
		||||
 | 
			
		||||
  /// Return the storage uniquer used for AST types.
 | 
			
		||||
  StorageUniquer &getTypeUniquer() { return typeUniquer; }
 | 
			
		||||
 | 
			
		||||
  /// Return the diagnostic engine of this context.
 | 
			
		||||
  DiagnosticEngine &getDiagEngine() { return diagEngine; }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  /// The diagnostic engine of this AST context.
 | 
			
		||||
  DiagnosticEngine diagEngine;
 | 
			
		||||
 | 
			
		||||
  /// The allocator used for AST nodes, and other entities allocated within the
 | 
			
		||||
  /// context.
 | 
			
		||||
  llvm::BumpPtrAllocator allocator;
 | 
			
		||||
 | 
			
		||||
  /// The uniquer used for creating AST types.
 | 
			
		||||
  StorageUniquer typeUniquer;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace ast
 | 
			
		||||
} // namespace pdll
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif // MLIR_TOOLS_PDLL_AST_CONTEXT_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,182 @@
 | 
			
		|||
//===- Diagnostic.h - PDLL AST Diagnostics ----------------------*- C++ -*-===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#ifndef MLIR_TOOLS_PDLL_AST_DIAGNOSTICS_H_
 | 
			
		||||
#define MLIR_TOOLS_PDLL_AST_DIAGNOSTICS_H_
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include "mlir/Support/LLVM.h"
 | 
			
		||||
#include "mlir/Support/LogicalResult.h"
 | 
			
		||||
#include "llvm/ADT/FunctionExtras.h"
 | 
			
		||||
#include "llvm/Support/SourceMgr.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace pdll {
 | 
			
		||||
namespace ast {
 | 
			
		||||
class DiagnosticEngine;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Diagnostic
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class provides a simple implementation of a PDLL diagnostic.
 | 
			
		||||
class Diagnostic {
 | 
			
		||||
public:
 | 
			
		||||
  using Severity = llvm::SourceMgr::DiagKind;
 | 
			
		||||
 | 
			
		||||
  /// Return the severity of this diagnostic.
 | 
			
		||||
  Severity getSeverity() const { return severity; }
 | 
			
		||||
 | 
			
		||||
  /// Return the message of this diagnostic.
 | 
			
		||||
  StringRef getMessage() const { return message; }
 | 
			
		||||
 | 
			
		||||
  /// Return the location of this diagnostic.
 | 
			
		||||
  llvm::SMRange getLocation() const { return location; }
 | 
			
		||||
 | 
			
		||||
  /// Return the notes of this diagnostic.
 | 
			
		||||
  auto getNotes() const { return llvm::make_pointee_range(notes); }
 | 
			
		||||
 | 
			
		||||
  /// Attach a note to this diagnostic.
 | 
			
		||||
  Diagnostic &attachNote(const Twine &msg,
 | 
			
		||||
                         Optional<llvm::SMRange> noteLoc = llvm::None) {
 | 
			
		||||
    assert(getSeverity() != Severity::DK_Note &&
 | 
			
		||||
           "cannot attach a Note to a Note");
 | 
			
		||||
    notes.emplace_back(
 | 
			
		||||
        new Diagnostic(Severity::DK_Note, noteLoc.getValueOr(location), msg));
 | 
			
		||||
    return *notes.back();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Allow an inflight diagnostic to be converted to 'failure', otherwise
 | 
			
		||||
  /// 'success' if this is an empty diagnostic.
 | 
			
		||||
  operator LogicalResult() const { return failure(); }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  Diagnostic(Severity severity, llvm::SMRange loc, const Twine &msg)
 | 
			
		||||
      : severity(severity), message(msg.str()), location(loc) {}
 | 
			
		||||
 | 
			
		||||
  // Allow access to the constructor.
 | 
			
		||||
  friend DiagnosticEngine;
 | 
			
		||||
 | 
			
		||||
  /// The severity of this diagnostic.
 | 
			
		||||
  Severity severity;
 | 
			
		||||
  /// The message held by this diagnostic.
 | 
			
		||||
  std::string message;
 | 
			
		||||
  /// The raw location of this diagnostic.
 | 
			
		||||
  llvm::SMRange location;
 | 
			
		||||
  /// Any additional note diagnostics attached to this diagnostic.
 | 
			
		||||
  std::vector<std::unique_ptr<Diagnostic>> notes;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// InFlightDiagnostic
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a diagnostic that is inflight and set to be reported.
 | 
			
		||||
/// This allows for last minute modifications of the diagnostic before it is
 | 
			
		||||
/// emitted by a DiagnosticEngine.
 | 
			
		||||
class InFlightDiagnostic {
 | 
			
		||||
public:
 | 
			
		||||
  InFlightDiagnostic() = default;
 | 
			
		||||
  InFlightDiagnostic(InFlightDiagnostic &&rhs)
 | 
			
		||||
      : owner(rhs.owner), impl(std::move(rhs.impl)) {
 | 
			
		||||
    // Reset the rhs diagnostic.
 | 
			
		||||
    rhs.impl.reset();
 | 
			
		||||
    rhs.abandon();
 | 
			
		||||
  }
 | 
			
		||||
  ~InFlightDiagnostic() {
 | 
			
		||||
    if (isInFlight())
 | 
			
		||||
      report();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Access the internal diagnostic.
 | 
			
		||||
  Diagnostic &operator*() { return *impl; }
 | 
			
		||||
  Diagnostic *operator->() { return &*impl; }
 | 
			
		||||
 | 
			
		||||
  /// Reports the diagnostic to the engine.
 | 
			
		||||
  void report();
 | 
			
		||||
 | 
			
		||||
  /// Abandons this diagnostic so that it will no longer be reported.
 | 
			
		||||
  void abandon() { owner = nullptr; }
 | 
			
		||||
 | 
			
		||||
  /// Allow an inflight diagnostic to be converted to 'failure', otherwise
 | 
			
		||||
  /// 'success' if this is an empty diagnostic.
 | 
			
		||||
  operator LogicalResult() const { return failure(isActive()); }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  InFlightDiagnostic &operator=(const InFlightDiagnostic &) = delete;
 | 
			
		||||
  InFlightDiagnostic &operator=(InFlightDiagnostic &&) = delete;
 | 
			
		||||
  InFlightDiagnostic(DiagnosticEngine *owner, Diagnostic &&rhs)
 | 
			
		||||
      : owner(owner), impl(std::move(rhs)) {}
 | 
			
		||||
 | 
			
		||||
  /// Returns true if the diagnostic is still active, i.e. it has a live
 | 
			
		||||
  /// diagnostic.
 | 
			
		||||
  bool isActive() const { return impl.hasValue(); }
 | 
			
		||||
 | 
			
		||||
  /// Returns true if the diagnostic is still in flight to be reported.
 | 
			
		||||
  bool isInFlight() const { return owner; }
 | 
			
		||||
 | 
			
		||||
  // Allow access to the constructor.
 | 
			
		||||
  friend DiagnosticEngine;
 | 
			
		||||
 | 
			
		||||
  /// The engine that this diagnostic is to report to.
 | 
			
		||||
  DiagnosticEngine *owner = nullptr;
 | 
			
		||||
 | 
			
		||||
  /// The raw diagnostic that is inflight to be reported.
 | 
			
		||||
  Optional<Diagnostic> impl;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// DiagnosticEngine
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class manages the construction and emission of PDLL diagnostics.
 | 
			
		||||
class DiagnosticEngine {
 | 
			
		||||
public:
 | 
			
		||||
  /// A function used to handle diagnostics emitted by the engine.
 | 
			
		||||
  using HandlerFn = llvm::unique_function<void(Diagnostic &)>;
 | 
			
		||||
 | 
			
		||||
  /// Emit an error to the diagnostic engine.
 | 
			
		||||
  InFlightDiagnostic emitError(llvm::SMRange loc, const Twine &msg) {
 | 
			
		||||
    return InFlightDiagnostic(
 | 
			
		||||
        this, Diagnostic(Diagnostic::Severity::DK_Error, loc, msg));
 | 
			
		||||
  }
 | 
			
		||||
  InFlightDiagnostic emitWarning(llvm::SMRange loc, const Twine &msg) {
 | 
			
		||||
    return InFlightDiagnostic(
 | 
			
		||||
        this, Diagnostic(Diagnostic::Severity::DK_Warning, loc, msg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Report the given diagnostic.
 | 
			
		||||
  void report(Diagnostic &&diagnostic) {
 | 
			
		||||
    if (handler)
 | 
			
		||||
      handler(diagnostic);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Get the current handler function of this diagnostic engine.
 | 
			
		||||
  const HandlerFn &getHandlerFn() const { return handler; }
 | 
			
		||||
 | 
			
		||||
  /// Take the current handler function, resetting the current handler to null.
 | 
			
		||||
  HandlerFn takeHandlerFn() {
 | 
			
		||||
    HandlerFn oldHandler = std::move(handler);
 | 
			
		||||
    handler = {};
 | 
			
		||||
    return oldHandler;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Set the handler function for this diagnostic engine.
 | 
			
		||||
  void setHandlerFn(HandlerFn &&newHandler) { handler = std::move(newHandler); }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  /// The registered diagnostic handler function.
 | 
			
		||||
  HandlerFn handler;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace ast
 | 
			
		||||
} // namespace pdll
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif // MLIR_TOOLS_PDLL_AST_DIAGNOSTICS_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,690 @@
 | 
			
		|||
//===- Nodes.h --------------------------------------------------*- C++ -*-===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#ifndef MLIR_TOOLS_PDLL_AST_NODES_H_
 | 
			
		||||
#define MLIR_TOOLS_PDLL_AST_NODES_H_
 | 
			
		||||
 | 
			
		||||
#include "mlir/Support/LLVM.h"
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Types.h"
 | 
			
		||||
#include "llvm/ADT/StringMap.h"
 | 
			
		||||
#include "llvm/ADT/StringRef.h"
 | 
			
		||||
#include "llvm/Support/SMLoc.h"
 | 
			
		||||
#include "llvm/Support/SourceMgr.h"
 | 
			
		||||
#include "llvm/Support/TrailingObjects.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace pdll {
 | 
			
		||||
namespace ast {
 | 
			
		||||
class Context;
 | 
			
		||||
class Decl;
 | 
			
		||||
class Expr;
 | 
			
		||||
class OpNameDecl;
 | 
			
		||||
class VariableDecl;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Name
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class provides a convenient API for interacting with source names. It
 | 
			
		||||
/// contains a string name as well as the source location for that name.
 | 
			
		||||
struct Name {
 | 
			
		||||
  static const Name &create(Context &ctx, StringRef name,
 | 
			
		||||
                            llvm::SMRange location);
 | 
			
		||||
 | 
			
		||||
  /// Return the raw string name.
 | 
			
		||||
  StringRef getName() const { return name; }
 | 
			
		||||
 | 
			
		||||
  /// Get the location of this name.
 | 
			
		||||
  llvm::SMRange getLoc() const { return location; }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  Name() = delete;
 | 
			
		||||
  Name(const Name &) = delete;
 | 
			
		||||
  Name &operator=(const Name &) = delete;
 | 
			
		||||
  Name(StringRef name, llvm::SMRange location)
 | 
			
		||||
      : name(name), location(location) {}
 | 
			
		||||
 | 
			
		||||
  /// The string name of the decl.
 | 
			
		||||
  StringRef name;
 | 
			
		||||
  /// The location of the decl name.
 | 
			
		||||
  llvm::SMRange location;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// DeclScope
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a scope for named AST decls. A scope determines the
 | 
			
		||||
/// visibility and lifetime of a named declaration.
 | 
			
		||||
class DeclScope {
 | 
			
		||||
public:
 | 
			
		||||
  /// Create a new scope with an optional parent scope.
 | 
			
		||||
  DeclScope(DeclScope *parent = nullptr) : parent(parent) {}
 | 
			
		||||
 | 
			
		||||
  /// Return the parent scope of this scope, or nullptr if there is no parent.
 | 
			
		||||
  DeclScope *getParentScope() { return parent; }
 | 
			
		||||
  const DeclScope *getParentScope() const { return parent; }
 | 
			
		||||
 | 
			
		||||
  /// Return all of the decls within this scope.
 | 
			
		||||
  auto getDecls() const { return llvm::make_second_range(decls); }
 | 
			
		||||
 | 
			
		||||
  /// Add a new decl to the scope.
 | 
			
		||||
  void add(Decl *decl);
 | 
			
		||||
 | 
			
		||||
  /// Lookup a decl with the given name starting from this scope. Returns
 | 
			
		||||
  /// nullptr if no decl could be found.
 | 
			
		||||
  Decl *lookup(StringRef name);
 | 
			
		||||
  template <typename T> T *lookup(StringRef name) {
 | 
			
		||||
    return dyn_cast_or_null<T>(lookup(name));
 | 
			
		||||
  }
 | 
			
		||||
  const Decl *lookup(StringRef name) const {
 | 
			
		||||
    return const_cast<DeclScope *>(this)->lookup(name);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename T> const T *lookup(StringRef name) const {
 | 
			
		||||
    return dyn_cast_or_null<T>(lookup(name));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  /// The parent scope, or null if this is a top-level scope.
 | 
			
		||||
  DeclScope *parent;
 | 
			
		||||
  /// The decls defined within this scope.
 | 
			
		||||
  llvm::StringMap<Decl *> decls;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Node
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a base AST node. All AST nodes are derived from this
 | 
			
		||||
/// class, and it contains many of the base functionality for interacting with
 | 
			
		||||
/// nodes.
 | 
			
		||||
class Node {
 | 
			
		||||
public:
 | 
			
		||||
  /// This CRTP class provides several utilies when defining new AST nodes.
 | 
			
		||||
  template <typename T, typename BaseT> class NodeBase : public BaseT {
 | 
			
		||||
  public:
 | 
			
		||||
    using Base = NodeBase<T, BaseT>;
 | 
			
		||||
 | 
			
		||||
    /// Provide type casting support.
 | 
			
		||||
    static bool classof(const Node *node) {
 | 
			
		||||
      return node->getTypeID() == TypeID::get<T>();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  protected:
 | 
			
		||||
    template <typename... Args>
 | 
			
		||||
    explicit NodeBase(llvm::SMRange loc, Args &&...args)
 | 
			
		||||
        : BaseT(TypeID::get<T>(), loc, std::forward<Args>(args)...) {}
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  /// Return the type identifier of this node.
 | 
			
		||||
  TypeID getTypeID() const { return typeID; }
 | 
			
		||||
 | 
			
		||||
  /// Return the location of this node.
 | 
			
		||||
  llvm::SMRange getLoc() const { return loc; }
 | 
			
		||||
 | 
			
		||||
  /// Print this node to the given stream.
 | 
			
		||||
  void print(raw_ostream &os) const;
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  Node(TypeID typeID, llvm::SMRange loc) : typeID(typeID), loc(loc) {}
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  /// A unique type identifier for this node.
 | 
			
		||||
  TypeID typeID;
 | 
			
		||||
 | 
			
		||||
  /// The location of this node.
 | 
			
		||||
  llvm::SMRange loc;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Stmt
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a base AST Statement node.
 | 
			
		||||
class Stmt : public Node {
 | 
			
		||||
public:
 | 
			
		||||
  using Node::Node;
 | 
			
		||||
 | 
			
		||||
  /// Provide type casting support.
 | 
			
		||||
  static bool classof(const Node *node);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// CompoundStmt
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This statement represents a compound statement, which contains a collection
 | 
			
		||||
/// of other statements.
 | 
			
		||||
class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
 | 
			
		||||
                           private llvm::TrailingObjects<CompoundStmt, Stmt *> {
 | 
			
		||||
public:
 | 
			
		||||
  static CompoundStmt *create(Context &ctx, llvm::SMRange location,
 | 
			
		||||
                              ArrayRef<Stmt *> children);
 | 
			
		||||
 | 
			
		||||
  /// Return the children of this compound statement.
 | 
			
		||||
  MutableArrayRef<Stmt *> getChildren() {
 | 
			
		||||
    return {getTrailingObjects<Stmt *>(), numChildren};
 | 
			
		||||
  }
 | 
			
		||||
  ArrayRef<Stmt *> getChildren() const {
 | 
			
		||||
    return const_cast<CompoundStmt *>(this)->getChildren();
 | 
			
		||||
  }
 | 
			
		||||
  ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
 | 
			
		||||
  ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  CompoundStmt(llvm::SMRange location, unsigned numChildren)
 | 
			
		||||
      : Base(location), numChildren(numChildren) {}
 | 
			
		||||
 | 
			
		||||
  /// The number of held children statements.
 | 
			
		||||
  unsigned numChildren;
 | 
			
		||||
 | 
			
		||||
  // Allow access to various privates.
 | 
			
		||||
  friend class llvm::TrailingObjects<CompoundStmt, Stmt *>;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// LetStmt
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This statement represents a `let` statement in PDLL. This statement is used
 | 
			
		||||
/// to define variables.
 | 
			
		||||
class LetStmt final : public Node::NodeBase<LetStmt, Stmt> {
 | 
			
		||||
public:
 | 
			
		||||
  static LetStmt *create(Context &ctx, llvm::SMRange loc,
 | 
			
		||||
                         VariableDecl *varDecl);
 | 
			
		||||
 | 
			
		||||
  /// Return the variable defined by this statement.
 | 
			
		||||
  VariableDecl *getVarDecl() const { return varDecl; }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  LetStmt(llvm::SMRange loc, VariableDecl *varDecl)
 | 
			
		||||
      : Base(loc), varDecl(varDecl) {}
 | 
			
		||||
 | 
			
		||||
  /// The variable defined by this statement.
 | 
			
		||||
  VariableDecl *varDecl;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// OpRewriteStmt
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a base operation rewrite statement. Operation rewrite
 | 
			
		||||
/// statements perform a set of transformations on a given root operation.
 | 
			
		||||
class OpRewriteStmt : public Stmt {
 | 
			
		||||
public:
 | 
			
		||||
  /// Provide type casting support.
 | 
			
		||||
  static bool classof(const Node *node);
 | 
			
		||||
 | 
			
		||||
  /// Return the root operation of this rewrite.
 | 
			
		||||
  Expr *getRootOpExpr() const { return rootOp; }
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  OpRewriteStmt(TypeID typeID, llvm::SMRange loc, Expr *rootOp)
 | 
			
		||||
      : Stmt(typeID, loc), rootOp(rootOp) {}
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  /// The root operation being rewritten.
 | 
			
		||||
  Expr *rootOp;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// EraseStmt
 | 
			
		||||
 | 
			
		||||
/// This statement represents the `erase` statement in PDLL. This statement
 | 
			
		||||
/// erases the given root operation, corresponding roughly to the
 | 
			
		||||
/// PatternRewriter::eraseOp API.
 | 
			
		||||
class EraseStmt final : public Node::NodeBase<EraseStmt, OpRewriteStmt> {
 | 
			
		||||
public:
 | 
			
		||||
  static EraseStmt *create(Context &ctx, llvm::SMRange loc, Expr *rootOp);
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  EraseStmt(llvm::SMRange loc, Expr *rootOp) : Base(loc, rootOp) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Expr
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a base AST Expression node.
 | 
			
		||||
class Expr : public Stmt {
 | 
			
		||||
public:
 | 
			
		||||
  /// Return the type of this expression.
 | 
			
		||||
  Type getType() const { return type; }
 | 
			
		||||
 | 
			
		||||
  /// Provide type casting support.
 | 
			
		||||
  static bool classof(const Node *node);
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  Expr(TypeID typeID, llvm::SMRange loc, Type type)
 | 
			
		||||
      : Stmt(typeID, loc), type(type) {}
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  /// The type of this expression.
 | 
			
		||||
  Type type;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// DeclRefExpr
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This expression represents a reference to a Decl node.
 | 
			
		||||
class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> {
 | 
			
		||||
public:
 | 
			
		||||
  static DeclRefExpr *create(Context &ctx, llvm::SMRange loc, Decl *decl,
 | 
			
		||||
                             Type type);
 | 
			
		||||
 | 
			
		||||
  /// Get the decl referenced by this expression.
 | 
			
		||||
  Decl *getDecl() const { return decl; }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  DeclRefExpr(llvm::SMRange loc, Decl *decl, Type type)
 | 
			
		||||
      : Base(loc, type), decl(decl) {}
 | 
			
		||||
 | 
			
		||||
  /// The decl referenced by this expression.
 | 
			
		||||
  Decl *decl;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// MemberAccessExpr
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This expression represents a named member or field access of a given parent
 | 
			
		||||
/// expression.
 | 
			
		||||
class MemberAccessExpr : public Node::NodeBase<MemberAccessExpr, Expr> {
 | 
			
		||||
public:
 | 
			
		||||
  static MemberAccessExpr *create(Context &ctx, llvm::SMRange loc,
 | 
			
		||||
                                  const Expr *parentExpr, StringRef memberName,
 | 
			
		||||
                                  Type type);
 | 
			
		||||
 | 
			
		||||
  /// Get the parent expression of this access.
 | 
			
		||||
  const Expr *getParentExpr() const { return parentExpr; }
 | 
			
		||||
 | 
			
		||||
  /// Return the name of the member being accessed.
 | 
			
		||||
  StringRef getMemberName() const { return memberName; }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  MemberAccessExpr(llvm::SMRange loc, const Expr *parentExpr,
 | 
			
		||||
                   StringRef memberName, Type type)
 | 
			
		||||
      : Base(loc, type), parentExpr(parentExpr), memberName(memberName) {}
 | 
			
		||||
 | 
			
		||||
  /// The parent expression of this access.
 | 
			
		||||
  const Expr *parentExpr;
 | 
			
		||||
 | 
			
		||||
  /// The name of the member being accessed from the parent.
 | 
			
		||||
  StringRef memberName;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Decl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents the base Decl node.
 | 
			
		||||
class Decl : public Node {
 | 
			
		||||
public:
 | 
			
		||||
  /// Return the name of the decl, or nullptr if it doesn't have one.
 | 
			
		||||
  const Name *getName() const { return name; }
 | 
			
		||||
 | 
			
		||||
  /// Provide type casting support.
 | 
			
		||||
  static bool classof(const Node *node);
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  Decl(TypeID typeID, llvm::SMRange loc, const Name *name = nullptr)
 | 
			
		||||
      : Node(typeID, loc), name(name) {}
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  /// The name of the decl. This is optional for some decls, such as
 | 
			
		||||
  /// PatternDecl.
 | 
			
		||||
  const Name *name;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ConstraintDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents the base of all AST Constraint decls. Constraints
 | 
			
		||||
/// apply matcher conditions to, and define the type of PDLL variables.
 | 
			
		||||
class ConstraintDecl : public Decl {
 | 
			
		||||
public:
 | 
			
		||||
  /// Provide type casting support.
 | 
			
		||||
  static bool classof(const Node *node);
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  ConstraintDecl(TypeID typeID, llvm::SMRange loc, const Name *name = nullptr)
 | 
			
		||||
      : Decl(typeID, loc, name) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/// This class represents a reference to a constraint, and contains a constraint
 | 
			
		||||
/// and the location of the reference.
 | 
			
		||||
struct ConstraintRef {
 | 
			
		||||
  ConstraintRef(const ConstraintDecl *constraint, llvm::SMRange refLoc)
 | 
			
		||||
      : constraint(constraint), referenceLoc(refLoc) {}
 | 
			
		||||
  explicit ConstraintRef(const ConstraintDecl *constraint)
 | 
			
		||||
      : ConstraintRef(constraint, constraint->getLoc()) {}
 | 
			
		||||
 | 
			
		||||
  const ConstraintDecl *constraint;
 | 
			
		||||
  llvm::SMRange referenceLoc;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// CoreConstraintDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents the base of all "core" constraints. Core constraints
 | 
			
		||||
/// are those that generally represent a concrete IR construct, such as
 | 
			
		||||
/// `Type`s or `Value`s.
 | 
			
		||||
class CoreConstraintDecl : public ConstraintDecl {
 | 
			
		||||
public:
 | 
			
		||||
  /// Provide type casting support.
 | 
			
		||||
  static bool classof(const Node *node);
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  CoreConstraintDecl(TypeID typeID, llvm::SMRange loc,
 | 
			
		||||
                     const Name *name = nullptr)
 | 
			
		||||
      : ConstraintDecl(typeID, loc, name) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// AttrConstraintDecl
 | 
			
		||||
 | 
			
		||||
/// The class represents an Attribute constraint, and constrains a variable to
 | 
			
		||||
/// be an Attribute.
 | 
			
		||||
class AttrConstraintDecl
 | 
			
		||||
    : public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> {
 | 
			
		||||
public:
 | 
			
		||||
  static AttrConstraintDecl *create(Context &ctx, llvm::SMRange loc,
 | 
			
		||||
                                    Expr *typeExpr = nullptr);
 | 
			
		||||
 | 
			
		||||
  /// Return the optional type the attribute is constrained to.
 | 
			
		||||
  Expr *getTypeExpr() { return typeExpr; }
 | 
			
		||||
  const Expr *getTypeExpr() const { return typeExpr; }
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  AttrConstraintDecl(llvm::SMRange loc, Expr *typeExpr)
 | 
			
		||||
      : Base(loc), typeExpr(typeExpr) {}
 | 
			
		||||
 | 
			
		||||
  /// An optional type that the attribute is constrained to.
 | 
			
		||||
  Expr *typeExpr;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// OpConstraintDecl
 | 
			
		||||
 | 
			
		||||
/// The class represents an Operation constraint, and constrains a variable to
 | 
			
		||||
/// be an Operation.
 | 
			
		||||
class OpConstraintDecl
 | 
			
		||||
    : public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> {
 | 
			
		||||
public:
 | 
			
		||||
  static OpConstraintDecl *create(Context &ctx, llvm::SMRange loc,
 | 
			
		||||
                                  const OpNameDecl *nameDecl = nullptr);
 | 
			
		||||
 | 
			
		||||
  /// Return the name of the operation, or None if there isn't one.
 | 
			
		||||
  Optional<StringRef> getName() const;
 | 
			
		||||
 | 
			
		||||
  /// Return the declaration of the operation name.
 | 
			
		||||
  const OpNameDecl *getNameDecl() const { return nameDecl; }
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  explicit OpConstraintDecl(llvm::SMRange loc, const OpNameDecl *nameDecl)
 | 
			
		||||
      : Base(loc), nameDecl(nameDecl) {}
 | 
			
		||||
 | 
			
		||||
  /// The operation name of this constraint.
 | 
			
		||||
  const OpNameDecl *nameDecl;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TypeConstraintDecl
 | 
			
		||||
 | 
			
		||||
/// The class represents a Type constraint, and constrains a variable to be a
 | 
			
		||||
/// Type.
 | 
			
		||||
class TypeConstraintDecl
 | 
			
		||||
    : public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> {
 | 
			
		||||
public:
 | 
			
		||||
  static TypeConstraintDecl *create(Context &ctx, llvm::SMRange loc);
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  using Base::Base;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TypeRangeConstraintDecl
 | 
			
		||||
 | 
			
		||||
/// The class represents a TypeRange constraint, and constrains a variable to be
 | 
			
		||||
/// a TypeRange.
 | 
			
		||||
class TypeRangeConstraintDecl
 | 
			
		||||
    : public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> {
 | 
			
		||||
public:
 | 
			
		||||
  static TypeRangeConstraintDecl *create(Context &ctx, llvm::SMRange loc);
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  using Base::Base;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ValueConstraintDecl
 | 
			
		||||
 | 
			
		||||
/// The class represents a Value constraint, and constrains a variable to be a
 | 
			
		||||
/// Value.
 | 
			
		||||
class ValueConstraintDecl
 | 
			
		||||
    : public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {
 | 
			
		||||
public:
 | 
			
		||||
  static ValueConstraintDecl *create(Context &ctx, llvm::SMRange loc,
 | 
			
		||||
                                     Expr *typeExpr);
 | 
			
		||||
 | 
			
		||||
  /// Return the optional type the value is constrained to.
 | 
			
		||||
  Expr *getTypeExpr() { return typeExpr; }
 | 
			
		||||
  const Expr *getTypeExpr() const { return typeExpr; }
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  ValueConstraintDecl(llvm::SMRange loc, Expr *typeExpr)
 | 
			
		||||
      : Base(loc), typeExpr(typeExpr) {}
 | 
			
		||||
 | 
			
		||||
  /// An optional type that the value is constrained to.
 | 
			
		||||
  Expr *typeExpr;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ValueRangeConstraintDecl
 | 
			
		||||
 | 
			
		||||
/// The class represents a ValueRange constraint, and constrains a variable to
 | 
			
		||||
/// be a ValueRange.
 | 
			
		||||
class ValueRangeConstraintDecl
 | 
			
		||||
    : public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> {
 | 
			
		||||
public:
 | 
			
		||||
  static ValueRangeConstraintDecl *create(Context &ctx, llvm::SMRange loc,
 | 
			
		||||
                                          Expr *typeExpr);
 | 
			
		||||
 | 
			
		||||
  /// Return the optional type the value range is constrained to.
 | 
			
		||||
  Expr *getTypeExpr() { return typeExpr; }
 | 
			
		||||
  const Expr *getTypeExpr() const { return typeExpr; }
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  ValueRangeConstraintDecl(llvm::SMRange loc, Expr *typeExpr)
 | 
			
		||||
      : Base(loc), typeExpr(typeExpr) {}
 | 
			
		||||
 | 
			
		||||
  /// An optional type that the value range is constrained to.
 | 
			
		||||
  Expr *typeExpr;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// OpNameDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This Decl represents an OperationName.
 | 
			
		||||
class OpNameDecl : public Node::NodeBase<OpNameDecl, Decl> {
 | 
			
		||||
public:
 | 
			
		||||
  static OpNameDecl *create(Context &ctx, const Name &name);
 | 
			
		||||
  static OpNameDecl *create(Context &ctx, llvm::SMRange loc);
 | 
			
		||||
 | 
			
		||||
  /// Return the name of this operation, or none if the name is unknown.
 | 
			
		||||
  Optional<StringRef> getName() const {
 | 
			
		||||
    const Name *name = Decl::getName();
 | 
			
		||||
    return name ? Optional<StringRef>(name->getName()) : llvm::None;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  explicit OpNameDecl(const Name &name) : Base(name.getLoc(), &name) {}
 | 
			
		||||
  explicit OpNameDecl(llvm::SMRange loc) : Base(loc) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// PatternDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This Decl represents a single Pattern.
 | 
			
		||||
class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
 | 
			
		||||
public:
 | 
			
		||||
  static PatternDecl *create(Context &ctx, llvm::SMRange location,
 | 
			
		||||
                             const Name *name, Optional<uint16_t> benefit,
 | 
			
		||||
                             bool hasBoundedRecursion,
 | 
			
		||||
                             const CompoundStmt *body);
 | 
			
		||||
 | 
			
		||||
  /// Return the benefit of this pattern if specified, or None.
 | 
			
		||||
  Optional<uint16_t> getBenefit() const { return benefit; }
 | 
			
		||||
 | 
			
		||||
  /// Return if this pattern has bounded rewrite recursion.
 | 
			
		||||
  bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
 | 
			
		||||
 | 
			
		||||
  /// Return the body of this pattern.
 | 
			
		||||
  const CompoundStmt *getBody() const { return patternBody; }
 | 
			
		||||
 | 
			
		||||
  /// Return the root rewrite statement of this pattern.
 | 
			
		||||
  const OpRewriteStmt *getRootRewriteStmt() const {
 | 
			
		||||
    return cast<OpRewriteStmt>(patternBody->getChildren().back());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  PatternDecl(llvm::SMRange loc, const Name *name, Optional<uint16_t> benefit,
 | 
			
		||||
              bool hasBoundedRecursion, const CompoundStmt *body)
 | 
			
		||||
      : Base(loc, name), benefit(benefit),
 | 
			
		||||
        hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {}
 | 
			
		||||
 | 
			
		||||
  /// The benefit of the pattern if it was explicitly specified, None otherwise.
 | 
			
		||||
  Optional<uint16_t> benefit;
 | 
			
		||||
 | 
			
		||||
  /// If the pattern has properly bounded rewrite recursion or not.
 | 
			
		||||
  bool hasBoundedRecursion;
 | 
			
		||||
 | 
			
		||||
  /// The compound statement representing the body of the pattern.
 | 
			
		||||
  const CompoundStmt *patternBody;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// VariableDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This Decl represents the definition of a PDLL variable.
 | 
			
		||||
class VariableDecl final
 | 
			
		||||
    : public Node::NodeBase<VariableDecl, Decl>,
 | 
			
		||||
      private llvm::TrailingObjects<VariableDecl, ConstraintRef> {
 | 
			
		||||
public:
 | 
			
		||||
  static VariableDecl *create(Context &ctx, const Name &name, Type type,
 | 
			
		||||
                              Expr *initExpr,
 | 
			
		||||
                              ArrayRef<ConstraintRef> constraints);
 | 
			
		||||
 | 
			
		||||
  /// Return the constraints of this variable.
 | 
			
		||||
  MutableArrayRef<ConstraintRef> getConstraints() {
 | 
			
		||||
    return {getTrailingObjects<ConstraintRef>(), numConstraints};
 | 
			
		||||
  }
 | 
			
		||||
  ArrayRef<ConstraintRef> getConstraints() const {
 | 
			
		||||
    return const_cast<VariableDecl *>(this)->getConstraints();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Return the initializer expression of this statement, or nullptr if there
 | 
			
		||||
  /// was no initializer.
 | 
			
		||||
  Expr *getInitExpr() const { return initExpr; }
 | 
			
		||||
 | 
			
		||||
  /// Return the name of the decl.
 | 
			
		||||
  const Name &getName() const { return *Decl::getName(); }
 | 
			
		||||
 | 
			
		||||
  /// Return the type of the decl.
 | 
			
		||||
  Type getType() const { return type; }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  VariableDecl(const Name &name, Type type, Expr *initExpr,
 | 
			
		||||
               unsigned numConstraints)
 | 
			
		||||
      : Base(name.getLoc(), &name), type(type), initExpr(initExpr),
 | 
			
		||||
        numConstraints(numConstraints) {}
 | 
			
		||||
 | 
			
		||||
  /// The type of the variable.
 | 
			
		||||
  Type type;
 | 
			
		||||
 | 
			
		||||
  /// The optional initializer expression of this statement.
 | 
			
		||||
  Expr *initExpr;
 | 
			
		||||
 | 
			
		||||
  /// The number of constraints attached to this variable.
 | 
			
		||||
  unsigned numConstraints;
 | 
			
		||||
 | 
			
		||||
  /// Allow access to various internals.
 | 
			
		||||
  friend llvm::TrailingObjects<VariableDecl, ConstraintRef>;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Module
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a top-level AST module.
 | 
			
		||||
class Module final : public Node::NodeBase<Module, Node>,
 | 
			
		||||
                     private llvm::TrailingObjects<Module, Decl *> {
 | 
			
		||||
public:
 | 
			
		||||
  static Module *create(Context &ctx, llvm::SMLoc loc,
 | 
			
		||||
                        ArrayRef<Decl *> children);
 | 
			
		||||
 | 
			
		||||
  /// Return the children of this module.
 | 
			
		||||
  MutableArrayRef<Decl *> getChildren() {
 | 
			
		||||
    return {getTrailingObjects<Decl *>(), numChildren};
 | 
			
		||||
  }
 | 
			
		||||
  ArrayRef<Decl *> getChildren() const {
 | 
			
		||||
    return const_cast<Module *>(this)->getChildren();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  Module(llvm::SMLoc loc, unsigned numChildren)
 | 
			
		||||
      : Base(llvm::SMRange{loc, loc}), numChildren(numChildren) {}
 | 
			
		||||
 | 
			
		||||
  /// The number of decls held by this module.
 | 
			
		||||
  unsigned numChildren;
 | 
			
		||||
 | 
			
		||||
  /// Allow access to various internals.
 | 
			
		||||
  friend llvm::TrailingObjects<Module, Decl *>;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Defered Method Definitions
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
inline bool Decl::classof(const Node *node) {
 | 
			
		||||
  return isa<ConstraintDecl, OpNameDecl, PatternDecl, VariableDecl>(node);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline bool ConstraintDecl::classof(const Node *node) {
 | 
			
		||||
  return isa<CoreConstraintDecl>(node);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline bool CoreConstraintDecl::classof(const Node *node) {
 | 
			
		||||
  return isa<AttrConstraintDecl, OpConstraintDecl, TypeConstraintDecl,
 | 
			
		||||
             TypeRangeConstraintDecl, ValueConstraintDecl,
 | 
			
		||||
             ValueRangeConstraintDecl>(node);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline bool Expr::classof(const Node *node) {
 | 
			
		||||
  return isa<DeclRefExpr, MemberAccessExpr>(node);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline bool OpRewriteStmt::classof(const Node *node) {
 | 
			
		||||
  return isa<EraseStmt>(node);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline bool Stmt::classof(const Node *node) {
 | 
			
		||||
  return isa<CompoundStmt, LetStmt, OpRewriteStmt, Expr>(node);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace ast
 | 
			
		||||
} // namespace pdll
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif // MLIR_TOOLS_PDLL_AST_NODES_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,257 @@
 | 
			
		|||
//===- Types.h --------------------------------------------------*- C++ -*-===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#ifndef MLIR_TOOLS_PDLL_AST_TYPES_H_
 | 
			
		||||
#define MLIR_TOOLS_PDLL_AST_TYPES_H_
 | 
			
		||||
 | 
			
		||||
#include "mlir/Support/LLVM.h"
 | 
			
		||||
#include "mlir/Support/StorageUniquer.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace pdll {
 | 
			
		||||
namespace ast {
 | 
			
		||||
class Context;
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
struct AttributeTypeStorage;
 | 
			
		||||
struct ConstraintTypeStorage;
 | 
			
		||||
struct OperationTypeStorage;
 | 
			
		||||
struct RangeTypeStorage;
 | 
			
		||||
struct TypeTypeStorage;
 | 
			
		||||
struct ValueTypeStorage;
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Type
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
class Type {
 | 
			
		||||
public:
 | 
			
		||||
  /// This class represents the internal storage of the Type class.
 | 
			
		||||
  struct Storage;
 | 
			
		||||
 | 
			
		||||
  /// This class provides several utilities when defining derived type classes.
 | 
			
		||||
  template <typename ImplT, typename BaseT = Type>
 | 
			
		||||
  class TypeBase : public BaseT {
 | 
			
		||||
  public:
 | 
			
		||||
    using Base = TypeBase<ImplT, BaseT>;
 | 
			
		||||
    using ImplTy = ImplT;
 | 
			
		||||
    using BaseT::BaseT;
 | 
			
		||||
 | 
			
		||||
    /// Provide type casting support.
 | 
			
		||||
    static bool classof(Type type) {
 | 
			
		||||
      return type.getTypeID() == TypeID::get<ImplTy>();
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  Type(Storage *impl = nullptr) : impl(impl) {}
 | 
			
		||||
  Type(const Type &other) = default;
 | 
			
		||||
 | 
			
		||||
  bool operator==(const Type &other) const { return impl == other.impl; }
 | 
			
		||||
  bool operator!=(const Type &other) const { return !(*this == other); }
 | 
			
		||||
  explicit operator bool() const { return impl; }
 | 
			
		||||
 | 
			
		||||
  /// Provide type casting support.
 | 
			
		||||
  template <typename U> bool isa() const {
 | 
			
		||||
    assert(impl && "isa<> used on a null type.");
 | 
			
		||||
    return U::classof(*this);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename U, typename V, typename... Others> bool isa() const {
 | 
			
		||||
    return isa<U>() || isa<V, Others...>();
 | 
			
		||||
  }
 | 
			
		||||
  template <typename U> U dyn_cast() const {
 | 
			
		||||
    return isa<U>() ? U(impl) : U(nullptr);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename U> U dyn_cast_or_null() const {
 | 
			
		||||
    return (impl && isa<U>()) ? U(impl) : U(nullptr);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename U> U cast() const {
 | 
			
		||||
    assert(isa<U>());
 | 
			
		||||
    return U(impl);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Return the internal storage instance of this type.
 | 
			
		||||
  Storage *getImpl() const { return impl; }
 | 
			
		||||
 | 
			
		||||
  /// Return the TypeID instance of this type.
 | 
			
		||||
  TypeID getTypeID() const;
 | 
			
		||||
 | 
			
		||||
  /// Print this type to the given stream.
 | 
			
		||||
  void print(raw_ostream &os) const;
 | 
			
		||||
 | 
			
		||||
  /// Try to refine this type with the one provided. Given two compatible types,
 | 
			
		||||
  /// this will return a merged type contains as much detail from the two types.
 | 
			
		||||
  /// For example, if refining two operation types and one contains a name,
 | 
			
		||||
  /// while the other doesn't, the refined type contains the name. If the two
 | 
			
		||||
  /// types are incompatible, null is returned.
 | 
			
		||||
  Type refineWith(Type other) const;
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  /// Return the internal storage instance of this type reinterpreted as the
 | 
			
		||||
  /// given derived storage type.
 | 
			
		||||
  template <typename T> const T *getImplAs() const {
 | 
			
		||||
    return static_cast<const T *>(impl);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  Storage *impl;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
inline llvm::hash_code hash_value(Type type) {
 | 
			
		||||
  return DenseMapInfo<Type::Storage *>::getHashValue(type.getImpl());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline raw_ostream &operator<<(raw_ostream &os, Type type) {
 | 
			
		||||
  type.print(os);
 | 
			
		||||
  return os;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// AttributeType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a PDLL type that corresponds to an mlir::Attribute.
 | 
			
		||||
class AttributeType : public Type::TypeBase<detail::AttributeTypeStorage> {
 | 
			
		||||
public:
 | 
			
		||||
  using Base::Base;
 | 
			
		||||
 | 
			
		||||
  /// Return an instance of the Attribute type.
 | 
			
		||||
  static AttributeType get(Context &context);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ConstraintType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a PDLL type that corresponds to a constraint. This
 | 
			
		||||
/// type has no MLIR C++ API correspondance.
 | 
			
		||||
class ConstraintType : public Type::TypeBase<detail::ConstraintTypeStorage> {
 | 
			
		||||
public:
 | 
			
		||||
  using Base::Base;
 | 
			
		||||
 | 
			
		||||
  /// Return an instance of the Constraint type.
 | 
			
		||||
  static ConstraintType get(Context &context);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// OperationType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a PDLL type that corresponds to an mlir::Operation.
 | 
			
		||||
class OperationType : public Type::TypeBase<detail::OperationTypeStorage> {
 | 
			
		||||
public:
 | 
			
		||||
  using Base::Base;
 | 
			
		||||
 | 
			
		||||
  /// Return an instance of the Operation type with an optional operation name.
 | 
			
		||||
  /// If no name is provided, this type may refer to any operation.
 | 
			
		||||
  static OperationType get(Context &context,
 | 
			
		||||
                           Optional<StringRef> name = llvm::None);
 | 
			
		||||
 | 
			
		||||
  /// Return the name of this operation type, or None if it doesn't have on.
 | 
			
		||||
  Optional<StringRef> getName() const;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// RangeType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a PDLL type that corresponds to a range of elements
 | 
			
		||||
/// with a given element type.
 | 
			
		||||
class RangeType : public Type::TypeBase<detail::RangeTypeStorage> {
 | 
			
		||||
public:
 | 
			
		||||
  using Base::Base;
 | 
			
		||||
 | 
			
		||||
  /// Return an instance of the Range type with the given element type.
 | 
			
		||||
  static RangeType get(Context &context, Type elementType);
 | 
			
		||||
 | 
			
		||||
  /// Return the element type of this range.
 | 
			
		||||
  Type getElementType() const;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TypeRangeType
 | 
			
		||||
 | 
			
		||||
/// This class represents a PDLL type that corresponds to an mlir::TypeRange.
 | 
			
		||||
class TypeRangeType : public RangeType {
 | 
			
		||||
public:
 | 
			
		||||
  using RangeType::RangeType;
 | 
			
		||||
 | 
			
		||||
  /// Provide type casting support.
 | 
			
		||||
  static bool classof(Type type);
 | 
			
		||||
 | 
			
		||||
  /// Return an instance of the TypeRange type.
 | 
			
		||||
  static TypeRangeType get(Context &context);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ValueRangeType
 | 
			
		||||
 | 
			
		||||
/// This class represents a PDLL type that corresponds to an mlir::ValueRange.
 | 
			
		||||
class ValueRangeType : public RangeType {
 | 
			
		||||
public:
 | 
			
		||||
  using RangeType::RangeType;
 | 
			
		||||
 | 
			
		||||
  /// Provide type casting support.
 | 
			
		||||
  static bool classof(Type type);
 | 
			
		||||
 | 
			
		||||
  /// Return an instance of the ValueRange type.
 | 
			
		||||
  static ValueRangeType get(Context &context);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TypeType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a PDLL type that corresponds to an mlir::Type.
 | 
			
		||||
class TypeType : public Type::TypeBase<detail::TypeTypeStorage> {
 | 
			
		||||
public:
 | 
			
		||||
  using Base::Base;
 | 
			
		||||
 | 
			
		||||
  /// Return an instance of the Type type.
 | 
			
		||||
  static TypeType get(Context &context);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ValueType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// This class represents a PDLL type that corresponds to an mlir::Value.
 | 
			
		||||
class ValueType : public Type::TypeBase<detail::ValueTypeStorage> {
 | 
			
		||||
public:
 | 
			
		||||
  using Base::Base;
 | 
			
		||||
 | 
			
		||||
  /// Return an instance of the Value type.
 | 
			
		||||
  static ValueType get(Context &context);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace ast
 | 
			
		||||
} // namespace pdll
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
namespace llvm {
 | 
			
		||||
template <> struct DenseMapInfo<mlir::pdll::ast::Type> {
 | 
			
		||||
  static mlir::pdll::ast::Type getEmptyKey() {
 | 
			
		||||
    void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
 | 
			
		||||
    return mlir::pdll::ast::Type(
 | 
			
		||||
        static_cast<mlir::pdll::ast::Type::Storage *>(pointer));
 | 
			
		||||
  }
 | 
			
		||||
  static mlir::pdll::ast::Type getTombstoneKey() {
 | 
			
		||||
    void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
 | 
			
		||||
    return mlir::pdll::ast::Type(
 | 
			
		||||
        static_cast<mlir::pdll::ast::Type::Storage *>(pointer));
 | 
			
		||||
  }
 | 
			
		||||
  static unsigned getHashValue(mlir::pdll::ast::Type val) {
 | 
			
		||||
    return llvm::hash_value(val.getImpl());
 | 
			
		||||
  }
 | 
			
		||||
  static bool isEqual(mlir::pdll::ast::Type lhs, mlir::pdll::ast::Type rhs) {
 | 
			
		||||
    return lhs == rhs;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
} // namespace llvm
 | 
			
		||||
 | 
			
		||||
#endif // MLIR_TOOLS_PDLL_AST_TYPES_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,33 @@
 | 
			
		|||
//===- Parser.h - MLIR PDLL Frontend Parser ---------------------*- C++ -*-===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#ifndef MLIR_TOOLS_PDLL_PARSER_PARSER_H_
 | 
			
		||||
#define MLIR_TOOLS_PDLL_PARSER_PARSER_H_
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
 | 
			
		||||
#include "mlir/Support/LogicalResult.h"
 | 
			
		||||
 | 
			
		||||
namespace llvm {
 | 
			
		||||
class SourceMgr;
 | 
			
		||||
} // namespace llvm
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace pdll {
 | 
			
		||||
namespace ast {
 | 
			
		||||
class Context;
 | 
			
		||||
class Module;
 | 
			
		||||
} // namespace ast
 | 
			
		||||
 | 
			
		||||
/// Parse an AST module from the main file of the given source manager.
 | 
			
		||||
FailureOr<ast::Module *> parsePDLAST(ast::Context &ctx,
 | 
			
		||||
                                     llvm::SourceMgr &sourceMgr);
 | 
			
		||||
} // namespace pdll
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif // MLIR_TOOLS_PDLL_PARSER_PARSER_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -1,2 +1,3 @@
 | 
			
		|||
add_subdirectory(mlir-lsp-server)
 | 
			
		||||
add_subdirectory(mlir-reduce)
 | 
			
		||||
add_subdirectory(PDLL)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,10 @@
 | 
			
		|||
add_mlir_library(MLIRPDLLAST
 | 
			
		||||
  Context.cpp
 | 
			
		||||
  Diagnostic.cpp
 | 
			
		||||
  NodePrinter.cpp
 | 
			
		||||
  Nodes.cpp
 | 
			
		||||
  Types.cpp
 | 
			
		||||
 | 
			
		||||
  LINK_LIBS PUBLIC
 | 
			
		||||
  MLIRSupport
 | 
			
		||||
  )
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,23 @@
 | 
			
		|||
//===- Context.cpp --------------------------------------------------------===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Context.h"
 | 
			
		||||
#include "TypeDetail.h"
 | 
			
		||||
 | 
			
		||||
using namespace mlir;
 | 
			
		||||
using namespace mlir::pdll::ast;
 | 
			
		||||
 | 
			
		||||
Context::Context() {
 | 
			
		||||
  typeUniquer.registerSingletonStorageType<detail::AttributeTypeStorage>();
 | 
			
		||||
  typeUniquer.registerSingletonStorageType<detail::ConstraintTypeStorage>();
 | 
			
		||||
  typeUniquer.registerSingletonStorageType<detail::TypeTypeStorage>();
 | 
			
		||||
  typeUniquer.registerSingletonStorageType<detail::ValueTypeStorage>();
 | 
			
		||||
 | 
			
		||||
  typeUniquer.registerParametricStorageType<detail::OperationTypeStorage>();
 | 
			
		||||
  typeUniquer.registerParametricStorageType<detail::RangeTypeStorage>();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,26 @@
 | 
			
		|||
//===- Diagnostic.cpp -----------------------------------------------------===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
 | 
			
		||||
 | 
			
		||||
using namespace mlir;
 | 
			
		||||
using namespace mlir::pdll::ast;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// InFlightDiagnostic
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
void InFlightDiagnostic::report() {
 | 
			
		||||
  // If this diagnostic is still inflight and it hasn't been abandoned, then
 | 
			
		||||
  // report it.
 | 
			
		||||
  if (isInFlight()) {
 | 
			
		||||
    owner->report(std::move(*impl));
 | 
			
		||||
    owner = nullptr;
 | 
			
		||||
  }
 | 
			
		||||
  impl.reset();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,266 @@
 | 
			
		|||
//===- NodePrinter.cpp ----------------------------------------------------===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Context.h"
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Nodes.h"
 | 
			
		||||
#include "llvm/ADT/StringExtras.h"
 | 
			
		||||
#include "llvm/ADT/TypeSwitch.h"
 | 
			
		||||
#include "llvm/Support/SaveAndRestore.h"
 | 
			
		||||
#include "llvm/Support/ScopedPrinter.h"
 | 
			
		||||
 | 
			
		||||
using namespace mlir;
 | 
			
		||||
using namespace mlir::pdll::ast;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// NodePrinter
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
class NodePrinter {
 | 
			
		||||
public:
 | 
			
		||||
  NodePrinter(raw_ostream &os) : os(os) {}
 | 
			
		||||
 | 
			
		||||
  /// Print the given type to the stream.
 | 
			
		||||
  void print(Type type);
 | 
			
		||||
 | 
			
		||||
  /// Print the given node to the stream.
 | 
			
		||||
  void print(const Node *node);
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  /// Print a range containing children of a node.
 | 
			
		||||
  template <typename RangeT,
 | 
			
		||||
            std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
 | 
			
		||||
                * = nullptr>
 | 
			
		||||
  void printChildren(RangeT &&range) {
 | 
			
		||||
    if (llvm::empty(range))
 | 
			
		||||
      return;
 | 
			
		||||
 | 
			
		||||
    // Print the first N-1 elements with a prefix of "|-".
 | 
			
		||||
    auto it = std::begin(range);
 | 
			
		||||
    for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
 | 
			
		||||
      print(*it);
 | 
			
		||||
 | 
			
		||||
    // Print the last element.
 | 
			
		||||
    elementIndentStack.back() = true;
 | 
			
		||||
    print(*it);
 | 
			
		||||
  }
 | 
			
		||||
  template <typename RangeT, typename... OthersT,
 | 
			
		||||
            std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
 | 
			
		||||
                * = nullptr>
 | 
			
		||||
  void printChildren(RangeT &&range, OthersT &&...others) {
 | 
			
		||||
    printChildren(ArrayRef<const Node *>({range, others...}));
 | 
			
		||||
  }
 | 
			
		||||
  /// Print a range containing children of a node, nesting the children under
 | 
			
		||||
  /// the given label.
 | 
			
		||||
  template <typename RangeT>
 | 
			
		||||
  void printChildren(StringRef label, RangeT &&range) {
 | 
			
		||||
    if (llvm::empty(range))
 | 
			
		||||
      return;
 | 
			
		||||
    elementIndentStack.reserve(elementIndentStack.size() + 1);
 | 
			
		||||
    llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true);
 | 
			
		||||
 | 
			
		||||
    printIndent();
 | 
			
		||||
    os << label << "`\n";
 | 
			
		||||
    elementIndentStack.push_back(/*isLastElt*/ false);
 | 
			
		||||
    printChildren(std::forward<RangeT>(range));
 | 
			
		||||
    elementIndentStack.pop_back();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Print the given derived node to the stream.
 | 
			
		||||
  void printImpl(const CompoundStmt *stmt);
 | 
			
		||||
  void printImpl(const EraseStmt *stmt);
 | 
			
		||||
  void printImpl(const LetStmt *stmt);
 | 
			
		||||
 | 
			
		||||
  void printImpl(const DeclRefExpr *expr);
 | 
			
		||||
  void printImpl(const MemberAccessExpr *expr);
 | 
			
		||||
 | 
			
		||||
  void printImpl(const AttrConstraintDecl *decl);
 | 
			
		||||
  void printImpl(const OpConstraintDecl *decl);
 | 
			
		||||
  void printImpl(const TypeConstraintDecl *decl);
 | 
			
		||||
  void printImpl(const TypeRangeConstraintDecl *decl);
 | 
			
		||||
  void printImpl(const ValueConstraintDecl *decl);
 | 
			
		||||
  void printImpl(const ValueRangeConstraintDecl *decl);
 | 
			
		||||
  void printImpl(const OpNameDecl *decl);
 | 
			
		||||
  void printImpl(const PatternDecl *decl);
 | 
			
		||||
  void printImpl(const VariableDecl *decl);
 | 
			
		||||
  void printImpl(const Module *module);
 | 
			
		||||
 | 
			
		||||
  /// Print the current indent stack.
 | 
			
		||||
  void printIndent() {
 | 
			
		||||
    if (elementIndentStack.empty())
 | 
			
		||||
      return;
 | 
			
		||||
 | 
			
		||||
    for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back())
 | 
			
		||||
      os << (isLastElt ? "  " : " |");
 | 
			
		||||
    os << (elementIndentStack.back() ? " `" : " |");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// The raw output stream.
 | 
			
		||||
  raw_ostream &os;
 | 
			
		||||
 | 
			
		||||
  /// A stack of indents and a flag indicating if the current element being
 | 
			
		||||
  /// printed at that indent is the last element.
 | 
			
		||||
  SmallVector<bool> elementIndentStack;
 | 
			
		||||
};
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
void NodePrinter::print(Type type) {
 | 
			
		||||
  // Protect against invalid inputs.
 | 
			
		||||
  if (!type) {
 | 
			
		||||
    os << "Type<NULL>";
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TypeSwitch<Type>(type)
 | 
			
		||||
      .Case([&](AttributeType) { os << "Attr"; })
 | 
			
		||||
      .Case([&](ConstraintType) { os << "Constraint"; })
 | 
			
		||||
      .Case([&](OperationType type) {
 | 
			
		||||
        os << "Op";
 | 
			
		||||
        if (Optional<StringRef> name = type.getName())
 | 
			
		||||
          os << "<" << *name << ">";
 | 
			
		||||
      })
 | 
			
		||||
      .Case([&](RangeType type) {
 | 
			
		||||
        print(type.getElementType());
 | 
			
		||||
        os << "Range";
 | 
			
		||||
      })
 | 
			
		||||
      .Case([&](TypeType) { os << "Type"; })
 | 
			
		||||
      .Case([&](ValueType) { os << "Value"; })
 | 
			
		||||
      .Default([](Type) { llvm_unreachable("unknown AST type"); });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::print(const Node *node) {
 | 
			
		||||
  printIndent();
 | 
			
		||||
  os << "-";
 | 
			
		||||
 | 
			
		||||
  elementIndentStack.push_back(/*isLastElt*/ false);
 | 
			
		||||
  TypeSwitch<const Node *>(node)
 | 
			
		||||
      .Case<
 | 
			
		||||
          // Statements.
 | 
			
		||||
          const CompoundStmt, const EraseStmt, const LetStmt,
 | 
			
		||||
 | 
			
		||||
          // Expressions.
 | 
			
		||||
          const DeclRefExpr, const MemberAccessExpr,
 | 
			
		||||
 | 
			
		||||
          // Decls.
 | 
			
		||||
          const AttrConstraintDecl, const OpConstraintDecl,
 | 
			
		||||
          const TypeConstraintDecl, const TypeRangeConstraintDecl,
 | 
			
		||||
          const ValueConstraintDecl, const ValueRangeConstraintDecl,
 | 
			
		||||
          const OpNameDecl, const PatternDecl, const VariableDecl,
 | 
			
		||||
 | 
			
		||||
          const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
 | 
			
		||||
      .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
 | 
			
		||||
  elementIndentStack.pop_back();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const CompoundStmt *stmt) {
 | 
			
		||||
  os << "CompoundStmt " << stmt << "\n";
 | 
			
		||||
  printChildren(stmt->getChildren());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const EraseStmt *stmt) {
 | 
			
		||||
  os << "EraseStmt " << stmt << "\n";
 | 
			
		||||
  printChildren(stmt->getRootOpExpr());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const LetStmt *stmt) {
 | 
			
		||||
  os << "LetStmt " << stmt << "\n";
 | 
			
		||||
  printChildren(stmt->getVarDecl());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const DeclRefExpr *expr) {
 | 
			
		||||
  os << "DeclRefExpr " << expr << " Type<";
 | 
			
		||||
  print(expr->getType());
 | 
			
		||||
  os << ">\n";
 | 
			
		||||
  printChildren(expr->getDecl());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const MemberAccessExpr *expr) {
 | 
			
		||||
  os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
 | 
			
		||||
     << "> Type<";
 | 
			
		||||
  print(expr->getType());
 | 
			
		||||
  os << ">\n";
 | 
			
		||||
  printChildren(expr->getParentExpr());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
 | 
			
		||||
  os << "AttrConstraintDecl " << decl << "\n";
 | 
			
		||||
  if (const auto *typeExpr = decl->getTypeExpr())
 | 
			
		||||
    printChildren(typeExpr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const OpConstraintDecl *decl) {
 | 
			
		||||
  os << "OpConstraintDecl " << decl << "\n";
 | 
			
		||||
  printChildren(decl->getNameDecl());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
 | 
			
		||||
  os << "TypeConstraintDecl " << decl << "\n";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
 | 
			
		||||
  os << "TypeRangeConstraintDecl " << decl << "\n";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
 | 
			
		||||
  os << "ValueConstraintDecl " << decl << "\n";
 | 
			
		||||
  if (const auto *typeExpr = decl->getTypeExpr())
 | 
			
		||||
    printChildren(typeExpr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
 | 
			
		||||
  os << "ValueRangeConstraintDecl " << decl << "\n";
 | 
			
		||||
  if (const auto *typeExpr = decl->getTypeExpr())
 | 
			
		||||
    printChildren(typeExpr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const OpNameDecl *decl) {
 | 
			
		||||
  os << "OpNameDecl " << decl;
 | 
			
		||||
  if (Optional<StringRef> name = decl->getName())
 | 
			
		||||
    os << " Name<" << name << ">";
 | 
			
		||||
  os << "\n";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const PatternDecl *decl) {
 | 
			
		||||
  os << "PatternDecl " << decl;
 | 
			
		||||
  if (const Name *name = decl->getName())
 | 
			
		||||
    os << " Name<" << name->getName() << ">";
 | 
			
		||||
  if (Optional<uint16_t> benefit = decl->getBenefit())
 | 
			
		||||
    os << " Benefit<" << *benefit << ">";
 | 
			
		||||
  if (decl->hasBoundedRewriteRecursion())
 | 
			
		||||
    os << " Recursion";
 | 
			
		||||
 | 
			
		||||
  os << "\n";
 | 
			
		||||
  printChildren(decl->getBody());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const VariableDecl *decl) {
 | 
			
		||||
  os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
 | 
			
		||||
     << "> Type<";
 | 
			
		||||
  print(decl->getType());
 | 
			
		||||
  os << ">\n";
 | 
			
		||||
  if (Expr *initExpr = decl->getInitExpr())
 | 
			
		||||
    printChildren(initExpr);
 | 
			
		||||
 | 
			
		||||
  auto constraints =
 | 
			
		||||
      llvm::map_range(decl->getConstraints(),
 | 
			
		||||
                      [](const ConstraintRef &ref) { return ref.constraint; });
 | 
			
		||||
  printChildren("Constraints", constraints);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void NodePrinter::printImpl(const Module *module) {
 | 
			
		||||
  os << "Module " << module << "\n";
 | 
			
		||||
  printChildren(module->getChildren());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Entry point
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
 | 
			
		||||
 | 
			
		||||
void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,231 @@
 | 
			
		|||
//===- Nodes.cpp ----------------------------------------------------------===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Nodes.h"
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Context.h"
 | 
			
		||||
#include "llvm/ADT/SmallPtrSet.h"
 | 
			
		||||
 | 
			
		||||
using namespace mlir;
 | 
			
		||||
using namespace mlir::pdll::ast;
 | 
			
		||||
 | 
			
		||||
/// Copy a string reference into the context with a null terminator.
 | 
			
		||||
static StringRef copyStringWithNull(Context &ctx, StringRef str) {
 | 
			
		||||
  if (str.empty())
 | 
			
		||||
    return str;
 | 
			
		||||
 | 
			
		||||
  char *data = ctx.getAllocator().Allocate<char>(str.size() + 1);
 | 
			
		||||
  std::copy(str.begin(), str.end(), data);
 | 
			
		||||
  data[str.size()] = 0;
 | 
			
		||||
  return StringRef(data, str.size());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Name
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
const Name &Name::create(Context &ctx, StringRef name, llvm::SMRange location) {
 | 
			
		||||
  return *new (ctx.getAllocator().Allocate<Name>())
 | 
			
		||||
      Name(copyStringWithNull(ctx, name), location);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// DeclScope
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
void DeclScope::add(Decl *decl) {
 | 
			
		||||
  const Name *name = decl->getName();
 | 
			
		||||
  assert(name && "expected a named decl");
 | 
			
		||||
  assert(!decls.count(name->getName()) && "decl with this name already exists");
 | 
			
		||||
  decls.try_emplace(name->getName(), decl);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Decl *DeclScope::lookup(StringRef name) {
 | 
			
		||||
  if (Decl *decl = decls.lookup(name))
 | 
			
		||||
    return decl;
 | 
			
		||||
  return parent ? parent->lookup(name) : nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// CompoundStmt
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
CompoundStmt *CompoundStmt::create(Context &ctx, llvm::SMRange loc,
 | 
			
		||||
                                   ArrayRef<Stmt *> children) {
 | 
			
		||||
  unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size());
 | 
			
		||||
  void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt));
 | 
			
		||||
 | 
			
		||||
  CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size());
 | 
			
		||||
  std::uninitialized_copy(children.begin(), children.end(),
 | 
			
		||||
                          stmt->getChildren().begin());
 | 
			
		||||
  return stmt;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// LetStmt
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
LetStmt *LetStmt::create(Context &ctx, llvm::SMRange loc,
 | 
			
		||||
                         VariableDecl *varDecl) {
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// OpRewriteStmt
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// EraseStmt
 | 
			
		||||
 | 
			
		||||
EraseStmt *EraseStmt::create(Context &ctx, llvm::SMRange loc, Expr *rootOp) {
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// DeclRefExpr
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
DeclRefExpr *DeclRefExpr::create(Context &ctx, llvm::SMRange loc, Decl *decl,
 | 
			
		||||
                                 Type type) {
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<DeclRefExpr>())
 | 
			
		||||
      DeclRefExpr(loc, decl, type);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// MemberAccessExpr
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
MemberAccessExpr *MemberAccessExpr::create(Context &ctx, llvm::SMRange loc,
 | 
			
		||||
                                           const Expr *parentExpr,
 | 
			
		||||
                                           StringRef memberName, Type type) {
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr(
 | 
			
		||||
      loc, parentExpr, memberName.copy(ctx.getAllocator()), type);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// AttrConstraintDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, llvm::SMRange loc,
 | 
			
		||||
                                               Expr *typeExpr) {
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<AttrConstraintDecl>())
 | 
			
		||||
      AttrConstraintDecl(loc, typeExpr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// OpConstraintDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
OpConstraintDecl *OpConstraintDecl::create(Context &ctx, llvm::SMRange loc,
 | 
			
		||||
                                           const OpNameDecl *nameDecl) {
 | 
			
		||||
  if (!nameDecl)
 | 
			
		||||
    nameDecl = OpNameDecl::create(ctx, llvm::SMRange());
 | 
			
		||||
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<OpConstraintDecl>())
 | 
			
		||||
      OpConstraintDecl(loc, nameDecl);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Optional<StringRef> OpConstraintDecl::getName() const {
 | 
			
		||||
  return getNameDecl()->getName();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TypeConstraintDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx,
 | 
			
		||||
                                               llvm::SMRange loc) {
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<TypeConstraintDecl>())
 | 
			
		||||
      TypeConstraintDecl(loc);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TypeRangeConstraintDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx,
 | 
			
		||||
                                                         llvm::SMRange loc) {
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>())
 | 
			
		||||
      TypeRangeConstraintDecl(loc);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ValueConstraintDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
ValueConstraintDecl *
 | 
			
		||||
ValueConstraintDecl::create(Context &ctx, llvm::SMRange loc, Expr *typeExpr) {
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<ValueConstraintDecl>())
 | 
			
		||||
      ValueConstraintDecl(loc, typeExpr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ValueRangeConstraintDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
 | 
			
		||||
                                                           llvm::SMRange loc,
 | 
			
		||||
                                                           Expr *typeExpr) {
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>())
 | 
			
		||||
      ValueRangeConstraintDecl(loc, typeExpr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// OpNameDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) {
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name);
 | 
			
		||||
}
 | 
			
		||||
OpNameDecl *OpNameDecl::create(Context &ctx, llvm::SMRange loc) {
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// PatternDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
PatternDecl *PatternDecl::create(Context &ctx, llvm::SMRange loc,
 | 
			
		||||
                                 const Name *name, Optional<uint16_t> benefit,
 | 
			
		||||
                                 bool hasBoundedRecursion,
 | 
			
		||||
                                 const CompoundStmt *body) {
 | 
			
		||||
  return new (ctx.getAllocator().Allocate<PatternDecl>())
 | 
			
		||||
      PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// VariableDecl
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type,
 | 
			
		||||
                                   Expr *initExpr,
 | 
			
		||||
                                   ArrayRef<ConstraintRef> constraints) {
 | 
			
		||||
  unsigned allocSize =
 | 
			
		||||
      VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size());
 | 
			
		||||
  void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl));
 | 
			
		||||
 | 
			
		||||
  VariableDecl *varDecl =
 | 
			
		||||
      new (rawData) VariableDecl(name, type, initExpr, constraints.size());
 | 
			
		||||
  std::uninitialized_copy(constraints.begin(), constraints.end(),
 | 
			
		||||
                          varDecl->getConstraints().begin());
 | 
			
		||||
  return varDecl;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Module
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
Module *Module::create(Context &ctx, llvm::SMLoc loc,
 | 
			
		||||
                       ArrayRef<Decl *> children) {
 | 
			
		||||
  unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size());
 | 
			
		||||
  void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module));
 | 
			
		||||
 | 
			
		||||
  Module *module = new (rawData) Module(loc, children.size());
 | 
			
		||||
  std::uninitialized_copy(children.begin(), children.end(),
 | 
			
		||||
                          module->getChildren().begin());
 | 
			
		||||
  return module;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,116 @@
 | 
			
		|||
//===- TypeDetail.h ---------------------------------------------*- C++ -*-===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#ifndef LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_
 | 
			
		||||
#define LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_
 | 
			
		||||
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Types.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace pdll {
 | 
			
		||||
namespace ast {
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Type
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
struct Type::Storage : public StorageUniquer::BaseStorage {
 | 
			
		||||
  Storage(TypeID typeID) : typeID(typeID) {}
 | 
			
		||||
 | 
			
		||||
  /// The type identifier for the derived type class.
 | 
			
		||||
  TypeID typeID;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
/// A utility CRTP base class that defines many of the necessary utilities for
 | 
			
		||||
/// defining a PDLL AST Type.
 | 
			
		||||
template <typename ConcreteT, typename KeyT = void>
 | 
			
		||||
struct TypeStorageBase : public Type::Storage {
 | 
			
		||||
  using KeyTy = KeyT;
 | 
			
		||||
  using Base = TypeStorageBase<ConcreteT, KeyT>;
 | 
			
		||||
 | 
			
		||||
  /// Construct an instance with the given storage allocator.
 | 
			
		||||
  static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
 | 
			
		||||
                              const KeyTy &key) {
 | 
			
		||||
    return new (alloc.allocate<ConcreteT>()) ConcreteT(key);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Utility methods required by the storage allocator.
 | 
			
		||||
  bool operator==(const KeyTy &key) const { return this->key == key; }
 | 
			
		||||
 | 
			
		||||
  /// Return the key value of this storage class.
 | 
			
		||||
  const KeyTy &getValue() const { return key; }
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  TypeStorageBase(KeyTy key)
 | 
			
		||||
      : Type::Storage(TypeID::get<ConcreteT>()), key(key) {}
 | 
			
		||||
 | 
			
		||||
  KeyTy key;
 | 
			
		||||
};
 | 
			
		||||
/// A specialization of the storage base for singleton types.
 | 
			
		||||
template <typename ConcreteT>
 | 
			
		||||
struct TypeStorageBase<ConcreteT, void> : public Type::Storage {
 | 
			
		||||
  using Base = TypeStorageBase<ConcreteT, void>;
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
  TypeStorageBase() : Type::Storage(TypeID::get<ConcreteT>()) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// AttributeType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
struct AttributeTypeStorage : public TypeStorageBase<AttributeTypeStorage> {};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ConstraintType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
struct ConstraintTypeStorage : public TypeStorageBase<ConstraintTypeStorage> {};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// OperationType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
struct OperationTypeStorage
 | 
			
		||||
    : public TypeStorageBase<OperationTypeStorage, StringRef> {
 | 
			
		||||
  using Base::Base;
 | 
			
		||||
 | 
			
		||||
  static OperationTypeStorage *
 | 
			
		||||
  construct(StorageUniquer::StorageAllocator &alloc, StringRef key) {
 | 
			
		||||
    return new (alloc.allocate<OperationTypeStorage>())
 | 
			
		||||
        OperationTypeStorage(alloc.copyInto(key));
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// RangeType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
struct RangeTypeStorage : public TypeStorageBase<RangeTypeStorage, Type> {
 | 
			
		||||
  using Base::Base;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TypeType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
struct TypeTypeStorage : public TypeStorageBase<TypeTypeStorage> {};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ValueType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
struct ValueTypeStorage : public TypeStorageBase<ValueTypeStorage> {};
 | 
			
		||||
 | 
			
		||||
} // namespace detail
 | 
			
		||||
} // namespace ast
 | 
			
		||||
} // namespace pdll
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif // LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,124 @@
 | 
			
		|||
//===- Types.cpp ----------------------------------------------------------===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Types.h"
 | 
			
		||||
#include "TypeDetail.h"
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Context.h"
 | 
			
		||||
 | 
			
		||||
using namespace mlir;
 | 
			
		||||
using namespace mlir::pdll::ast;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Type
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
TypeID Type::getTypeID() const { return impl->typeID; }
 | 
			
		||||
 | 
			
		||||
Type Type::refineWith(Type other) const {
 | 
			
		||||
  if (*this == other)
 | 
			
		||||
    return *this;
 | 
			
		||||
 | 
			
		||||
  // Operation types are compatible if the operation names don't conflict.
 | 
			
		||||
  if (auto opTy = dyn_cast<OperationType>()) {
 | 
			
		||||
    auto otherOpTy = other.dyn_cast<ast::OperationType>();
 | 
			
		||||
    if (!otherOpTy)
 | 
			
		||||
      return nullptr;
 | 
			
		||||
    if (!otherOpTy.getName())
 | 
			
		||||
      return *this;
 | 
			
		||||
    if (!opTy.getName())
 | 
			
		||||
      return other;
 | 
			
		||||
 | 
			
		||||
    return nullptr;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// AttributeType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
AttributeType AttributeType::get(Context &context) {
 | 
			
		||||
  return context.getTypeUniquer().get<ImplTy>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ConstraintType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
ConstraintType ConstraintType::get(Context &context) {
 | 
			
		||||
  return context.getTypeUniquer().get<ImplTy>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// OperationType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
OperationType OperationType::get(Context &context, Optional<StringRef> name) {
 | 
			
		||||
  return context.getTypeUniquer().get<ImplTy>(
 | 
			
		||||
      /*initFn=*/function_ref<void(ImplTy *)>(), name.getValueOr(""));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Optional<StringRef> OperationType::getName() const {
 | 
			
		||||
  StringRef name = getImplAs<ImplTy>()->getValue();
 | 
			
		||||
  return name.empty() ? Optional<StringRef>() : Optional<StringRef>(name);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// RangeType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
RangeType RangeType::get(Context &context, Type elementType) {
 | 
			
		||||
  return context.getTypeUniquer().get<ImplTy>(
 | 
			
		||||
      /*initFn=*/function_ref<void(ImplTy *)>(), elementType);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Type RangeType::getElementType() const {
 | 
			
		||||
  return getImplAs<ImplTy>()->getValue();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TypeRangeType
 | 
			
		||||
 | 
			
		||||
bool TypeRangeType::classof(Type type) {
 | 
			
		||||
  RangeType range = type.dyn_cast<RangeType>();
 | 
			
		||||
  return range && range.getElementType().isa<TypeType>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TypeRangeType TypeRangeType::get(Context &context) {
 | 
			
		||||
  return RangeType::get(context, TypeType::get(context)).cast<TypeRangeType>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ValueRangeType
 | 
			
		||||
 | 
			
		||||
bool ValueRangeType::classof(Type type) {
 | 
			
		||||
  RangeType range = type.dyn_cast<RangeType>();
 | 
			
		||||
  return range && range.getElementType().isa<ValueType>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ValueRangeType ValueRangeType::get(Context &context) {
 | 
			
		||||
  return RangeType::get(context, ValueType::get(context))
 | 
			
		||||
      .cast<ValueRangeType>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// TypeType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
TypeType TypeType::get(Context &context) {
 | 
			
		||||
  return context.getTypeUniquer().get<ImplTy>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// ValueType
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
ValueType ValueType::get(Context &context) {
 | 
			
		||||
  return context.getTypeUniquer().get<ImplTy>();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,2 @@
 | 
			
		|||
add_subdirectory(AST)
 | 
			
		||||
add_subdirectory(Parser)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,8 @@
 | 
			
		|||
add_mlir_library(MLIRPDLLParser
 | 
			
		||||
  Lexer.cpp
 | 
			
		||||
  Parser.cpp
 | 
			
		||||
 | 
			
		||||
  LINK_LIBS PUBLIC
 | 
			
		||||
  MLIRPDLLAST
 | 
			
		||||
  MLIRSupport
 | 
			
		||||
  )
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,366 @@
 | 
			
		|||
//===- Lexer.cpp ----------------------------------------------------------===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "Lexer.h"
 | 
			
		||||
#include "mlir/Support/LogicalResult.h"
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
 | 
			
		||||
#include "llvm/ADT/StringExtras.h"
 | 
			
		||||
#include "llvm/ADT/StringSwitch.h"
 | 
			
		||||
#include "llvm/Support/SourceMgr.h"
 | 
			
		||||
 | 
			
		||||
using namespace mlir;
 | 
			
		||||
using namespace mlir::pdll;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Token
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
std::string Token::getStringValue() const {
 | 
			
		||||
  assert(getKind() == string || getKind() == string_block);
 | 
			
		||||
 | 
			
		||||
  // Start by dropping the quotes.
 | 
			
		||||
  StringRef bytes = getSpelling().drop_front().drop_back();
 | 
			
		||||
  if (is(string_block)) bytes = bytes.drop_front().drop_back();
 | 
			
		||||
 | 
			
		||||
  std::string result;
 | 
			
		||||
  result.reserve(bytes.size());
 | 
			
		||||
  for (unsigned i = 0, e = bytes.size(); i != e;) {
 | 
			
		||||
    auto c = bytes[i++];
 | 
			
		||||
    if (c != '\\') {
 | 
			
		||||
      result.push_back(c);
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    assert(i + 1 <= e && "invalid string should be caught by lexer");
 | 
			
		||||
    auto c1 = bytes[i++];
 | 
			
		||||
    switch (c1) {
 | 
			
		||||
      case '"':
 | 
			
		||||
      case '\\':
 | 
			
		||||
        result.push_back(c1);
 | 
			
		||||
        continue;
 | 
			
		||||
      case 'n':
 | 
			
		||||
        result.push_back('\n');
 | 
			
		||||
        continue;
 | 
			
		||||
      case 't':
 | 
			
		||||
        result.push_back('\t');
 | 
			
		||||
        continue;
 | 
			
		||||
      default:
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    assert(i + 1 <= e && "invalid string should be caught by lexer");
 | 
			
		||||
    auto c2 = bytes[i++];
 | 
			
		||||
 | 
			
		||||
    assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
 | 
			
		||||
    result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Lexer
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine)
 | 
			
		||||
    : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false) {
 | 
			
		||||
  curBufferID = mgr.getMainFileID();
 | 
			
		||||
  curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
 | 
			
		||||
  curPtr = curBuffer.begin();
 | 
			
		||||
 | 
			
		||||
  // If the diag engine has no handler, add a default that emits to the
 | 
			
		||||
  // SourceMgr.
 | 
			
		||||
  if (!diagEngine.getHandlerFn()) {
 | 
			
		||||
    diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
 | 
			
		||||
      srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
 | 
			
		||||
                          diag.getMessage());
 | 
			
		||||
      for (const ast::Diagnostic ¬e : diag.getNotes())
 | 
			
		||||
        srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
 | 
			
		||||
                            note.getMessage());
 | 
			
		||||
    });
 | 
			
		||||
    addedHandlerToDiagEngine = true;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Lexer::~Lexer() {
 | 
			
		||||
  if (addedHandlerToDiagEngine) diagEngine.setHandlerFn(nullptr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LogicalResult Lexer::pushInclude(StringRef filename) {
 | 
			
		||||
  std::string includedFile;
 | 
			
		||||
  int bufferID = srcMgr.AddIncludeFile(
 | 
			
		||||
      filename.str(), llvm::SMLoc::getFromPointer(curPtr), includedFile);
 | 
			
		||||
  if (!bufferID) return failure();
 | 
			
		||||
 | 
			
		||||
  curBufferID = bufferID;
 | 
			
		||||
  curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
 | 
			
		||||
  curPtr = curBuffer.begin();
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Token Lexer::emitError(llvm::SMRange loc, const Twine &msg) {
 | 
			
		||||
  diagEngine.emitError(loc, msg);
 | 
			
		||||
  return formToken(Token::error, loc.Start.getPointer());
 | 
			
		||||
}
 | 
			
		||||
Token Lexer::emitErrorAndNote(llvm::SMRange loc, const Twine &msg,
 | 
			
		||||
                              llvm::SMRange noteLoc, const Twine ¬e) {
 | 
			
		||||
  diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
 | 
			
		||||
  return formToken(Token::error, loc.Start.getPointer());
 | 
			
		||||
}
 | 
			
		||||
Token Lexer::emitError(const char *loc, const Twine &msg) {
 | 
			
		||||
  return emitError(llvm::SMRange(llvm::SMLoc::getFromPointer(loc),
 | 
			
		||||
                                 llvm::SMLoc::getFromPointer(loc + 1)),
 | 
			
		||||
                   msg);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int Lexer::getNextChar() {
 | 
			
		||||
  char curChar = *curPtr++;
 | 
			
		||||
  switch (curChar) {
 | 
			
		||||
    default:
 | 
			
		||||
      return static_cast<unsigned char>(curChar);
 | 
			
		||||
    case 0: {
 | 
			
		||||
      // A nul character in the stream is either the end of the current buffer
 | 
			
		||||
      // or a random nul in the file. Disambiguate that here.
 | 
			
		||||
      if (curPtr - 1 != curBuffer.end()) return 0;
 | 
			
		||||
 | 
			
		||||
      // Otherwise, return end of file.
 | 
			
		||||
      --curPtr;
 | 
			
		||||
      return EOF;
 | 
			
		||||
    }
 | 
			
		||||
    case '\n':
 | 
			
		||||
    case '\r':
 | 
			
		||||
      // Handle the newline character by ignoring it and incrementing the line
 | 
			
		||||
      // count. However, be careful about 'dos style' files with \n\r in them.
 | 
			
		||||
      // Only treat a \n\r or \r\n as a single line.
 | 
			
		||||
      if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
 | 
			
		||||
        ++curPtr;
 | 
			
		||||
      return '\n';
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Token Lexer::lexToken() {
 | 
			
		||||
  while (true) {
 | 
			
		||||
    const char *tokStart = curPtr;
 | 
			
		||||
 | 
			
		||||
    // This always consumes at least one character.
 | 
			
		||||
    int curChar = getNextChar();
 | 
			
		||||
    switch (curChar) {
 | 
			
		||||
      default:
 | 
			
		||||
        // Handle identifiers: [a-zA-Z_]
 | 
			
		||||
        if (isalpha(curChar) || curChar == '_') return lexIdentifier(tokStart);
 | 
			
		||||
 | 
			
		||||
        // Unknown character, emit an error.
 | 
			
		||||
        return emitError(tokStart, "unexpected character");
 | 
			
		||||
      case EOF: {
 | 
			
		||||
        // Return EOF denoting the end of lexing.
 | 
			
		||||
        Token eof = formToken(Token::eof, tokStart);
 | 
			
		||||
 | 
			
		||||
        // Check to see if we are in an included file.
 | 
			
		||||
        llvm::SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
 | 
			
		||||
        if (parentIncludeLoc.isValid()) {
 | 
			
		||||
          curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
 | 
			
		||||
          curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
 | 
			
		||||
          curPtr = parentIncludeLoc.getPointer();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return eof;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Lex punctuation.
 | 
			
		||||
      case '-':
 | 
			
		||||
        if (*curPtr == '>') {
 | 
			
		||||
          ++curPtr;
 | 
			
		||||
          return formToken(Token::arrow, tokStart);
 | 
			
		||||
        }
 | 
			
		||||
        return emitError(tokStart, "unexpected character");
 | 
			
		||||
      case ':':
 | 
			
		||||
        return formToken(Token::colon, tokStart);
 | 
			
		||||
      case ',':
 | 
			
		||||
        return formToken(Token::comma, tokStart);
 | 
			
		||||
      case '.':
 | 
			
		||||
        return formToken(Token::dot, tokStart);
 | 
			
		||||
      case '=':
 | 
			
		||||
        if (*curPtr == '>') {
 | 
			
		||||
          ++curPtr;
 | 
			
		||||
          return formToken(Token::equal_arrow, tokStart);
 | 
			
		||||
        }
 | 
			
		||||
        return formToken(Token::equal, tokStart);
 | 
			
		||||
      case ';':
 | 
			
		||||
        return formToken(Token::semicolon, tokStart);
 | 
			
		||||
      case '[':
 | 
			
		||||
        if (*curPtr == '{') {
 | 
			
		||||
          ++curPtr;
 | 
			
		||||
          return lexString(tokStart, /*isStringBlock=*/true);
 | 
			
		||||
        }
 | 
			
		||||
        return formToken(Token::l_square, tokStart);
 | 
			
		||||
      case ']':
 | 
			
		||||
        return formToken(Token::r_square, tokStart);
 | 
			
		||||
 | 
			
		||||
      case '<':
 | 
			
		||||
        return formToken(Token::less, tokStart);
 | 
			
		||||
      case '>':
 | 
			
		||||
        return formToken(Token::greater, tokStart);
 | 
			
		||||
      case '{':
 | 
			
		||||
        return formToken(Token::l_brace, tokStart);
 | 
			
		||||
      case '}':
 | 
			
		||||
        return formToken(Token::r_brace, tokStart);
 | 
			
		||||
      case '(':
 | 
			
		||||
        return formToken(Token::l_paren, tokStart);
 | 
			
		||||
      case ')':
 | 
			
		||||
        return formToken(Token::r_paren, tokStart);
 | 
			
		||||
      case '/':
 | 
			
		||||
        if (*curPtr == '/') {
 | 
			
		||||
          lexComment();
 | 
			
		||||
          continue;
 | 
			
		||||
        }
 | 
			
		||||
        return emitError(tokStart, "unexpected character");
 | 
			
		||||
 | 
			
		||||
      // Ignore whitespace characters.
 | 
			
		||||
      case 0:
 | 
			
		||||
      case ' ':
 | 
			
		||||
      case '\t':
 | 
			
		||||
      case '\n':
 | 
			
		||||
        return lexToken();
 | 
			
		||||
 | 
			
		||||
      case '#':
 | 
			
		||||
        return lexDirective(tokStart);
 | 
			
		||||
      case '"':
 | 
			
		||||
        return lexString(tokStart, /*isStringBlock=*/false);
 | 
			
		||||
 | 
			
		||||
      case '0':
 | 
			
		||||
      case '1':
 | 
			
		||||
      case '2':
 | 
			
		||||
      case '3':
 | 
			
		||||
      case '4':
 | 
			
		||||
      case '5':
 | 
			
		||||
      case '6':
 | 
			
		||||
      case '7':
 | 
			
		||||
      case '8':
 | 
			
		||||
      case '9':
 | 
			
		||||
        return lexNumber(tokStart);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Skip a comment line, starting with a '//'.
 | 
			
		||||
void Lexer::lexComment() {
 | 
			
		||||
  // Advance over the second '/' in a '//' comment.
 | 
			
		||||
  assert(*curPtr == '/');
 | 
			
		||||
  ++curPtr;
 | 
			
		||||
 | 
			
		||||
  while (true) {
 | 
			
		||||
    switch (*curPtr++) {
 | 
			
		||||
      case '\n':
 | 
			
		||||
      case '\r':
 | 
			
		||||
        // Newline is end of comment.
 | 
			
		||||
        return;
 | 
			
		||||
      case 0:
 | 
			
		||||
        // If this is the end of the buffer, end the comment.
 | 
			
		||||
        if (curPtr - 1 == curBuffer.end()) {
 | 
			
		||||
          --curPtr;
 | 
			
		||||
          return;
 | 
			
		||||
        }
 | 
			
		||||
        LLVM_FALLTHROUGH;
 | 
			
		||||
      default:
 | 
			
		||||
        // Skip over other characters.
 | 
			
		||||
        break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Token Lexer::lexDirective(const char *tokStart) {
 | 
			
		||||
  // Match the rest with an identifier regex: [0-9a-zA-Z_]*
 | 
			
		||||
  while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
 | 
			
		||||
 | 
			
		||||
  StringRef str(tokStart, curPtr - tokStart);
 | 
			
		||||
  return Token(Token::directive, str);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Token Lexer::lexIdentifier(const char *tokStart) {
 | 
			
		||||
  // Match the rest of the identifier regex: [0-9a-zA-Z_]*
 | 
			
		||||
  while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
 | 
			
		||||
 | 
			
		||||
  // Check to see if this identifier is a keyword.
 | 
			
		||||
  StringRef str(tokStart, curPtr - tokStart);
 | 
			
		||||
  Token::Kind kind = StringSwitch<Token::Kind>(str)
 | 
			
		||||
                         .Case("attr", Token::kw_attr)
 | 
			
		||||
                         .Case("Attr", Token::kw_Attr)
 | 
			
		||||
                         .Case("erase", Token::kw_erase)
 | 
			
		||||
                         .Case("let", Token::kw_let)
 | 
			
		||||
                         .Case("Constraint", Token::kw_Constraint)
 | 
			
		||||
                         .Case("op", Token::kw_op)
 | 
			
		||||
                         .Case("Op", Token::kw_Op)
 | 
			
		||||
                         .Case("OpName", Token::kw_OpName)
 | 
			
		||||
                         .Case("Pattern", Token::kw_Pattern)
 | 
			
		||||
                         .Case("replace", Token::kw_replace)
 | 
			
		||||
                         .Case("rewrite", Token::kw_rewrite)
 | 
			
		||||
                         .Case("type", Token::kw_type)
 | 
			
		||||
                         .Case("Type", Token::kw_Type)
 | 
			
		||||
                         .Case("TypeRange", Token::kw_TypeRange)
 | 
			
		||||
                         .Case("Value", Token::kw_Value)
 | 
			
		||||
                         .Case("ValueRange", Token::kw_ValueRange)
 | 
			
		||||
                         .Case("with", Token::kw_with)
 | 
			
		||||
                         .Case("_", Token::underscore)
 | 
			
		||||
                         .Default(Token::identifier);
 | 
			
		||||
  return Token(kind, str);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Token Lexer::lexNumber(const char *tokStart) {
 | 
			
		||||
  assert(isdigit(curPtr[-1]));
 | 
			
		||||
 | 
			
		||||
  // Handle the normal decimal case.
 | 
			
		||||
  while (isdigit(*curPtr)) ++curPtr;
 | 
			
		||||
 | 
			
		||||
  return formToken(Token::integer, tokStart);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
 | 
			
		||||
  while (true) {
 | 
			
		||||
    switch (*curPtr++) {
 | 
			
		||||
      case '"':
 | 
			
		||||
        // If this is a string block, we only end the string when we encounter a
 | 
			
		||||
        // `}]`.
 | 
			
		||||
        if (!isStringBlock) return formToken(Token::string, tokStart);
 | 
			
		||||
        continue;
 | 
			
		||||
      case '}':
 | 
			
		||||
        // If this is a string block, we only end the string when we encounter a
 | 
			
		||||
        // `}]`.
 | 
			
		||||
        if (!isStringBlock || *curPtr != ']') continue;
 | 
			
		||||
        ++curPtr;
 | 
			
		||||
        return formToken(Token::string_block, tokStart);
 | 
			
		||||
      case 0:
 | 
			
		||||
        // If this is a random nul character in the middle of a string, just
 | 
			
		||||
        // include it.  If it is the end of file, then it is an error.
 | 
			
		||||
        if (curPtr - 1 != curBuffer.end()) continue;
 | 
			
		||||
        LLVM_FALLTHROUGH;
 | 
			
		||||
      case '\n':
 | 
			
		||||
      case '\v':
 | 
			
		||||
      case '\f':
 | 
			
		||||
        // String blocks allow multiple lines.
 | 
			
		||||
        if (!isStringBlock)
 | 
			
		||||
          return emitError(curPtr - 1, "expected '\"' in string literal");
 | 
			
		||||
        continue;
 | 
			
		||||
 | 
			
		||||
      case '\\':
 | 
			
		||||
        // Handle explicitly a few escapes.
 | 
			
		||||
        if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
 | 
			
		||||
            *curPtr == 't') {
 | 
			
		||||
          ++curPtr;
 | 
			
		||||
        } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
 | 
			
		||||
          // Support \xx for two hex digits.
 | 
			
		||||
          curPtr += 2;
 | 
			
		||||
        } else {
 | 
			
		||||
          return emitError(curPtr - 1, "unknown escape in string literal");
 | 
			
		||||
        }
 | 
			
		||||
        continue;
 | 
			
		||||
 | 
			
		||||
      default:
 | 
			
		||||
        continue;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,220 @@
 | 
			
		|||
//===- Lexer.h - MLIR PDLL Frontend Lexer -----------------------*- C++ -*-===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#ifndef LIB_TOOLS_PDLL_PARSER_LEXER_H_
 | 
			
		||||
#define LIB_TOOLS_PDLL_PARSER_LEXER_H_
 | 
			
		||||
 | 
			
		||||
#include "mlir/Support/LLVM.h"
 | 
			
		||||
#include "llvm/ADT/StringRef.h"
 | 
			
		||||
#include "llvm/Support/SMLoc.h"
 | 
			
		||||
 | 
			
		||||
namespace llvm {
 | 
			
		||||
class SourceMgr;
 | 
			
		||||
} // namespace llvm
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
struct LogicalResult;
 | 
			
		||||
 | 
			
		||||
namespace pdll {
 | 
			
		||||
namespace ast {
 | 
			
		||||
class DiagnosticEngine;
 | 
			
		||||
} // namespace ast
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Token
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
class Token {
 | 
			
		||||
public:
 | 
			
		||||
  enum Kind {
 | 
			
		||||
    // Markers.
 | 
			
		||||
    eof,
 | 
			
		||||
    error,
 | 
			
		||||
 | 
			
		||||
    // Keywords.
 | 
			
		||||
    KW_BEGIN,
 | 
			
		||||
    // Dependent keywords, i.e. those that are treated as keywords depending on
 | 
			
		||||
    // the current parser context.
 | 
			
		||||
    KW_DEPENDENT_BEGIN,
 | 
			
		||||
    kw_attr,
 | 
			
		||||
    kw_op,
 | 
			
		||||
    kw_type,
 | 
			
		||||
    KW_DEPENDENT_END,
 | 
			
		||||
 | 
			
		||||
    // General keywords.
 | 
			
		||||
    kw_Attr,
 | 
			
		||||
    kw_erase,
 | 
			
		||||
    kw_let,
 | 
			
		||||
    kw_Constraint,
 | 
			
		||||
    kw_Op,
 | 
			
		||||
    kw_OpName,
 | 
			
		||||
    kw_Pattern,
 | 
			
		||||
    kw_replace,
 | 
			
		||||
    kw_rewrite,
 | 
			
		||||
    kw_Type,
 | 
			
		||||
    kw_TypeRange,
 | 
			
		||||
    kw_Value,
 | 
			
		||||
    kw_ValueRange,
 | 
			
		||||
    kw_with,
 | 
			
		||||
    KW_END,
 | 
			
		||||
 | 
			
		||||
    // Punctuation.
 | 
			
		||||
    arrow,
 | 
			
		||||
    colon,
 | 
			
		||||
    comma,
 | 
			
		||||
    dot,
 | 
			
		||||
    equal,
 | 
			
		||||
    equal_arrow,
 | 
			
		||||
    semicolon,
 | 
			
		||||
    // Paired punctuation.
 | 
			
		||||
    less,
 | 
			
		||||
    greater,
 | 
			
		||||
    l_brace,
 | 
			
		||||
    r_brace,
 | 
			
		||||
    l_paren,
 | 
			
		||||
    r_paren,
 | 
			
		||||
    l_square,
 | 
			
		||||
    r_square,
 | 
			
		||||
    underscore,
 | 
			
		||||
 | 
			
		||||
    // Tokens.
 | 
			
		||||
    directive,
 | 
			
		||||
    identifier,
 | 
			
		||||
    integer,
 | 
			
		||||
    string_block,
 | 
			
		||||
    string
 | 
			
		||||
  };
 | 
			
		||||
  Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
 | 
			
		||||
 | 
			
		||||
  /// Given a token containing a string literal, return its value, including
 | 
			
		||||
  /// removing the quote characters and unescaping the contents of the string.
 | 
			
		||||
  std::string getStringValue() const;
 | 
			
		||||
 | 
			
		||||
  /// Returns true if the current token is a string literal.
 | 
			
		||||
  bool isString() const { return isAny(Token::string, Token::string_block); }
 | 
			
		||||
 | 
			
		||||
  /// Returns true if the current token is a keyword.
 | 
			
		||||
  bool isKeyword() const {
 | 
			
		||||
    return kind > Token::KW_BEGIN && kind < Token::KW_END;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Returns true if the current token is a keyword in a dependent context, and
 | 
			
		||||
  /// in any other situation (e.g. variable names) may be treated as an
 | 
			
		||||
  /// identifier.
 | 
			
		||||
  bool isDependentKeyword() const {
 | 
			
		||||
    return kind > Token::KW_DEPENDENT_BEGIN && kind < Token::KW_DEPENDENT_END;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Return the bytes that make up this token.
 | 
			
		||||
  StringRef getSpelling() const { return spelling; }
 | 
			
		||||
 | 
			
		||||
  /// Return the kind of this token.
 | 
			
		||||
  Kind getKind() const { return kind; }
 | 
			
		||||
 | 
			
		||||
  /// Return true if this token is one of the specified kinds.
 | 
			
		||||
  bool isAny(Kind k1, Kind k2) const { return is(k1) || is(k2); }
 | 
			
		||||
  template <typename... T>
 | 
			
		||||
  bool isAny(Kind k1, Kind k2, Kind k3, T... others) const {
 | 
			
		||||
    return is(k1) || isAny(k2, k3, others...);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Return if the token does not have the given kind.
 | 
			
		||||
  bool isNot(Kind k) const { return k != kind; }
 | 
			
		||||
  template <typename... T> bool isNot(Kind k1, Kind k2, T... others) const {
 | 
			
		||||
    return !isAny(k1, k2, others...);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Return if the token has the given kind.
 | 
			
		||||
  bool is(Kind K) const { return kind == K; }
 | 
			
		||||
 | 
			
		||||
  /// Return a location for the start of this token.
 | 
			
		||||
  llvm::SMLoc getStartLoc() const {
 | 
			
		||||
    return llvm::SMLoc::getFromPointer(spelling.data());
 | 
			
		||||
  }
 | 
			
		||||
  /// Return a location at the end of this token.
 | 
			
		||||
  llvm::SMLoc getEndLoc() const {
 | 
			
		||||
    return llvm::SMLoc::getFromPointer(spelling.data() + spelling.size());
 | 
			
		||||
  }
 | 
			
		||||
  /// Return a location for the range of this token.
 | 
			
		||||
  llvm::SMRange getLoc() const {
 | 
			
		||||
    return llvm::SMRange(getStartLoc(), getEndLoc());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  /// Discriminator that indicates the kind of token this is.
 | 
			
		||||
  Kind kind;
 | 
			
		||||
 | 
			
		||||
  /// A reference to the entire token contents; this is always a pointer into
 | 
			
		||||
  /// a memory buffer owned by the source manager.
 | 
			
		||||
  StringRef spelling;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Lexer
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
class Lexer {
 | 
			
		||||
public:
 | 
			
		||||
  Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine);
 | 
			
		||||
  ~Lexer();
 | 
			
		||||
 | 
			
		||||
  /// Return a reference to the source manager used by the lexer.
 | 
			
		||||
  llvm::SourceMgr &getSourceMgr() { return srcMgr; }
 | 
			
		||||
 | 
			
		||||
  /// Return a reference to the diagnostic engine used by the lexer.
 | 
			
		||||
  ast::DiagnosticEngine &getDiagEngine() { return diagEngine; }
 | 
			
		||||
 | 
			
		||||
  /// Push an include of the given file. This will cause the lexer to start
 | 
			
		||||
  /// processing the provided file. Returns failure if the file could not be
 | 
			
		||||
  /// opened, success otherwise.
 | 
			
		||||
  LogicalResult pushInclude(StringRef filename);
 | 
			
		||||
 | 
			
		||||
  /// Lex the next token and return it.
 | 
			
		||||
  Token lexToken();
 | 
			
		||||
 | 
			
		||||
  /// Change the position of the lexer cursor. The next token we lex will start
 | 
			
		||||
  /// at the designated point in the input.
 | 
			
		||||
  void resetPointer(const char *newPointer) { curPtr = newPointer; }
 | 
			
		||||
 | 
			
		||||
  /// Emit an error to the lexer with the given location and message.
 | 
			
		||||
  Token emitError(llvm::SMRange loc, const Twine &msg);
 | 
			
		||||
  Token emitError(const char *loc, const Twine &msg);
 | 
			
		||||
  Token emitErrorAndNote(llvm::SMRange loc, const Twine &msg,
 | 
			
		||||
                         llvm::SMRange noteLoc, const Twine ¬e);
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  Token formToken(Token::Kind kind, const char *tokStart) {
 | 
			
		||||
    return Token(kind, StringRef(tokStart, curPtr - tokStart));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Return the next character in the stream.
 | 
			
		||||
  int getNextChar();
 | 
			
		||||
 | 
			
		||||
  /// Lex methods.
 | 
			
		||||
  void lexComment();
 | 
			
		||||
  Token lexDirective(const char *tokStart);
 | 
			
		||||
  Token lexIdentifier(const char *tokStart);
 | 
			
		||||
  Token lexNumber(const char *tokStart);
 | 
			
		||||
  Token lexString(const char *tokStart, bool isStringBlock);
 | 
			
		||||
 | 
			
		||||
  llvm::SourceMgr &srcMgr;
 | 
			
		||||
  int curBufferID;
 | 
			
		||||
  StringRef curBuffer;
 | 
			
		||||
  const char *curPtr;
 | 
			
		||||
 | 
			
		||||
  /// The engine used to emit diagnostics during lexing/parsing.
 | 
			
		||||
  ast::DiagnosticEngine &diagEngine;
 | 
			
		||||
 | 
			
		||||
  /// A flag indicating if we added a default diagnostic handler to the provided
 | 
			
		||||
  /// diagEngine.
 | 
			
		||||
  bool addedHandlerToDiagEngine;
 | 
			
		||||
};
 | 
			
		||||
} // namespace pdll
 | 
			
		||||
} // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif // LIB_TOOLS_PDLL_PARSER_LEXER_H_
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
				
			
			@ -78,6 +78,7 @@ set(MLIR_TEST_DEPENDS
 | 
			
		|||
  mlir-linalg-ods-yaml-gen
 | 
			
		||||
  mlir-lsp-server
 | 
			
		||||
  mlir-opt
 | 
			
		||||
  mlir-pdll
 | 
			
		||||
  mlir-reduce
 | 
			
		||||
  mlir-tblgen
 | 
			
		||||
  mlir-translate
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,7 +21,7 @@ config.name = 'MLIR'
 | 
			
		|||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
 | 
			
		||||
 | 
			
		||||
# suffixes: A list of file extensions to treat as test files.
 | 
			
		||||
config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test']
 | 
			
		||||
config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test', '.pdll']
 | 
			
		||||
 | 
			
		||||
# test_source_root: The root path where tests are located.
 | 
			
		||||
config.test_source_root = os.path.dirname(__file__)
 | 
			
		||||
| 
						 | 
				
			
			@ -68,6 +68,7 @@ tools = [
 | 
			
		|||
    'mlir-cpu-runner',
 | 
			
		||||
    'mlir-linalg-ods-yaml-gen',
 | 
			
		||||
    'mlir-reduce',
 | 
			
		||||
    'mlir-pdll',
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# The following tools are optional
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,23 @@
 | 
			
		|||
// RUN: not mlir-pdll %s -split-input-file 2>&1  | FileCheck %s
 | 
			
		||||
 | 
			
		||||
// CHECK: unknown directive `#foo`
 | 
			
		||||
#foo
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Include
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// CHECK: expected string file name after `include` directive
 | 
			
		||||
#include <>
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK: unable to open include file `unknown_file.pdll`
 | 
			
		||||
#include "unknown_file.pdll"
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK: expected include filename to end with `.pdll`
 | 
			
		||||
#include "unknown_file.foo"
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,62 @@
 | 
			
		|||
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Reference Expr
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected identifier constraint
 | 
			
		||||
  let foo = Foo: ;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: undefined reference to `bar`
 | 
			
		||||
  let foo = bar;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern FooPattern {
 | 
			
		||||
  erase _: Op;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: invalid reference to `FooPattern`
 | 
			
		||||
  let foo = FooPattern;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected `:` after `_` variable
 | 
			
		||||
  let foo = _;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected identifier constraint
 | 
			
		||||
  let foo = _: ;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Member Access Expr
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected identifier or numeric member name
 | 
			
		||||
  let root: Op;
 | 
			
		||||
  erase root.<>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: invalid member access `unknown_result` on expression of type `Op`
 | 
			
		||||
  let root: Op;
 | 
			
		||||
  erase root.unknown_result;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,15 @@
 | 
			
		|||
// RUN: mlir-pdll %s -I %S | FileCheck %s
 | 
			
		||||
 | 
			
		||||
Pattern BeforeIncludedPattern {
 | 
			
		||||
  erase _: Op;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#include "include/included.pdll"
 | 
			
		||||
 | 
			
		||||
Pattern AfterIncludedPattern {
 | 
			
		||||
  erase _: Op;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK: PatternDecl {{.*}} Name<BeforeIncludedPattern>
 | 
			
		||||
// CHECK: PatternDecl {{.*}} Name<IncludedPattern>
 | 
			
		||||
// CHECK: PatternDecl {{.*}} Name<AfterIncludedPattern>
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,5 @@
 | 
			
		|||
// This file is included by 'include.pdll' as part of testing include files.
 | 
			
		||||
 | 
			
		||||
Pattern IncludedPattern {
 | 
			
		||||
  erase _: Op;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,2 @@
 | 
			
		|||
config.suffixes = ['.pdll']
 | 
			
		||||
config.excludes = ['include']
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,26 @@
 | 
			
		|||
// RUN: not mlir-pdll %s -split-input-file 2>&1  | FileCheck %s
 | 
			
		||||
 | 
			
		||||
// CHECK: expected `{` to start pattern body
 | 
			
		||||
Pattern }
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK: :6:9: error: `Foo` has already been defined
 | 
			
		||||
// CHECK: :5:9: note: see previous definition here
 | 
			
		||||
Pattern Foo { erase root: Op; }
 | 
			
		||||
Pattern Foo { erase root: Op; }
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK: expected Pattern body to terminate with an operation rewrite statement
 | 
			
		||||
Pattern {
 | 
			
		||||
  let value: Value;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK: Pattern body was terminated by an operation rewrite statement, but found trailing statements
 | 
			
		||||
Pattern {
 | 
			
		||||
  erase root: Op;
 | 
			
		||||
  let value: Value;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,17 @@
 | 
			
		|||
// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s
 | 
			
		||||
 | 
			
		||||
// CHECK: Module
 | 
			
		||||
// CHECK: `-PatternDecl
 | 
			
		||||
// CHECK:   `-CompoundStmt
 | 
			
		||||
// CHECK:     `-EraseStmt
 | 
			
		||||
Pattern {
 | 
			
		||||
  erase _: Op;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// CHECK: Module
 | 
			
		||||
// CHECK: `-PatternDecl {{.*}} Name<NamedPattern>
 | 
			
		||||
Pattern NamedPattern {
 | 
			
		||||
  erase _: Op;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,222 @@
 | 
			
		|||
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
 | 
			
		||||
 | 
			
		||||
// CHECK: expected top-level declaration, such as a `Pattern`
 | 
			
		||||
10
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected `;` after statement
 | 
			
		||||
  erase _: Op
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// `erase`
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected expression
 | 
			
		||||
  erase;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected `Op` expression
 | 
			
		||||
  erase _: Attr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// `let`
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected identifier after `let` to name a new variable
 | 
			
		||||
  let 5;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: `_` may only be used to define "inline" variables
 | 
			
		||||
  let _;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected expression
 | 
			
		||||
  let foo: Attr<>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected expression of `Type` in type constraint
 | 
			
		||||
  let foo: Attr<_: Attr>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected `>` after variable type constraint
 | 
			
		||||
  let foo: Attr<_: Type{};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: the type of this variable has already been constrained
 | 
			
		||||
  let foo: [Attr<_: Type>, Attr<_: Type];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected `.` after dialect namespace
 | 
			
		||||
  let foo: Op<builtin>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected operation name after dialect namespace
 | 
			
		||||
  let foo: Op<builtin.>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected `>` after operation name
 | 
			
		||||
  let foo: Op<builtin.func<;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected expression
 | 
			
		||||
  let foo: Value<>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected expression of `Type` in type constraint
 | 
			
		||||
  let foo: Value<_: Attr>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected `>` after variable type constraint
 | 
			
		||||
  let foo: Value<_: Type{};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: the type of this variable has already been constrained
 | 
			
		||||
  let foo: [Value<_: Type>, Value<_: Type];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected expression
 | 
			
		||||
  let foo: ValueRange<10>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected expression of `TypeRange` in type constraint
 | 
			
		||||
  let foo: ValueRange<_: Type>;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected `>` after variable type constraint
 | 
			
		||||
  let foo: ValueRange<_: Type{};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: the type of this variable has already been constrained
 | 
			
		||||
  let foo: [ValueRange<_: Type>, ValueRange<_: Type];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: unknown reference to constraint `UnknownConstraint`
 | 
			
		||||
  let foo: UnknownConstraint;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern Foo {
 | 
			
		||||
  erase root: Op;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: invalid reference to non-constraint
 | 
			
		||||
  let foo: Foo;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: constraint type `Attr` is incompatible with the previously inferred type `Value`
 | 
			
		||||
  let foo: [Value, Attr];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected `]` after constraint list
 | 
			
		||||
  let foo: [Attr[];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: expected expression
 | 
			
		||||
  let foo: Attr = ;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: type constraints are not permitted on variables with initializers
 | 
			
		||||
  let foo: ValueRange<_: Type> = _: Op;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: unable to infer type for variable `foo`
 | 
			
		||||
  // CHECK: note: the type of a variable must be inferable from the constraint list or the initializer
 | 
			
		||||
  let foo;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: unable to convert expression of type `Attr` to the expected type of `Value`
 | 
			
		||||
  let foo: Value = _: Attr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
Pattern {
 | 
			
		||||
  // CHECK: :7:7: error: `foo` has already been defined
 | 
			
		||||
  // CHECK: :6:7: note: see previous definition here
 | 
			
		||||
  let foo: Attr;
 | 
			
		||||
  let foo: Attr;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,155 @@
 | 
			
		|||
// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// CompoundStmt
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// CHECK: Module
 | 
			
		||||
// CHECK: CompoundStmt
 | 
			
		||||
// CHECK:   |-LetStmt
 | 
			
		||||
// CHECK:   `-EraseStmt
 | 
			
		||||
Pattern {
 | 
			
		||||
  let root: Op;
 | 
			
		||||
  erase root;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// EraseStmt
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// CHECK: Module
 | 
			
		||||
// CHECK: EraseStmt
 | 
			
		||||
// CHECK:   `-DeclRefExpr {{.*}} Type<Op>
 | 
			
		||||
Pattern {
 | 
			
		||||
  erase _: Op;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// LetStmt
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// CHECK: Module
 | 
			
		||||
// CHECK: LetStmt
 | 
			
		||||
// CHECK:   `-VariableDecl {{.*}} Name<attrVar> Type<Attr>
 | 
			
		||||
// CHECK:     `Constraints`
 | 
			
		||||
// CHECK:       `-AttrConstraintDecl
 | 
			
		||||
// CHECK:   `-VariableDecl {{.*}} Name<var> Type<Op>
 | 
			
		||||
// CHECK:     `Constraints`
 | 
			
		||||
// CHECK:       `-OpConstraintDecl
 | 
			
		||||
// CHECK:         `-OpNameDecl
 | 
			
		||||
Pattern {
 | 
			
		||||
  let attrVar: Attr;
 | 
			
		||||
  let var: Op;
 | 
			
		||||
  erase var;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Check for proper refinement between constraint types.
 | 
			
		||||
 | 
			
		||||
// CHECK: Module
 | 
			
		||||
// CHECK: LetStmt
 | 
			
		||||
// CHECK:   `-VariableDecl {{.*}} Name<var> Type<Op<dialect.op>>
 | 
			
		||||
// CHECK:     `Constraints`
 | 
			
		||||
// CHECK:       `-OpConstraintDecl
 | 
			
		||||
// CHECK:         `-OpNameDecl
 | 
			
		||||
// CHECK:       `-OpConstraintDecl
 | 
			
		||||
// CHECK:         `-OpNameDecl {{.*}} Name<dialect.op>
 | 
			
		||||
Pattern {
 | 
			
		||||
  let var: [Op, Op<dialect.op>];
 | 
			
		||||
  erase var;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Check for proper conversion between initializer and constraint type.
 | 
			
		||||
 | 
			
		||||
// CHECK: Module
 | 
			
		||||
// CHECK: LetStmt
 | 
			
		||||
// CHECK:   `-VariableDecl {{.*}} Name<var> Type<Op<dialect.op>>
 | 
			
		||||
// CHECK:     `-DeclRefExpr {{.*}} Type<Op<dialect.op>>
 | 
			
		||||
// CHECK:       `-VariableDecl {{.*}} Name<input>
 | 
			
		||||
// CHECK:     `Constraints`
 | 
			
		||||
// CHECK:       `-OpConstraintDecl
 | 
			
		||||
// CHECK:         `-OpNameDecl
 | 
			
		||||
Pattern {
 | 
			
		||||
  let input: Op<dialect.op>;
 | 
			
		||||
  let var: Op = input;
 | 
			
		||||
  erase var;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Check for proper conversion between initializer and constraint type.
 | 
			
		||||
 | 
			
		||||
// CHECK: Module
 | 
			
		||||
// CHECK: LetStmt
 | 
			
		||||
// CHECK:   `-VariableDecl {{.*}} Name<var> Type<Value>
 | 
			
		||||
// CHECK:     `-MemberAccessExpr {{.*}} Member<$results> Type<Value>
 | 
			
		||||
// CHECK:       `-DeclRefExpr {{.*}} Type<Op<dialect.op>>
 | 
			
		||||
// CHECK:         `-VariableDecl {{.*}} Name<input>
 | 
			
		||||
// CHECK:     `Constraints`
 | 
			
		||||
// CHECK:       `-ValueConstraintDecl
 | 
			
		||||
Pattern {
 | 
			
		||||
  let input: Op<dialect.op>;
 | 
			
		||||
  let var: Value = input;
 | 
			
		||||
  erase _: Op;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Check for proper conversion between initializer and constraint type.
 | 
			
		||||
 | 
			
		||||
// CHECK: Module
 | 
			
		||||
// CHECK: LetStmt
 | 
			
		||||
// CHECK:   `-VariableDecl {{.*}} Name<var> Type<ValueRange>
 | 
			
		||||
// CHECK:     `-MemberAccessExpr {{.*}} Member<$results> Type<ValueRange>
 | 
			
		||||
// CHECK:       `-DeclRefExpr {{.*}} Type<Op<dialect.op>>
 | 
			
		||||
// CHECK:         `-VariableDecl {{.*}} Name<input>
 | 
			
		||||
// CHECK:     `Constraints`
 | 
			
		||||
// CHECK:       `-ValueRangeConstraintDecl
 | 
			
		||||
Pattern {
 | 
			
		||||
  let input: Op<dialect.op>;
 | 
			
		||||
  let var: ValueRange = input;
 | 
			
		||||
  erase _: Op;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Check for proper handling of type constraints.
 | 
			
		||||
 | 
			
		||||
// CHECK: Module
 | 
			
		||||
// CHECK: LetStmt
 | 
			
		||||
// CHECK:   `-VariableDecl {{.*}} Name<var> Type<Value>
 | 
			
		||||
// CHECK:     `Constraints`
 | 
			
		||||
// CHECK:       `-ValueConstraintDecl
 | 
			
		||||
// CHECK:         `-DeclRefExpr {{.*}} Type<Type>
 | 
			
		||||
// CHECK:           `-VariableDecl {{.*}} Name<_> Type<Type>
 | 
			
		||||
// CHECK:             `Constraints`
 | 
			
		||||
// CHECK:               `-TypeConstraintDecl
 | 
			
		||||
Pattern {
 | 
			
		||||
  let var: Value<_: Type>;
 | 
			
		||||
  erase _: Op;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// -----
 | 
			
		||||
 | 
			
		||||
// Check for proper handling of type constraints.
 | 
			
		||||
 | 
			
		||||
// CHECK: Module
 | 
			
		||||
// CHECK: LetStmt
 | 
			
		||||
// CHECK:   `-VariableDecl {{.*}} Name<var> Type<ValueRange>
 | 
			
		||||
// CHECK:     `Constraints`
 | 
			
		||||
// CHECK:       `-ValueRangeConstraintDecl
 | 
			
		||||
// CHECK:         `-DeclRefExpr {{.*}} Type<TypeRange>
 | 
			
		||||
// CHECK:           `-VariableDecl {{.*}} Name<_> Type<TypeRange>
 | 
			
		||||
// CHECK:             `Constraints`
 | 
			
		||||
// CHECK:               `-TypeRangeConstraintDecl
 | 
			
		||||
Pattern {
 | 
			
		||||
  let var: ValueRange<_: TypeRange>;
 | 
			
		||||
  erase _: Op;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -1,6 +1,7 @@
 | 
			
		|||
add_subdirectory(mlir-cpu-runner)
 | 
			
		||||
add_subdirectory(mlir-lsp-server)
 | 
			
		||||
add_subdirectory(mlir-opt)
 | 
			
		||||
add_subdirectory(mlir-pdll)
 | 
			
		||||
add_subdirectory(mlir-reduce)
 | 
			
		||||
add_subdirectory(mlir-shlib)
 | 
			
		||||
add_subdirectory(mlir-spirv-cpu-runner)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,16 @@
 | 
			
		|||
set(LIBS
 | 
			
		||||
  MLIRPDLLAST
 | 
			
		||||
  MLIRPDLLParser
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
add_llvm_tool(mlir-pdll
 | 
			
		||||
  mlir-pdll.cpp
 | 
			
		||||
 | 
			
		||||
  DEPENDS
 | 
			
		||||
  ${LIBS}
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
target_link_libraries(mlir-pdll PRIVATE ${LIBS})
 | 
			
		||||
llvm_update_compile_flags(mlir-pdll)
 | 
			
		||||
 | 
			
		||||
mlir_check_all_link_libraries(mlir-pdll)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,111 @@
 | 
			
		|||
//===- mlir-pdll.cpp - MLIR PDLL frontend -----------------------*- C++ -*-===//
 | 
			
		||||
//
 | 
			
		||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
			
		||||
// See https://llvm.org/LICENSE.txt for license information.
 | 
			
		||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
			
		||||
//
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#include "mlir/Support/FileUtilities.h"
 | 
			
		||||
#include "mlir/Support/ToolUtilities.h"
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Context.h"
 | 
			
		||||
#include "mlir/Tools/PDLL/AST/Nodes.h"
 | 
			
		||||
#include "mlir/Tools/PDLL/Parser/Parser.h"
 | 
			
		||||
#include "llvm/Support/CommandLine.h"
 | 
			
		||||
#include "llvm/Support/InitLLVM.h"
 | 
			
		||||
#include "llvm/Support/SourceMgr.h"
 | 
			
		||||
#include "llvm/Support/ToolOutputFile.h"
 | 
			
		||||
 | 
			
		||||
using namespace mlir;
 | 
			
		||||
using namespace mlir::pdll;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// main
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// The desired output type.
 | 
			
		||||
enum class OutputType {
 | 
			
		||||
  AST,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static LogicalResult
 | 
			
		||||
processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
 | 
			
		||||
              OutputType outputType, std::vector<std::string> &includeDirs) {
 | 
			
		||||
  llvm::SourceMgr sourceMgr;
 | 
			
		||||
  sourceMgr.setIncludeDirs(includeDirs);
 | 
			
		||||
  sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), llvm::SMLoc());
 | 
			
		||||
 | 
			
		||||
  ast::Context astContext;
 | 
			
		||||
  FailureOr<ast::Module *> module = parsePDLAST(astContext, sourceMgr);
 | 
			
		||||
  if (failed(module))
 | 
			
		||||
    return failure();
 | 
			
		||||
 | 
			
		||||
  switch (outputType) {
 | 
			
		||||
  case OutputType::AST:
 | 
			
		||||
    (*module)->print(os);
 | 
			
		||||
    break;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int main(int argc, char **argv) {
 | 
			
		||||
  llvm::cl::opt<std::string> inputFilename(
 | 
			
		||||
      llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
 | 
			
		||||
      llvm::cl::value_desc("filename"));
 | 
			
		||||
 | 
			
		||||
  llvm::cl::opt<std::string> outputFilename(
 | 
			
		||||
      "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
 | 
			
		||||
      llvm::cl::init("-"));
 | 
			
		||||
 | 
			
		||||
  llvm::cl::list<std::string> includeDirs(
 | 
			
		||||
      "I", llvm::cl::desc("Directory of include files"),
 | 
			
		||||
      llvm::cl::value_desc("directory"), llvm::cl::Prefix);
 | 
			
		||||
 | 
			
		||||
  llvm::cl::opt<bool> splitInputFile(
 | 
			
		||||
      "split-input-file",
 | 
			
		||||
      llvm::cl::desc("Split the input file into pieces and process each "
 | 
			
		||||
                     "chunk independently"),
 | 
			
		||||
      llvm::cl::init(false));
 | 
			
		||||
  llvm::cl::opt<enum OutputType> outputType(
 | 
			
		||||
      "x", llvm::cl::init(OutputType::AST),
 | 
			
		||||
      llvm::cl::desc("The type of output desired"),
 | 
			
		||||
      llvm::cl::values(clEnumValN(OutputType::AST, "ast",
 | 
			
		||||
                                  "generate the AST for the input file")));
 | 
			
		||||
 | 
			
		||||
  llvm::InitLLVM y(argc, argv);
 | 
			
		||||
  llvm::cl::ParseCommandLineOptions(argc, argv, "PDLL Frontend");
 | 
			
		||||
 | 
			
		||||
  // Set up the input file.
 | 
			
		||||
  std::string errorMessage;
 | 
			
		||||
  std::unique_ptr<llvm::MemoryBuffer> inputFile =
 | 
			
		||||
      openInputFile(inputFilename, &errorMessage);
 | 
			
		||||
  if (!inputFile) {
 | 
			
		||||
    llvm::errs() << errorMessage << "\n";
 | 
			
		||||
    return 1;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Set up the output file.
 | 
			
		||||
  std::unique_ptr<llvm::ToolOutputFile> outputFile =
 | 
			
		||||
      openOutputFile(outputFilename, &errorMessage);
 | 
			
		||||
  if (!outputFile) {
 | 
			
		||||
    llvm::errs() << errorMessage << "\n";
 | 
			
		||||
    return 1;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // The split-input-file mode is a very specific mode that slices the file
 | 
			
		||||
  // up into small pieces and checks each independently.
 | 
			
		||||
  auto processFn = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
 | 
			
		||||
                       raw_ostream &os) {
 | 
			
		||||
    return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs);
 | 
			
		||||
  };
 | 
			
		||||
  if (splitInputFile) {
 | 
			
		||||
    if (failed(splitAndProcessBuffer(std::move(inputFile), processFn,
 | 
			
		||||
                                     outputFile->os())))
 | 
			
		||||
      return 1;
 | 
			
		||||
  } else if (failed(processFn(std::move(inputFile), outputFile->os()))) {
 | 
			
		||||
    return 1;
 | 
			
		||||
  }
 | 
			
		||||
  outputFile->keep();
 | 
			
		||||
  return 0;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -7923,3 +7923,51 @@ cc_binary(
 | 
			
		|||
        "//mlir/test:TestDialect",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "PDLLAST",
 | 
			
		||||
    srcs = glob(
 | 
			
		||||
        [
 | 
			
		||||
            "lib/Tools/PDLL/AST/*.cpp",
 | 
			
		||||
            "lib/Tools/PDLL/AST/*.h",
 | 
			
		||||
        ],
 | 
			
		||||
    ),
 | 
			
		||||
    hdrs = glob(["include/mlir/Tools/PDLL/AST/*.h"]),
 | 
			
		||||
    includes = ["include"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//llvm:Support",
 | 
			
		||||
        "//mlir:Support",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "PDLLParser",
 | 
			
		||||
    srcs = glob(
 | 
			
		||||
        [
 | 
			
		||||
            "lib/Tools/PDLL/Parser/*.cpp",
 | 
			
		||||
            "lib/Tools/PDLL/Parser/*.h",
 | 
			
		||||
        ],
 | 
			
		||||
    ),
 | 
			
		||||
    hdrs = glob(["include/mlir/Tools/PDLL/Parser/*.h"]),
 | 
			
		||||
    includes = ["include"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":PDLLAST",
 | 
			
		||||
        ":Support",
 | 
			
		||||
        ":TableGen",
 | 
			
		||||
        "//llvm:Support",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_binary(
 | 
			
		||||
    name = "mlir-pdll",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "tools/mlir-pdll/mlir-pdll.cpp",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":PDLLAST",
 | 
			
		||||
        ":PDLLParser",
 | 
			
		||||
        ":Support",
 | 
			
		||||
        "//llvm:Support",
 | 
			
		||||
        "//llvm:config",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue