mirror of https://github.com/llvm/circt.git
488 lines
18 KiB
C++
488 lines
18 KiB
C++
//===- OMModule.cpp - OM API nanobind module ------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "CIRCTModules.h"
|
|
#include "circt-c/Dialect/HW.h"
|
|
#include "circt-c/Dialect/OM.h"
|
|
#include "mlir-c/BuiltinAttributes.h"
|
|
#include "mlir-c/BuiltinTypes.h"
|
|
#include "mlir-c/IR.h"
|
|
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
|
#include <nanobind/nanobind.h>
|
|
#include <nanobind/stl/variant.h>
|
|
#include <nanobind/stl/vector.h>
|
|
namespace nb = nanobind;
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::python;
|
|
using namespace mlir::python::nanobind_adaptors;
|
|
|
|
namespace {
|
|
|
|
struct List;
|
|
struct Object;
|
|
struct BasePath;
|
|
struct Path;
|
|
|
|
/// These are the Python types that are represented by the different primitive
|
|
/// OMEvaluatorValues as Attributes.
|
|
using PythonPrimitive = std::variant<nb::int_, nb::float_, nb::str, nb::bool_,
|
|
nb::tuple, nb::list, nb::dict>;
|
|
|
|
/// None is used to by nanobind when default initializing a PythonValue. The
|
|
/// order of types in the variant matters here, and we want nanobind to try
|
|
/// casting to the Python classes defined in this file first, before
|
|
/// MlirAttribute and the upstream MLIR type casters. If the MlirAttribute
|
|
/// is tried first, then we can hit an assert inside the MLIR codebase.
|
|
struct None {};
|
|
using PythonValue =
|
|
std::variant<None, Object, List, BasePath, Path, PythonPrimitive>;
|
|
|
|
/// Map an opaque OMEvaluatorValue into a python value.
|
|
PythonValue omEvaluatorValueToPythonValue(OMEvaluatorValue result);
|
|
OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result,
|
|
MlirContext ctx);
|
|
static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr);
|
|
static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value,
|
|
MlirContext ctx);
|
|
|
|
/// Provides a List class by simply wrapping the OMObject CAPI.
|
|
struct List {
|
|
// Instantiate a List with a reference to the underlying OMEvaluatorValue.
|
|
List(OMEvaluatorValue value) : value(value) {}
|
|
|
|
/// Return the number of elements.
|
|
intptr_t getNumElements() { return omEvaluatorListGetNumElements(value); }
|
|
|
|
PythonValue getElement(intptr_t i);
|
|
OMEvaluatorValue getValue() const { return value; }
|
|
|
|
private:
|
|
// The underlying CAPI value.
|
|
OMEvaluatorValue value;
|
|
};
|
|
|
|
/// Provides a BasePath class by simply wrapping the OMObject CAPI.
|
|
struct BasePath {
|
|
/// Instantiate a BasePath with a reference to the underlying
|
|
/// OMEvaluatorValue.
|
|
BasePath(OMEvaluatorValue value) : value(value) {}
|
|
|
|
static BasePath getEmpty(MlirContext context) {
|
|
return BasePath(omEvaluatorBasePathGetEmpty(context));
|
|
}
|
|
|
|
/// Return a context from an underlying value.
|
|
MlirContext getContext() const { return omEvaluatorValueGetContext(value); }
|
|
|
|
OMEvaluatorValue getValue() const { return value; }
|
|
|
|
private:
|
|
// The underlying CAPI value.
|
|
OMEvaluatorValue value;
|
|
};
|
|
|
|
/// Provides a Path class by simply wrapping the OMObject CAPI.
|
|
struct Path {
|
|
/// Instantiate a Path with a reference to the underlying OMEvaluatorValue.
|
|
Path(OMEvaluatorValue value) : value(value) {}
|
|
|
|
/// Return a context from an underlying value.
|
|
MlirContext getContext() const { return omEvaluatorValueGetContext(value); }
|
|
|
|
OMEvaluatorValue getValue() const { return value; }
|
|
|
|
std::string dunderStr() {
|
|
auto ref = mlirStringAttrGetValue(omEvaluatorPathGetAsString(getValue()));
|
|
return std::string(ref.data, ref.length);
|
|
}
|
|
|
|
private:
|
|
// The underlying CAPI value.
|
|
OMEvaluatorValue value;
|
|
};
|
|
|
|
/// Provides an Object class by simply wrapping the OMObject CAPI.
|
|
struct Object {
|
|
// Instantiate an Object with a reference to the underlying OMObject.
|
|
Object(OMEvaluatorValue value) : value(value) {}
|
|
|
|
/// Get the Type from an Object, which will be a ClassType.
|
|
MlirType getType() { return omEvaluatorObjectGetType(value); }
|
|
|
|
/// Get the Location from an Object, which will be an MlirLocation.
|
|
MlirLocation getLocation() { return omEvaluatorValueGetLoc(value); }
|
|
|
|
// Get the field location info.
|
|
MlirLocation getFieldLoc(const std::string &name) {
|
|
// Wrap the requested field name in an attribute.
|
|
MlirContext context = mlirTypeGetContext(omEvaluatorObjectGetType(value));
|
|
MlirStringRef cName = mlirStringRefCreateFromCString(name.c_str());
|
|
MlirAttribute nameAttr = mlirStringAttrGet(context, cName);
|
|
|
|
// Get the field's ObjectValue via the CAPI.
|
|
OMEvaluatorValue result = omEvaluatorObjectGetField(value, nameAttr);
|
|
|
|
return omEvaluatorValueGetLoc(result);
|
|
}
|
|
|
|
// Get a field from the Object, using nanobind's support for variant to return
|
|
// a Python object that is either an Object or Attribute.
|
|
PythonValue getField(const std::string &name) {
|
|
// Wrap the requested field name in an attribute.
|
|
MlirContext context = mlirTypeGetContext(omEvaluatorObjectGetType(value));
|
|
MlirStringRef cName = mlirStringRefCreateFromCString(name.c_str());
|
|
MlirAttribute nameAttr = mlirStringAttrGet(context, cName);
|
|
|
|
// Get the field's ObjectValue via the CAPI.
|
|
OMEvaluatorValue result = omEvaluatorObjectGetField(value, nameAttr);
|
|
|
|
return omEvaluatorValueToPythonValue(result);
|
|
}
|
|
|
|
// Get a list with the names of all the fields in the Object.
|
|
std::vector<std::string> getFieldNames() {
|
|
MlirAttribute fieldNames = omEvaluatorObjectGetFieldNames(value);
|
|
intptr_t numFieldNames = mlirArrayAttrGetNumElements(fieldNames);
|
|
|
|
std::vector<std::string> pyFieldNames;
|
|
for (intptr_t i = 0; i < numFieldNames; ++i) {
|
|
MlirAttribute fieldName = mlirArrayAttrGetElement(fieldNames, i);
|
|
MlirStringRef fieldNameStr = mlirStringAttrGetValue(fieldName);
|
|
pyFieldNames.emplace_back(fieldNameStr.data, fieldNameStr.length);
|
|
}
|
|
|
|
return pyFieldNames;
|
|
}
|
|
|
|
// Get the hash of the object
|
|
unsigned getHash() { return omEvaluatorObjectGetHash(value); }
|
|
|
|
// Check the equality of the underlying values.
|
|
bool eq(Object &other) { return omEvaluatorObjectIsEq(value, other.value); }
|
|
|
|
OMEvaluatorValue getValue() const { return value; }
|
|
|
|
private:
|
|
// The underlying CAPI OMObject.
|
|
OMEvaluatorValue value;
|
|
};
|
|
|
|
/// Provides an Evaluator class by simply wrapping the OMEvaluator CAPI.
|
|
struct Evaluator {
|
|
// Instantiate an Evaluator with a reference to the underlying OMEvaluator.
|
|
Evaluator(MlirModule mod) : evaluator(omEvaluatorNew(mod)) {}
|
|
|
|
// Instantiate an Object.
|
|
Object instantiate(MlirAttribute className,
|
|
std::vector<PythonValue> actualParams) {
|
|
std::vector<OMEvaluatorValue> values;
|
|
for (auto ¶m : actualParams)
|
|
values.push_back(pythonValueToOMEvaluatorValue(
|
|
param, mlirModuleGetContext(getModule())));
|
|
|
|
// Instantiate the Object via the CAPI.
|
|
OMEvaluatorValue result = omEvaluatorInstantiate(
|
|
evaluator, className, values.size(), values.data());
|
|
|
|
// If the Object is null, something failed. Diagnostic handling is
|
|
// implemented in pure Python, so nothing to do here besides throwing an
|
|
// error to halt execution.
|
|
if (omEvaluatorObjectIsNull(result))
|
|
throw nb::value_error(
|
|
"unable to instantiate object, see previous error(s)");
|
|
|
|
// Return a new Object.
|
|
return Object(result);
|
|
}
|
|
|
|
// Get the Module the Evaluator is built from.
|
|
MlirModule getModule() { return omEvaluatorGetModule(evaluator); }
|
|
|
|
private:
|
|
// The underlying CAPI OMEvaluator.
|
|
OMEvaluator evaluator;
|
|
};
|
|
|
|
class PyListAttrIterator {
|
|
public:
|
|
PyListAttrIterator(MlirAttribute attr) : attr(std::move(attr)) {}
|
|
|
|
PyListAttrIterator &dunderIter() { return *this; }
|
|
|
|
MlirAttribute dunderNext() {
|
|
if (nextIndex >= omListAttrGetNumElements(attr))
|
|
throw nb::stop_iteration();
|
|
return omListAttrGetElement(attr, nextIndex++);
|
|
}
|
|
|
|
static void bind(nb::module_ &m) {
|
|
nb::class_<PyListAttrIterator>(m, "ListAttributeIterator")
|
|
.def("__iter__", &PyListAttrIterator::dunderIter)
|
|
.def("__next__", &PyListAttrIterator::dunderNext);
|
|
}
|
|
|
|
private:
|
|
MlirAttribute attr;
|
|
intptr_t nextIndex = 0;
|
|
};
|
|
|
|
PythonValue List::getElement(intptr_t i) {
|
|
return omEvaluatorValueToPythonValue(omEvaluatorListGetElement(value, i));
|
|
}
|
|
// Convert a generic MLIR Attribute to a PythonValue. This is basically a C++
|
|
// fast path of the parts of attribute_to_var that we use in the OM dialect.
|
|
static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) {
|
|
if (omAttrIsAIntegerAttr(attr)) {
|
|
auto strRef = omIntegerAttrToString(attr);
|
|
return nb::int_(nb::str(strRef.data, strRef.length));
|
|
}
|
|
|
|
if (mlirAttributeIsAFloat(attr)) {
|
|
return nb::float_(mlirFloatAttrGetValueDouble(attr));
|
|
}
|
|
|
|
if (mlirAttributeIsAString(attr)) {
|
|
auto strRef = mlirStringAttrGetValue(attr);
|
|
return nb::str(strRef.data, strRef.length);
|
|
}
|
|
|
|
// BoolAttr's are IntegerAttr's, check this first.
|
|
if (mlirAttributeIsABool(attr)) {
|
|
return nb::bool_(mlirBoolAttrGetValue(attr));
|
|
}
|
|
|
|
if (mlirAttributeIsAInteger(attr)) {
|
|
MlirType type = mlirAttributeGetType(attr);
|
|
if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
|
|
return nb::int_(mlirIntegerAttrGetValueInt(attr));
|
|
if (mlirIntegerTypeIsSigned(type))
|
|
return nb::int_(mlirIntegerAttrGetValueSInt(attr));
|
|
return nb::int_(mlirIntegerAttrGetValueUInt(attr));
|
|
}
|
|
|
|
if (omAttrIsAReferenceAttr(attr)) {
|
|
auto innerRef = omReferenceAttrGetInnerRef(attr);
|
|
auto moduleStrRef =
|
|
mlirStringAttrGetValue(hwInnerRefAttrGetModule(innerRef));
|
|
auto nameStrRef = mlirStringAttrGetValue(hwInnerRefAttrGetName(innerRef));
|
|
auto moduleStr = nb::str(moduleStrRef.data, moduleStrRef.length);
|
|
auto nameStr = nb::str(nameStrRef.data, nameStrRef.length);
|
|
return nb::make_tuple(moduleStr, nameStr);
|
|
}
|
|
|
|
if (omAttrIsAListAttr(attr)) {
|
|
nb::list results;
|
|
for (intptr_t i = 0, e = omListAttrGetNumElements(attr); i < e; ++i)
|
|
results.append(omPrimitiveToPythonValue(omListAttrGetElement(attr, i)));
|
|
return results;
|
|
}
|
|
|
|
mlirAttributeDump(attr);
|
|
throw nb::type_error("Unexpected OM primitive attribute");
|
|
}
|
|
|
|
// Convert a primitive PythonValue to a generic MLIR Attribute. This is
|
|
// basically a C++ fast path of the parts of var_to_attribute that we use in the
|
|
// OM dialect.
|
|
static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value,
|
|
MlirContext ctx) {
|
|
if (auto *intValue = std::get_if<nb::int_>(&value)) {
|
|
auto intType = mlirIntegerTypeGet(ctx, 64);
|
|
auto intAttr = mlirIntegerAttrGet(intType, nb::cast<int64_t>(*intValue));
|
|
return omIntegerAttrGet(intAttr);
|
|
}
|
|
|
|
if (auto *attr = std::get_if<nb::float_>(&value)) {
|
|
auto floatType = mlirF64TypeGet(ctx);
|
|
return mlirFloatAttrDoubleGet(ctx, floatType, nb::cast<double>(*attr));
|
|
}
|
|
|
|
if (auto *attr = std::get_if<nb::str>(&value)) {
|
|
auto str = nb::cast<std::string>(*attr);
|
|
auto strRef = mlirStringRefCreate(str.data(), str.length());
|
|
auto omStringType = omStringTypeGet(ctx);
|
|
return mlirStringAttrTypedGet(omStringType, strRef);
|
|
}
|
|
|
|
if (auto *attr = std::get_if<nb::bool_>(&value)) {
|
|
return mlirBoolAttrGet(ctx, nb::cast<bool>(*attr));
|
|
}
|
|
|
|
// For a python list try constructing OM list attribute. The element
|
|
// type must be uniform.
|
|
if (auto *attr = std::get_if<nb::list>(&value)) {
|
|
if (attr->size() == 0)
|
|
throw nb::type_error("Empty list is prohibited now");
|
|
|
|
std::vector<MlirAttribute> attrs;
|
|
attrs.reserve(attr->size());
|
|
std::optional<MlirType> elemenType;
|
|
for (auto v : *attr) {
|
|
attrs.push_back(
|
|
omPythonValueToPrimitive(nb::cast<PythonPrimitive>(v), ctx));
|
|
if (!elemenType)
|
|
elemenType = mlirAttributeGetType(attrs.back());
|
|
else if (!mlirTypeEqual(*elemenType,
|
|
mlirAttributeGetType(attrs.back()))) {
|
|
throw nb::type_error("List elements must be of the same type");
|
|
}
|
|
}
|
|
return omListAttrGet(*elemenType, attrs.size(), attrs.data());
|
|
}
|
|
|
|
throw nb::type_error("Unexpected OM primitive value");
|
|
}
|
|
|
|
PythonValue omEvaluatorValueToPythonValue(OMEvaluatorValue result) {
|
|
// If the result is null, something failed. Diagnostic handling is
|
|
// implemented in pure Python, so nothing to do here besides throwing an
|
|
// error to halt execution.
|
|
if (omEvaluatorValueIsNull(result))
|
|
throw nb::value_error("unable to get field, see previous error(s)");
|
|
|
|
// If the field was an Object, return a new Object.
|
|
if (omEvaluatorValueIsAObject(result))
|
|
return Object(result);
|
|
|
|
// If the field was a list, return a new List.
|
|
if (omEvaluatorValueIsAList(result))
|
|
return List(result);
|
|
|
|
// If the field was a base path, return a new BasePath.
|
|
if (omEvaluatorValueIsABasePath(result))
|
|
return BasePath(result);
|
|
|
|
// If the field was a path, return a new Path.
|
|
if (omEvaluatorValueIsAPath(result))
|
|
return Path(result);
|
|
|
|
if (omEvaluatorValueIsAReference(result))
|
|
return omEvaluatorValueToPythonValue(
|
|
omEvaluatorValueGetReferenceValue(result));
|
|
|
|
// If the field was a primitive, return the Attribute.
|
|
assert(omEvaluatorValueIsAPrimitive(result));
|
|
return omPrimitiveToPythonValue(omEvaluatorValueGetPrimitive(result));
|
|
}
|
|
|
|
OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result,
|
|
MlirContext ctx) {
|
|
if (auto *list = std::get_if<List>(&result))
|
|
return list->getValue();
|
|
|
|
if (auto *basePath = std::get_if<BasePath>(&result))
|
|
return basePath->getValue();
|
|
|
|
if (auto *path = std::get_if<Path>(&result))
|
|
return path->getValue();
|
|
|
|
if (auto *object = std::get_if<Object>(&result))
|
|
return object->getValue();
|
|
|
|
auto primitive = std::get<PythonPrimitive>(result);
|
|
return omEvaluatorValueFromPrimitive(
|
|
omPythonValueToPrimitive(primitive, ctx));
|
|
}
|
|
|
|
} // namespace
|
|
|
|
/// Populate the OM Python module.
|
|
void circt::python::populateDialectOMSubmodule(nb::module_ &m) {
|
|
m.doc() = "OM dialect Python native extension";
|
|
|
|
// Add the Evaluator class definition.
|
|
nb::class_<Evaluator>(m, "Evaluator")
|
|
.def(nb::init<MlirModule>(), nb::arg("module"))
|
|
.def("instantiate", &Evaluator::instantiate, "Instantiate an Object",
|
|
nb::arg("class_name"), nb::arg("actual_params"))
|
|
.def_prop_ro("module", &Evaluator::getModule,
|
|
"The Module the Evaluator is built from");
|
|
|
|
// Add the List class definition.
|
|
nb::class_<List>(m, "List")
|
|
.def(nb::init<List>(), nb::arg("list"))
|
|
.def("__getitem__", &List::getElement)
|
|
.def("__len__", &List::getNumElements);
|
|
|
|
// Add the BasePath class definition.
|
|
nb::class_<BasePath>(m, "BasePath")
|
|
.def(nb::init<BasePath>(), nb::arg("basepath"))
|
|
.def_static("get_empty", &BasePath::getEmpty,
|
|
nb::arg("context") = nb::none());
|
|
|
|
// Add the Path class definition.
|
|
nb::class_<Path>(m, "Path")
|
|
.def(nb::init<Path>(), nb::arg("path"))
|
|
.def("__str__", &Path::dunderStr);
|
|
|
|
// Add the Object class definition.
|
|
nb::class_<Object>(m, "Object")
|
|
.def(nb::init<Object>(), nb::arg("object"))
|
|
.def("__getattr__", &Object::getField, "Get a field from an Object",
|
|
nb::arg("name"))
|
|
.def("get_field_loc", &Object::getFieldLoc,
|
|
"Get the location of a field from an Object", nb::arg("name"))
|
|
.def_prop_ro("field_names", &Object::getFieldNames,
|
|
"Get field names from an Object")
|
|
.def_prop_ro("type", &Object::getType, "The Type of the Object")
|
|
.def_prop_ro("loc", &Object::getLocation, "The Location of the Object")
|
|
.def("__hash__", &Object::getHash, "Get object hash")
|
|
.def("__eq__", &Object::eq, "Check if two objects are same");
|
|
|
|
// Add the ReferenceAttr definition
|
|
mlir_attribute_subclass(m, "ReferenceAttr", omAttrIsAReferenceAttr)
|
|
.def_property_readonly("inner_ref", [](MlirAttribute self) {
|
|
return omReferenceAttrGetInnerRef(self);
|
|
});
|
|
|
|
// Add the IntegerAttr definition
|
|
mlir_attribute_subclass(m, "OMIntegerAttr", omAttrIsAIntegerAttr)
|
|
.def_classmethod("get",
|
|
[](nb::object cls, MlirAttribute intVal) {
|
|
return cls(omIntegerAttrGet(intVal));
|
|
})
|
|
.def_property_readonly(
|
|
"integer",
|
|
[](MlirAttribute self) { return omIntegerAttrGetInt(self); })
|
|
.def("__str__", [](MlirAttribute self) {
|
|
MlirStringRef str = omIntegerAttrToString(self);
|
|
return std::string(str.data, str.length);
|
|
});
|
|
|
|
// Add the OMListAttr definition
|
|
mlir_attribute_subclass(m, "ListAttr", omAttrIsAListAttr)
|
|
.def("__getitem__", &omListAttrGetElement)
|
|
.def("__len__", &omListAttrGetNumElements)
|
|
.def("__iter__",
|
|
[](MlirAttribute arr) { return PyListAttrIterator(arr); });
|
|
PyListAttrIterator::bind(m);
|
|
|
|
// Add the AnyType class definition.
|
|
mlir_type_subclass(m, "AnyType", omTypeIsAAnyType, omAnyTypeGetTypeID);
|
|
|
|
// Add the ClassType class definition.
|
|
mlir_type_subclass(m, "ClassType", omTypeIsAClassType, omClassTypeGetTypeID)
|
|
.def_property_readonly("name", [](MlirType type) {
|
|
MlirStringRef name = mlirIdentifierStr(omClassTypeGetName(type));
|
|
return std::string(name.data, name.length);
|
|
});
|
|
|
|
// Add the BasePathType class definition.
|
|
mlir_type_subclass(m, "BasePathType", omTypeIsAFrozenBasePathType,
|
|
omFrozenBasePathTypeGetTypeID);
|
|
|
|
// Add the ListType class definition.
|
|
mlir_type_subclass(m, "ListType", omTypeIsAListType, omListTypeGetTypeID)
|
|
.def_property_readonly("element_type", omListTypeGetElementType);
|
|
|
|
// Add the PathType class definition.
|
|
mlir_type_subclass(m, "PathType", omTypeIsAFrozenPathType,
|
|
omFrozenPathTypeGetTypeID);
|
|
}
|