forked from OSchip/llvm-project
				
			
		
			
				
	
	
		
			775 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			775 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			C++
		
	
	
	
//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
 | 
						|
//
 | 
						|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
						|
// See https://llvm.org/LICENSE.txt for license information.
 | 
						|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
						|
//
 | 
						|
//===----------------------------------------------------------------------===//
 | 
						|
 | 
						|
#include "IRModules.h"
 | 
						|
#include "PybindUtils.h"
 | 
						|
 | 
						|
#include "mlir-c/StandardAttributes.h"
 | 
						|
#include "mlir-c/StandardTypes.h"
 | 
						|
 | 
						|
namespace py = pybind11;
 | 
						|
using namespace mlir;
 | 
						|
using namespace mlir::python;
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
// Docstrings (trivial, non-duplicated docstrings are included inline).
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
static const char kContextParseDocstring[] =
 | 
						|
    R"(Parses a module's assembly format from a string.
 | 
						|
 | 
						|
Returns a new MlirModule or raises a ValueError if the parsing fails.
 | 
						|
 | 
						|
See also: https://mlir.llvm.org/docs/LangRef/
 | 
						|
)";
 | 
						|
 | 
						|
static const char kContextParseTypeDocstring[] =
 | 
						|
    R"(Parses the assembly form of a type.
 | 
						|
 | 
						|
Returns a Type object or raises a ValueError if the type cannot be parsed.
 | 
						|
 | 
						|
See also: https://mlir.llvm.org/docs/LangRef/#type-system
 | 
						|
)";
 | 
						|
 | 
						|
static const char kContextGetUnknownLocationDocstring[] =
 | 
						|
    R"(Gets a Location representing an unknown location)";
 | 
						|
 | 
						|
static const char kContextGetFileLocationDocstring[] =
 | 
						|
    R"(Gets a Location representing a file, line and column)";
 | 
						|
 | 
						|
static const char kContextCreateBlockDocstring[] =
 | 
						|
    R"(Creates a detached block)";
 | 
						|
 | 
						|
static const char kContextCreateRegionDocstring[] =
 | 
						|
    R"(Creates a detached region)";
 | 
						|
 | 
						|
static const char kRegionAppendBlockDocstring[] =
 | 
						|
    R"(Appends a block to a region.
 | 
						|
 | 
						|
Raises:
 | 
						|
  ValueError: If the block is already attached to another region.
 | 
						|
)";
 | 
						|
 | 
						|
static const char kRegionInsertBlockDocstring[] =
 | 
						|
    R"(Inserts a block at a postiion in a region.
 | 
						|
 | 
						|
Raises:
 | 
						|
  ValueError: If the block is already attached to another region.
 | 
						|
)";
 | 
						|
 | 
						|
static const char kRegionFirstBlockDocstring[] =
 | 
						|
    R"(Gets the first block in a region.
 | 
						|
 | 
						|
Blocks can also be accessed via the `blocks` container.
 | 
						|
 | 
						|
Raises:
 | 
						|
  IndexError: If the region has no blocks.
 | 
						|
)";
 | 
						|
 | 
						|
static const char kBlockNextInRegionDocstring[] =
 | 
						|
    R"(Gets the next block in the enclosing region.
 | 
						|
 | 
						|
Blocks can also be accessed via the `blocks` container of the owning region.
 | 
						|
This method exists to mirror the lower level API and should not be preferred.
 | 
						|
 | 
						|
Raises:
 | 
						|
  IndexError: If there are no further blocks.
 | 
						|
)";
 | 
						|
 | 
						|
static const char kOperationStrDunderDocstring[] =
 | 
						|
    R"(Prints the assembly form of the operation with default options.
 | 
						|
 | 
						|
If more advanced control over the assembly formatting or I/O options is needed,
 | 
						|
use the dedicated print method, which supports keyword arguments to customize
 | 
						|
behavior.
 | 
						|
)";
 | 
						|
 | 
						|
static const char kTypeStrDunderDocstring[] =
 | 
						|
    R"(Prints the assembly form of the type.)";
 | 
						|
 | 
						|
static const char kDumpDocstring[] =
 | 
						|
    R"(Dumps a debug representation of the object to stderr.)";
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
// Conversion utilities.
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
namespace {
 | 
						|
 | 
						|
/// Accumulates into a python string from a method that accepts an
 | 
						|
/// MlirStringCallback.
 | 
						|
struct PyPrintAccumulator {
 | 
						|
  py::list parts;
 | 
						|
 | 
						|
  void *getUserData() { return this; }
 | 
						|
 | 
						|
  MlirStringCallback getCallback() {
 | 
						|
    return [](const char *part, intptr_t size, void *userData) {
 | 
						|
      PyPrintAccumulator *printAccum =
 | 
						|
          static_cast<PyPrintAccumulator *>(userData);
 | 
						|
      py::str pyPart(part, size); // Decodes as UTF-8 by default.
 | 
						|
      printAccum->parts.append(std::move(pyPart));
 | 
						|
    };
 | 
						|
  }
 | 
						|
 | 
						|
  py::str join() {
 | 
						|
    py::str delim("", 0);
 | 
						|
    return delim.attr("join")(parts);
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
/// Accumulates into a python string from a method that is expected to make
 | 
						|
/// one (no more, no less) call to the callback (asserts internally on
 | 
						|
/// violation).
 | 
						|
struct PySinglePartStringAccumulator {
 | 
						|
  void *getUserData() { return this; }
 | 
						|
 | 
						|
  MlirStringCallback getCallback() {
 | 
						|
    return [](const char *part, intptr_t size, void *userData) {
 | 
						|
      PySinglePartStringAccumulator *accum =
 | 
						|
          static_cast<PySinglePartStringAccumulator *>(userData);
 | 
						|
      assert(!accum->invoked &&
 | 
						|
             "PySinglePartStringAccumulator called back multiple times");
 | 
						|
      accum->invoked = true;
 | 
						|
      accum->value = py::str(part, size);
 | 
						|
    };
 | 
						|
  }
 | 
						|
 | 
						|
  py::str takeValue() {
 | 
						|
    assert(invoked && "PySinglePartStringAccumulator not called back");
 | 
						|
    return std::move(value);
 | 
						|
  }
 | 
						|
 | 
						|
private:
 | 
						|
  py::str value;
 | 
						|
  bool invoked = false;
 | 
						|
};
 | 
						|
 | 
						|
} // namespace
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
// PyBlock, PyRegion, and PyOperation.
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
void PyRegion::attachToParent() {
 | 
						|
  if (!detached) {
 | 
						|
    throw SetPyError(PyExc_ValueError, "Region is already attached to an op");
 | 
						|
  }
 | 
						|
  detached = false;
 | 
						|
}
 | 
						|
 | 
						|
void PyBlock::attachToParent() {
 | 
						|
  if (!detached) {
 | 
						|
    throw SetPyError(PyExc_ValueError, "Block is already attached to an op");
 | 
						|
  }
 | 
						|
  detached = false;
 | 
						|
}
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
// PyAttribute.
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
bool PyAttribute::operator==(const PyAttribute &other) {
 | 
						|
  return mlirAttributeEqual(attr, other.attr);
 | 
						|
}
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
// PyNamedAttribute.
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
 | 
						|
    : ownedName(new std::string(std::move(ownedName))) {
 | 
						|
  namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr);
 | 
						|
}
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
// PyType.
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
bool PyType::operator==(const PyType &other) {
 | 
						|
  return mlirTypeEqual(type, other.type);
 | 
						|
}
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
// Standard attribute subclasses.
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
namespace {
 | 
						|
 | 
						|
/// CRTP base classes for Python attributes that subclass Attribute and should
 | 
						|
/// be castable from it (i.e. via something like StringAttr(attr)).
 | 
						|
template <typename T>
 | 
						|
class PyConcreteAttribute : public PyAttribute {
 | 
						|
public:
 | 
						|
  // Derived classes must define statics for:
 | 
						|
  //   IsAFunctionTy isaFunction
 | 
						|
  //   const char *pyClassName
 | 
						|
  using ClassTy = py::class_<T, PyAttribute>;
 | 
						|
  using IsAFunctionTy = int (*)(MlirAttribute);
 | 
						|
 | 
						|
  PyConcreteAttribute() = default;
 | 
						|
  PyConcreteAttribute(MlirAttribute attr) : PyAttribute(attr) {}
 | 
						|
  PyConcreteAttribute(PyAttribute &orig)
 | 
						|
      : PyConcreteAttribute(castFrom(orig)) {}
 | 
						|
 | 
						|
  static MlirAttribute castFrom(PyAttribute &orig) {
 | 
						|
    if (!T::isaFunction(orig.attr)) {
 | 
						|
      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
 | 
						|
      throw SetPyError(PyExc_ValueError,
 | 
						|
                       llvm::Twine("Cannot cast attribute to ") +
 | 
						|
                           T::pyClassName + " (from " + origRepr + ")");
 | 
						|
    }
 | 
						|
    return orig.attr;
 | 
						|
  }
 | 
						|
 | 
						|
  static void bind(py::module &m) {
 | 
						|
    auto cls = ClassTy(m, T::pyClassName);
 | 
						|
    cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
 | 
						|
    T::bindDerived(cls);
 | 
						|
  }
 | 
						|
 | 
						|
  /// Implemented by derived classes to add methods to the Python subclass.
 | 
						|
  static void bindDerived(ClassTy &m) {}
 | 
						|
};
 | 
						|
 | 
						|
class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
 | 
						|
public:
 | 
						|
  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
 | 
						|
  static constexpr const char *pyClassName = "StringAttr";
 | 
						|
  using PyConcreteAttribute::PyConcreteAttribute;
 | 
						|
 | 
						|
  static void bindDerived(ClassTy &c) {
 | 
						|
    c.def_static(
 | 
						|
        "get",
 | 
						|
        [](PyMlirContext &context, std::string value) {
 | 
						|
          MlirAttribute attr =
 | 
						|
              mlirStringAttrGet(context.context, value.size(), &value[0]);
 | 
						|
          return PyStringAttribute(attr);
 | 
						|
        },
 | 
						|
        py::keep_alive<0, 1>(), "Gets a uniqued string attribute");
 | 
						|
    c.def_static(
 | 
						|
        "get_typed",
 | 
						|
        [](PyType &type, std::string value) {
 | 
						|
          MlirAttribute attr =
 | 
						|
              mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
 | 
						|
          return PyStringAttribute(attr);
 | 
						|
        },
 | 
						|
        py::keep_alive<0, 1>(),
 | 
						|
        "Gets a uniqued string attribute associated to a type");
 | 
						|
    c.def_property_readonly(
 | 
						|
        "value",
 | 
						|
        [](PyStringAttribute &self) {
 | 
						|
          PySinglePartStringAccumulator accum;
 | 
						|
          mlirStringAttrGetValue(self.attr, accum.getCallback(),
 | 
						|
                                 accum.getUserData());
 | 
						|
          return accum.takeValue();
 | 
						|
        },
 | 
						|
        "Returns the value of the string attribute");
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
} // namespace
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
// Standard type subclasses.
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
namespace {
 | 
						|
 | 
						|
/// CRTP base classes for Python types that subclass Type and should be
 | 
						|
/// castable from it (i.e. via something like IntegerType(t)).
 | 
						|
template <typename T>
 | 
						|
class PyConcreteType : public PyType {
 | 
						|
public:
 | 
						|
  // Derived classes must define statics for:
 | 
						|
  //   IsAFunctionTy isaFunction
 | 
						|
  //   const char *pyClassName
 | 
						|
  using ClassTy = py::class_<T, PyType>;
 | 
						|
  using IsAFunctionTy = int (*)(MlirType);
 | 
						|
 | 
						|
  PyConcreteType() = default;
 | 
						|
  PyConcreteType(MlirType t) : PyType(t) {}
 | 
						|
  PyConcreteType(PyType &orig) : PyType(castFrom(orig)) {}
 | 
						|
 | 
						|
  static MlirType castFrom(PyType &orig) {
 | 
						|
    if (!T::isaFunction(orig.type)) {
 | 
						|
      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
 | 
						|
      throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
 | 
						|
                                             T::pyClassName + " (from " +
 | 
						|
                                             origRepr + ")");
 | 
						|
    }
 | 
						|
    return orig.type;
 | 
						|
  }
 | 
						|
 | 
						|
  static void bind(py::module &m) {
 | 
						|
    auto cls = ClassTy(m, T::pyClassName);
 | 
						|
    cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
 | 
						|
    T::bindDerived(cls);
 | 
						|
  }
 | 
						|
 | 
						|
  /// Implemented by derived classes to add methods to the Python subclass.
 | 
						|
  static void bindDerived(ClassTy &m) {}
 | 
						|
};
 | 
						|
 | 
						|
class PyIntegerType : public PyConcreteType<PyIntegerType> {
 | 
						|
public:
 | 
						|
  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
 | 
						|
  static constexpr const char *pyClassName = "IntegerType";
 | 
						|
  using PyConcreteType::PyConcreteType;
 | 
						|
 | 
						|
  static void bindDerived(ClassTy &c) {
 | 
						|
    c.def_static(
 | 
						|
        "get_signless",
 | 
						|
        [](PyMlirContext &context, unsigned width) {
 | 
						|
          MlirType t = mlirIntegerTypeGet(context.context, width);
 | 
						|
          return PyIntegerType(t);
 | 
						|
        },
 | 
						|
        py::keep_alive<0, 1>(), "Create a signless integer type");
 | 
						|
    c.def_static(
 | 
						|
        "get_signed",
 | 
						|
        [](PyMlirContext &context, unsigned width) {
 | 
						|
          MlirType t = mlirIntegerTypeSignedGet(context.context, width);
 | 
						|
          return PyIntegerType(t);
 | 
						|
        },
 | 
						|
        py::keep_alive<0, 1>(), "Create a signed integer type");
 | 
						|
    c.def_static(
 | 
						|
        "get_unsigned",
 | 
						|
        [](PyMlirContext &context, unsigned width) {
 | 
						|
          MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
 | 
						|
          return PyIntegerType(t);
 | 
						|
        },
 | 
						|
        py::keep_alive<0, 1>(), "Create an unsigned integer type");
 | 
						|
    c.def_property_readonly(
 | 
						|
        "width",
 | 
						|
        [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
 | 
						|
        "Returns the width of the integer type");
 | 
						|
    c.def_property_readonly(
 | 
						|
        "is_signless",
 | 
						|
        [](PyIntegerType &self) -> bool {
 | 
						|
          return mlirIntegerTypeIsSignless(self.type);
 | 
						|
        },
 | 
						|
        "Returns whether this is a signless integer");
 | 
						|
    c.def_property_readonly(
 | 
						|
        "is_signed",
 | 
						|
        [](PyIntegerType &self) -> bool {
 | 
						|
          return mlirIntegerTypeIsSigned(self.type);
 | 
						|
        },
 | 
						|
        "Returns whether this is a signed integer");
 | 
						|
    c.def_property_readonly(
 | 
						|
        "is_unsigned",
 | 
						|
        [](PyIntegerType &self) -> bool {
 | 
						|
          return mlirIntegerTypeIsUnsigned(self.type);
 | 
						|
        },
 | 
						|
        "Returns whether this is an unsigned integer");
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
/// Index Type subclass - IndexType.
 | 
						|
class PyIndexType : public PyConcreteType<PyIndexType> {
 | 
						|
public:
 | 
						|
  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
 | 
						|
  static constexpr const char *pyClassName = "IndexType";
 | 
						|
  using PyConcreteType::PyConcreteType;
 | 
						|
 | 
						|
  static void bindDerived(ClassTy &c) {
 | 
						|
    c.def(py::init([](PyMlirContext &context) {
 | 
						|
            MlirType t = mlirIndexTypeGet(context.context);
 | 
						|
            return PyIndexType(t);
 | 
						|
          }),
 | 
						|
          py::keep_alive<0, 1>(), "Create a index type.");
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
/// Floating Point Type subclass - BF16Type.
 | 
						|
class PyBF16Type : public PyConcreteType<PyBF16Type> {
 | 
						|
public:
 | 
						|
  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
 | 
						|
  static constexpr const char *pyClassName = "BF16Type";
 | 
						|
  using PyConcreteType::PyConcreteType;
 | 
						|
 | 
						|
  static void bindDerived(ClassTy &c) {
 | 
						|
    c.def(py::init([](PyMlirContext &context) {
 | 
						|
            MlirType t = mlirBF16TypeGet(context.context);
 | 
						|
            return PyBF16Type(t);
 | 
						|
          }),
 | 
						|
          py::keep_alive<0, 1>(), "Create a bf16 type.");
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
/// Floating Point Type subclass - F16Type.
 | 
						|
class PyF16Type : public PyConcreteType<PyF16Type> {
 | 
						|
public:
 | 
						|
  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
 | 
						|
  static constexpr const char *pyClassName = "F16Type";
 | 
						|
  using PyConcreteType::PyConcreteType;
 | 
						|
 | 
						|
  static void bindDerived(ClassTy &c) {
 | 
						|
    c.def(py::init([](PyMlirContext &context) {
 | 
						|
            MlirType t = mlirF16TypeGet(context.context);
 | 
						|
            return PyF16Type(t);
 | 
						|
          }),
 | 
						|
          py::keep_alive<0, 1>(), "Create a f16 type.");
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
/// Floating Point Type subclass - F32Type.
 | 
						|
class PyF32Type : public PyConcreteType<PyF32Type> {
 | 
						|
public:
 | 
						|
  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
 | 
						|
  static constexpr const char *pyClassName = "F32Type";
 | 
						|
  using PyConcreteType::PyConcreteType;
 | 
						|
 | 
						|
  static void bindDerived(ClassTy &c) {
 | 
						|
    c.def(py::init([](PyMlirContext &context) {
 | 
						|
            MlirType t = mlirF32TypeGet(context.context);
 | 
						|
            return PyF32Type(t);
 | 
						|
          }),
 | 
						|
          py::keep_alive<0, 1>(), "Create a f32 type.");
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
/// Floating Point Type subclass - F64Type.
 | 
						|
class PyF64Type : public PyConcreteType<PyF64Type> {
 | 
						|
public:
 | 
						|
  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
 | 
						|
  static constexpr const char *pyClassName = "F64Type";
 | 
						|
  using PyConcreteType::PyConcreteType;
 | 
						|
 | 
						|
  static void bindDerived(ClassTy &c) {
 | 
						|
    c.def(py::init([](PyMlirContext &context) {
 | 
						|
            MlirType t = mlirF64TypeGet(context.context);
 | 
						|
            return PyF64Type(t);
 | 
						|
          }),
 | 
						|
          py::keep_alive<0, 1>(), "Create a f64 type.");
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
/// None Type subclass - NoneType.
 | 
						|
class PyNoneType : public PyConcreteType<PyNoneType> {
 | 
						|
public:
 | 
						|
  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
 | 
						|
  static constexpr const char *pyClassName = "NoneType";
 | 
						|
  using PyConcreteType::PyConcreteType;
 | 
						|
 | 
						|
  static void bindDerived(ClassTy &c) {
 | 
						|
    c.def(py::init([](PyMlirContext &context) {
 | 
						|
            MlirType t = mlirNoneTypeGet(context.context);
 | 
						|
            return PyNoneType(t);
 | 
						|
          }),
 | 
						|
          py::keep_alive<0, 1>(), "Create a none type.");
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
} // namespace
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
// Populates the pybind11 IR submodule.
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
void mlir::python::populateIRSubmodule(py::module &m) {
 | 
						|
  // Mapping of MlirContext
 | 
						|
  py::class_<PyMlirContext>(m, "Context")
 | 
						|
      .def(py::init<>())
 | 
						|
      .def(
 | 
						|
          "parse_module",
 | 
						|
          [](PyMlirContext &self, const std::string module) {
 | 
						|
            auto moduleRef =
 | 
						|
                mlirModuleCreateParse(self.context, module.c_str());
 | 
						|
            // TODO: Rework error reporting once diagnostic engine is exposed
 | 
						|
            // in C API.
 | 
						|
            if (mlirModuleIsNull(moduleRef)) {
 | 
						|
              throw SetPyError(
 | 
						|
                  PyExc_ValueError,
 | 
						|
                  "Unable to parse module assembly (see diagnostics)");
 | 
						|
            }
 | 
						|
            return PyModule(moduleRef);
 | 
						|
          },
 | 
						|
          py::keep_alive<0, 1>(), kContextParseDocstring)
 | 
						|
      .def(
 | 
						|
          "parse_attr",
 | 
						|
          [](PyMlirContext &self, std::string attrSpec) {
 | 
						|
            MlirAttribute type =
 | 
						|
                mlirAttributeParseGet(self.context, attrSpec.c_str());
 | 
						|
            // TODO: Rework error reporting once diagnostic engine is exposed
 | 
						|
            // in C API.
 | 
						|
            if (mlirAttributeIsNull(type)) {
 | 
						|
              throw SetPyError(PyExc_ValueError,
 | 
						|
                               llvm::Twine("Unable to parse attribute: '") +
 | 
						|
                                   attrSpec + "'");
 | 
						|
            }
 | 
						|
            return PyAttribute(type);
 | 
						|
          },
 | 
						|
          py::keep_alive<0, 1>())
 | 
						|
      .def(
 | 
						|
          "parse_type",
 | 
						|
          [](PyMlirContext &self, std::string typeSpec) {
 | 
						|
            MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
 | 
						|
            // TODO: Rework error reporting once diagnostic engine is exposed
 | 
						|
            // in C API.
 | 
						|
            if (mlirTypeIsNull(type)) {
 | 
						|
              throw SetPyError(PyExc_ValueError,
 | 
						|
                               llvm::Twine("Unable to parse type: '") +
 | 
						|
                                   typeSpec + "'");
 | 
						|
            }
 | 
						|
            return PyType(type);
 | 
						|
          },
 | 
						|
          py::keep_alive<0, 1>(), kContextParseTypeDocstring)
 | 
						|
      .def(
 | 
						|
          "get_unknown_location",
 | 
						|
          [](PyMlirContext &self) {
 | 
						|
            return PyLocation(mlirLocationUnknownGet(self.context));
 | 
						|
          },
 | 
						|
          py::keep_alive<0, 1>(), kContextGetUnknownLocationDocstring)
 | 
						|
      .def(
 | 
						|
          "get_file_location",
 | 
						|
          [](PyMlirContext &self, std::string filename, int line, int col) {
 | 
						|
            return PyLocation(mlirLocationFileLineColGet(
 | 
						|
                self.context, filename.c_str(), line, col));
 | 
						|
          },
 | 
						|
          py::keep_alive<0, 1>(), kContextGetFileLocationDocstring,
 | 
						|
          py::arg("filename"), py::arg("line"), py::arg("col"))
 | 
						|
      .def(
 | 
						|
          "create_region",
 | 
						|
          [](PyMlirContext &self) {
 | 
						|
            // The creating context is explicitly captured on regions to
 | 
						|
            // facilitate illegal assemblies of objects from multiple contexts
 | 
						|
            // that would invalidate the memory model.
 | 
						|
            return PyRegion(self.context, mlirRegionCreate(),
 | 
						|
                            /*detached=*/true);
 | 
						|
          },
 | 
						|
          py::keep_alive<0, 1>(), kContextCreateRegionDocstring)
 | 
						|
      .def(
 | 
						|
          "create_block",
 | 
						|
          [](PyMlirContext &self, std::vector<PyType> pyTypes) {
 | 
						|
            // In order for the keep_alive extend the proper lifetime, all
 | 
						|
            // types must be from the same context.
 | 
						|
            for (auto pyType : pyTypes) {
 | 
						|
              if (!mlirContextEqual(mlirTypeGetContext(pyType.type),
 | 
						|
                                    self.context)) {
 | 
						|
                throw SetPyError(
 | 
						|
                    PyExc_ValueError,
 | 
						|
                    "All types used to construct a block must be from "
 | 
						|
                    "the same context as the block");
 | 
						|
              }
 | 
						|
            }
 | 
						|
            llvm::SmallVector<MlirType, 4> types(pyTypes.begin(),
 | 
						|
                                                 pyTypes.end());
 | 
						|
            return PyBlock(self.context,
 | 
						|
                           mlirBlockCreate(types.size(), &types[0]),
 | 
						|
                           /*detached=*/true);
 | 
						|
          },
 | 
						|
          py::keep_alive<0, 1>(), kContextCreateBlockDocstring);
 | 
						|
 | 
						|
  py::class_<PyLocation>(m, "Location").def("__repr__", [](PyLocation &self) {
 | 
						|
    PyPrintAccumulator printAccum;
 | 
						|
    mlirLocationPrint(self.loc, printAccum.getCallback(),
 | 
						|
                      printAccum.getUserData());
 | 
						|
    return printAccum.join();
 | 
						|
  });
 | 
						|
 | 
						|
  // Mapping of Module
 | 
						|
  py::class_<PyModule>(m, "Module")
 | 
						|
      .def(
 | 
						|
          "dump",
 | 
						|
          [](PyModule &self) {
 | 
						|
            mlirOperationDump(mlirModuleGetOperation(self.module));
 | 
						|
          },
 | 
						|
          kDumpDocstring)
 | 
						|
      .def(
 | 
						|
          "__str__",
 | 
						|
          [](PyModule &self) {
 | 
						|
            auto operation = mlirModuleGetOperation(self.module);
 | 
						|
            PyPrintAccumulator printAccum;
 | 
						|
            mlirOperationPrint(operation, printAccum.getCallback(),
 | 
						|
                               printAccum.getUserData());
 | 
						|
            return printAccum.join();
 | 
						|
          },
 | 
						|
          kOperationStrDunderDocstring);
 | 
						|
 | 
						|
  // Mapping of PyRegion.
 | 
						|
  py::class_<PyRegion>(m, "Region")
 | 
						|
      .def(
 | 
						|
          "append_block",
 | 
						|
          [](PyRegion &self, PyBlock &block) {
 | 
						|
            if (!mlirContextEqual(self.context, block.context)) {
 | 
						|
              throw SetPyError(
 | 
						|
                  PyExc_ValueError,
 | 
						|
                  "Block must have been created from the same context as "
 | 
						|
                  "this region");
 | 
						|
            }
 | 
						|
 | 
						|
            block.attachToParent();
 | 
						|
            mlirRegionAppendOwnedBlock(self.region, block.block);
 | 
						|
          },
 | 
						|
          kRegionAppendBlockDocstring)
 | 
						|
      .def(
 | 
						|
          "insert_block",
 | 
						|
          [](PyRegion &self, int pos, PyBlock &block) {
 | 
						|
            if (!mlirContextEqual(self.context, block.context)) {
 | 
						|
              throw SetPyError(
 | 
						|
                  PyExc_ValueError,
 | 
						|
                  "Block must have been created from the same context as "
 | 
						|
                  "this region");
 | 
						|
            }
 | 
						|
            block.attachToParent();
 | 
						|
            // TODO: Make this return a failure and raise if out of bounds.
 | 
						|
            mlirRegionInsertOwnedBlock(self.region, pos, block.block);
 | 
						|
          },
 | 
						|
          kRegionInsertBlockDocstring)
 | 
						|
      .def_property_readonly(
 | 
						|
          "first_block",
 | 
						|
          [](PyRegion &self) {
 | 
						|
            MlirBlock block = mlirRegionGetFirstBlock(self.region);
 | 
						|
            if (mlirBlockIsNull(block)) {
 | 
						|
              throw SetPyError(PyExc_IndexError, "Region has no blocks");
 | 
						|
            }
 | 
						|
            return PyBlock(self.context, block, /*detached=*/false);
 | 
						|
          },
 | 
						|
          kRegionFirstBlockDocstring);
 | 
						|
 | 
						|
  // Mapping of PyBlock.
 | 
						|
  py::class_<PyBlock>(m, "Block")
 | 
						|
      .def_property_readonly(
 | 
						|
          "next_in_region",
 | 
						|
          [](PyBlock &self) {
 | 
						|
            MlirBlock block = mlirBlockGetNextInRegion(self.block);
 | 
						|
            if (mlirBlockIsNull(block)) {
 | 
						|
              throw SetPyError(PyExc_IndexError,
 | 
						|
                               "Attempt to read past last block");
 | 
						|
            }
 | 
						|
            return PyBlock(self.context, block, /*detached=*/false);
 | 
						|
          },
 | 
						|
          py::keep_alive<0, 1>(), kBlockNextInRegionDocstring)
 | 
						|
      .def(
 | 
						|
          "__str__",
 | 
						|
          [](PyBlock &self) {
 | 
						|
            PyPrintAccumulator printAccum;
 | 
						|
            mlirBlockPrint(self.block, printAccum.getCallback(),
 | 
						|
                           printAccum.getUserData());
 | 
						|
            return printAccum.join();
 | 
						|
          },
 | 
						|
          kTypeStrDunderDocstring);
 | 
						|
 | 
						|
  // Mapping of Type.
 | 
						|
  py::class_<PyAttribute>(m, "Attribute")
 | 
						|
      .def(
 | 
						|
          "get_named",
 | 
						|
          [](PyAttribute &self, std::string name) {
 | 
						|
            return PyNamedAttribute(self.attr, std::move(name));
 | 
						|
          },
 | 
						|
          py::keep_alive<0, 1>(), "Binds a name to the attribute")
 | 
						|
      .def("__eq__",
 | 
						|
           [](PyAttribute &self, py::object &other) {
 | 
						|
             try {
 | 
						|
               PyAttribute otherAttribute = other.cast<PyAttribute>();
 | 
						|
               return self == otherAttribute;
 | 
						|
             } catch (std::exception &e) {
 | 
						|
               return false;
 | 
						|
             }
 | 
						|
           })
 | 
						|
      .def(
 | 
						|
          "dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); },
 | 
						|
          kDumpDocstring)
 | 
						|
      .def(
 | 
						|
          "__str__",
 | 
						|
          [](PyAttribute &self) {
 | 
						|
            PyPrintAccumulator printAccum;
 | 
						|
            mlirAttributePrint(self.attr, printAccum.getCallback(),
 | 
						|
                               printAccum.getUserData());
 | 
						|
            return printAccum.join();
 | 
						|
          },
 | 
						|
          kTypeStrDunderDocstring)
 | 
						|
      .def("__repr__", [](PyAttribute &self) {
 | 
						|
        // Generally, assembly formats are not printed for __repr__ because
 | 
						|
        // this can cause exceptionally long debug output and exceptions.
 | 
						|
        // However, attribute values are generally considered useful and are
 | 
						|
        // printed. This may need to be re-evaluated if debug dumps end up
 | 
						|
        // being excessive.
 | 
						|
        PyPrintAccumulator printAccum;
 | 
						|
        printAccum.parts.append("Attribute(");
 | 
						|
        mlirAttributePrint(self.attr, printAccum.getCallback(),
 | 
						|
                           printAccum.getUserData());
 | 
						|
        printAccum.parts.append(")");
 | 
						|
        return printAccum.join();
 | 
						|
      });
 | 
						|
 | 
						|
  py::class_<PyNamedAttribute>(m, "NamedAttribute")
 | 
						|
      .def("__repr__",
 | 
						|
           [](PyNamedAttribute &self) {
 | 
						|
             PyPrintAccumulator printAccum;
 | 
						|
             printAccum.parts.append("NamedAttribute(");
 | 
						|
             printAccum.parts.append(self.namedAttr.name);
 | 
						|
             printAccum.parts.append("=");
 | 
						|
             mlirAttributePrint(self.namedAttr.attribute,
 | 
						|
                                printAccum.getCallback(),
 | 
						|
                                printAccum.getUserData());
 | 
						|
             printAccum.parts.append(")");
 | 
						|
             return printAccum.join();
 | 
						|
           })
 | 
						|
      .def_property_readonly(
 | 
						|
          "name",
 | 
						|
          [](PyNamedAttribute &self) {
 | 
						|
            return py::str(self.namedAttr.name, strlen(self.namedAttr.name));
 | 
						|
          },
 | 
						|
          "The name of the NamedAttribute binding")
 | 
						|
      .def_property_readonly(
 | 
						|
          "attr",
 | 
						|
          [](PyNamedAttribute &self) {
 | 
						|
            return PyAttribute(self.namedAttr.attribute);
 | 
						|
          },
 | 
						|
          py::keep_alive<0, 1>(),
 | 
						|
          "The underlying generic attribute of the NamedAttribute binding");
 | 
						|
 | 
						|
  // Standard attribute bindings.
 | 
						|
  PyStringAttribute::bind(m);
 | 
						|
 | 
						|
  // Mapping of Type.
 | 
						|
  py::class_<PyType>(m, "Type")
 | 
						|
      .def("__eq__",
 | 
						|
           [](PyType &self, py::object &other) {
 | 
						|
             try {
 | 
						|
               PyType otherType = other.cast<PyType>();
 | 
						|
               return self == otherType;
 | 
						|
             } catch (std::exception &e) {
 | 
						|
               return false;
 | 
						|
             }
 | 
						|
           })
 | 
						|
      .def(
 | 
						|
          "dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring)
 | 
						|
      .def(
 | 
						|
          "__str__",
 | 
						|
          [](PyType &self) {
 | 
						|
            PyPrintAccumulator printAccum;
 | 
						|
            mlirTypePrint(self.type, printAccum.getCallback(),
 | 
						|
                          printAccum.getUserData());
 | 
						|
            return printAccum.join();
 | 
						|
          },
 | 
						|
          kTypeStrDunderDocstring)
 | 
						|
      .def("__repr__", [](PyType &self) {
 | 
						|
        // Generally, assembly formats are not printed for __repr__ because
 | 
						|
        // this can cause exceptionally long debug output and exceptions.
 | 
						|
        // However, types are an exception as they typically have compact
 | 
						|
        // assembly forms and printing them is useful.
 | 
						|
        PyPrintAccumulator printAccum;
 | 
						|
        printAccum.parts.append("Type(");
 | 
						|
        mlirTypePrint(self.type, printAccum.getCallback(),
 | 
						|
                      printAccum.getUserData());
 | 
						|
        printAccum.parts.append(")");
 | 
						|
        return printAccum.join();
 | 
						|
      });
 | 
						|
 | 
						|
  // Standard type bindings.
 | 
						|
  PyIntegerType::bind(m);
 | 
						|
  PyIndexType::bind(m);
 | 
						|
  PyBF16Type::bind(m);
 | 
						|
  PyF16Type::bind(m);
 | 
						|
  PyF32Type::bind(m);
 | 
						|
  PyF64Type::bind(m);
 | 
						|
  PyNoneType::bind(m);
 | 
						|
}
 |