152 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			152 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			C++
		
	
	
	
//===- IRModule.cpp - IR pybind module ------------------------------------===//
 | 
						|
//
 | 
						|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 | 
						|
// See https://llvm.org/LICENSE.txt for license information.
 | 
						|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 | 
						|
//
 | 
						|
//===----------------------------------------------------------------------===//
 | 
						|
 | 
						|
#include "IRModule.h"
 | 
						|
#include "Globals.h"
 | 
						|
#include "PybindUtils.h"
 | 
						|
 | 
						|
#include <vector>
 | 
						|
 | 
						|
#include "mlir-c/Bindings/Python/Interop.h"
 | 
						|
 | 
						|
namespace py = pybind11;
 | 
						|
using namespace mlir;
 | 
						|
using namespace mlir::python;
 | 
						|
 | 
						|
// -----------------------------------------------------------------------------
 | 
						|
// PyGlobals
 | 
						|
// -----------------------------------------------------------------------------
 | 
						|
 | 
						|
PyGlobals *PyGlobals::instance = nullptr;
 | 
						|
 | 
						|
PyGlobals::PyGlobals() {
 | 
						|
  assert(!instance && "PyGlobals already constructed");
 | 
						|
  instance = this;
 | 
						|
  // The default search path include {mlir.}dialects, where {mlir.} is the
 | 
						|
  // package prefix configured at compile time.
 | 
						|
  dialectSearchPrefixes.push_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
 | 
						|
}
 | 
						|
 | 
						|
PyGlobals::~PyGlobals() { instance = nullptr; }
 | 
						|
 | 
						|
void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
 | 
						|
  py::gil_scoped_acquire();
 | 
						|
  if (loadedDialectModulesCache.contains(dialectNamespace))
 | 
						|
    return;
 | 
						|
  // Since re-entrancy is possible, make a copy of the search prefixes.
 | 
						|
  std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
 | 
						|
  py::object loaded;
 | 
						|
  for (std::string moduleName : localSearchPrefixes) {
 | 
						|
    moduleName.push_back('.');
 | 
						|
    moduleName.append(dialectNamespace.data(), dialectNamespace.size());
 | 
						|
 | 
						|
    try {
 | 
						|
      py::gil_scoped_release();
 | 
						|
      loaded = py::module::import(moduleName.c_str());
 | 
						|
    } catch (py::error_already_set &e) {
 | 
						|
      if (e.matches(PyExc_ModuleNotFoundError)) {
 | 
						|
        continue;
 | 
						|
      } else {
 | 
						|
        throw;
 | 
						|
      }
 | 
						|
    }
 | 
						|
    break;
 | 
						|
  }
 | 
						|
 | 
						|
  // Note: Iterator cannot be shared from prior to loading, since re-entrancy
 | 
						|
  // may have occurred, which may do anything.
 | 
						|
  loadedDialectModulesCache.insert(dialectNamespace);
 | 
						|
}
 | 
						|
 | 
						|
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
 | 
						|
                                    py::object pyClass) {
 | 
						|
  py::gil_scoped_acquire();
 | 
						|
  py::object &found = dialectClassMap[dialectNamespace];
 | 
						|
  if (found) {
 | 
						|
    throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
 | 
						|
                                             dialectNamespace +
 | 
						|
                                             "' is already registered.");
 | 
						|
  }
 | 
						|
  found = std::move(pyClass);
 | 
						|
}
 | 
						|
 | 
						|
void PyGlobals::registerOperationImpl(const std::string &operationName,
 | 
						|
                                      py::object pyClass,
 | 
						|
                                      py::object rawOpViewClass) {
 | 
						|
  py::gil_scoped_acquire();
 | 
						|
  py::object &found = operationClassMap[operationName];
 | 
						|
  if (found) {
 | 
						|
    throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
 | 
						|
                                             operationName +
 | 
						|
                                             "' is already registered.");
 | 
						|
  }
 | 
						|
  found = std::move(pyClass);
 | 
						|
  rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
 | 
						|
}
 | 
						|
 | 
						|
llvm::Optional<py::object>
 | 
						|
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
 | 
						|
  py::gil_scoped_acquire();
 | 
						|
  loadDialectModule(dialectNamespace);
 | 
						|
  // Fast match against the class map first (common case).
 | 
						|
  const auto foundIt = dialectClassMap.find(dialectNamespace);
 | 
						|
  if (foundIt != dialectClassMap.end()) {
 | 
						|
    if (foundIt->second.is_none())
 | 
						|
      return llvm::None;
 | 
						|
    assert(foundIt->second && "py::object is defined");
 | 
						|
    return foundIt->second;
 | 
						|
  }
 | 
						|
 | 
						|
  // Not found and loading did not yield a registration. Negative cache.
 | 
						|
  dialectClassMap[dialectNamespace] = py::none();
 | 
						|
  return llvm::None;
 | 
						|
}
 | 
						|
 | 
						|
llvm::Optional<pybind11::object>
 | 
						|
PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
 | 
						|
  {
 | 
						|
    py::gil_scoped_acquire();
 | 
						|
    auto foundIt = rawOpViewClassMapCache.find(operationName);
 | 
						|
    if (foundIt != rawOpViewClassMapCache.end()) {
 | 
						|
      if (foundIt->second.is_none())
 | 
						|
        return llvm::None;
 | 
						|
      assert(foundIt->second && "py::object is defined");
 | 
						|
      return foundIt->second;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  // Not found. Load the dialect namespace.
 | 
						|
  auto split = operationName.split('.');
 | 
						|
  llvm::StringRef dialectNamespace = split.first;
 | 
						|
  loadDialectModule(dialectNamespace);
 | 
						|
 | 
						|
  // Attempt to find from the canonical map and cache.
 | 
						|
  {
 | 
						|
    py::gil_scoped_acquire();
 | 
						|
    auto foundIt = rawOpViewClassMap.find(operationName);
 | 
						|
    if (foundIt != rawOpViewClassMap.end()) {
 | 
						|
      if (foundIt->second.is_none())
 | 
						|
        return llvm::None;
 | 
						|
      assert(foundIt->second && "py::object is defined");
 | 
						|
      // Positive cache.
 | 
						|
      rawOpViewClassMapCache[operationName] = foundIt->second;
 | 
						|
      return foundIt->second;
 | 
						|
    } else {
 | 
						|
      // Negative cache.
 | 
						|
      rawOpViewClassMap[operationName] = py::none();
 | 
						|
      return llvm::None;
 | 
						|
    }
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
void PyGlobals::clearImportCache() {
 | 
						|
  py::gil_scoped_acquire();
 | 
						|
  loadedDialectModulesCache.clear();
 | 
						|
  rawOpViewClassMapCache.clear();
 | 
						|
}
 |