111 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			111 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			C++
		
	
	
	
//===- MainModule.cpp - Main pybind module --------------------------------===//
 | 
						|
//
 | 
						|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
						|
// See https://llvm.org/LICENSE.txt for license information.
 | 
						|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
						|
//
 | 
						|
//===----------------------------------------------------------------------===//
 | 
						|
 | 
						|
#include <tuple>
 | 
						|
 | 
						|
#include "PybindUtils.h"
 | 
						|
 | 
						|
#include "Dialects.h"
 | 
						|
#include "Globals.h"
 | 
						|
#include "IRModule.h"
 | 
						|
#include "Pass.h"
 | 
						|
 | 
						|
namespace py = pybind11;
 | 
						|
using namespace mlir;
 | 
						|
using namespace mlir::python;
 | 
						|
 | 
						|
// -----------------------------------------------------------------------------
 | 
						|
// Module initialization.
 | 
						|
// -----------------------------------------------------------------------------
 | 
						|
 | 
						|
PYBIND11_MODULE(_mlir, m) {
 | 
						|
  m.doc() = "MLIR Python Native Extension";
 | 
						|
 | 
						|
  py::class_<PyGlobals>(m, "_Globals", py::module_local())
 | 
						|
      .def_property("dialect_search_modules",
 | 
						|
                    &PyGlobals::getDialectSearchPrefixes,
 | 
						|
                    &PyGlobals::setDialectSearchPrefixes)
 | 
						|
      .def(
 | 
						|
          "append_dialect_search_prefix",
 | 
						|
          [](PyGlobals &self, std::string moduleName) {
 | 
						|
            self.getDialectSearchPrefixes().push_back(std::move(moduleName));
 | 
						|
            self.clearImportCache();
 | 
						|
          },
 | 
						|
          py::arg("module_name"))
 | 
						|
      .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
 | 
						|
           py::arg("dialect_namespace"), py::arg("dialect_class"),
 | 
						|
           "Testing hook for directly registering a dialect")
 | 
						|
      .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
 | 
						|
           py::arg("operation_name"), py::arg("operation_class"),
 | 
						|
           py::arg("raw_opview_class"),
 | 
						|
           "Testing hook for directly registering an operation");
 | 
						|
 | 
						|
  // Aside from making the globals accessible to python, having python manage
 | 
						|
  // it is necessary to make sure it is destroyed (and releases its python
 | 
						|
  // resources) properly.
 | 
						|
  m.attr("globals") =
 | 
						|
      py::cast(new PyGlobals, py::return_value_policy::take_ownership);
 | 
						|
 | 
						|
  // Registration decorators.
 | 
						|
  m.def(
 | 
						|
      "register_dialect",
 | 
						|
      [](py::object pyClass) {
 | 
						|
        std::string dialectNamespace =
 | 
						|
            pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
 | 
						|
        PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
 | 
						|
        return pyClass;
 | 
						|
      },
 | 
						|
      py::arg("dialect_class"),
 | 
						|
      "Class decorator for registering a custom Dialect wrapper");
 | 
						|
  m.def(
 | 
						|
      "register_operation",
 | 
						|
      [](py::object dialectClass) -> py::cpp_function {
 | 
						|
        return py::cpp_function(
 | 
						|
            [dialectClass](py::object opClass) -> py::object {
 | 
						|
              std::string operationName =
 | 
						|
                  opClass.attr("OPERATION_NAME").cast<std::string>();
 | 
						|
              auto rawSubclass = PyOpView::createRawSubclass(opClass);
 | 
						|
              PyGlobals::get().registerOperationImpl(operationName, opClass,
 | 
						|
                                                     rawSubclass);
 | 
						|
 | 
						|
              // Dict-stuff the new opClass by name onto the dialect class.
 | 
						|
              py::object opClassName = opClass.attr("__name__");
 | 
						|
              dialectClass.attr(opClassName) = opClass;
 | 
						|
 | 
						|
              // Now create a special "Raw" subclass that passes through
 | 
						|
              // construction to the OpView parent (bypasses the intermediate
 | 
						|
              // child's __init__).
 | 
						|
              opClass.attr("_Raw") = rawSubclass;
 | 
						|
              return opClass;
 | 
						|
            });
 | 
						|
      },
 | 
						|
      py::arg("dialect_class"),
 | 
						|
      "Produce a class decorator for registering an Operation class as part of "
 | 
						|
      "a dialect");
 | 
						|
 | 
						|
  // Define and populate IR submodule.
 | 
						|
  auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
 | 
						|
  populateIRCore(irModule);
 | 
						|
  populateIRAffine(irModule);
 | 
						|
  populateIRAttributes(irModule);
 | 
						|
  populateIRInterfaces(irModule);
 | 
						|
  populateIRTypes(irModule);
 | 
						|
 | 
						|
  // Define and populate PassManager submodule.
 | 
						|
  auto passModule =
 | 
						|
      m.def_submodule("passmanager", "MLIR Pass Management Bindings");
 | 
						|
  populatePassManagerSubmodule(passModule);
 | 
						|
 | 
						|
  // Define and populate dialect submodules.
 | 
						|
  auto dialectsModule = m.def_submodule("dialects");
 | 
						|
  auto linalgModule = dialectsModule.def_submodule("linalg");
 | 
						|
  populateDialectLinalgSubmodule(linalgModule);
 | 
						|
  populateDialectSparseTensorSubmodule(
 | 
						|
      dialectsModule.def_submodule("sparse_tensor"), irModule);
 | 
						|
}
 |