[OM][Python] Support string and list as evaluator arguments (#8556)

Add support for passing Python strings and lists directly as arguments to
the OM evaluator's instantiate method:

* Fix Attribute value creation to work well with list evaluation.
* Add StringType C API to construct typed strings in Python bindings
* Implement conversion between Python lists and OM ListAttr
This commit is contained in:
Hideto Ueno 2025-06-12 16:38:13 -07:00 committed by GitHub
parent 9925b444d3
commit 5dc12802ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 279 additions and 31 deletions

View File

@ -70,6 +70,12 @@ MLIR_CAPI_EXPORTED MlirType omMapTypeGetKeyType(MlirType type);
/// Is the Type a StringType.
MLIR_CAPI_EXPORTED bool omTypeIsAStringType(MlirType type);
/// Get the TypeID for a StringType.
MLIR_CAPI_EXPORTED MlirTypeID omStringTypeGetTypeID(void);
/// Get a StringType.
MLIR_CAPI_EXPORTED MlirType omStringTypeGet(MlirContext ctx);
//===----------------------------------------------------------------------===//
// Evaluator data structures.
//===----------------------------------------------------------------------===//
@ -265,6 +271,10 @@ MLIR_CAPI_EXPORTED intptr_t omListAttrGetNumElements(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute omListAttrGetElement(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED MlirAttribute omListAttrGet(MlirType elementType,
intptr_t numElements,
const MlirAttribute *elements);
//===----------------------------------------------------------------------===//
// MapAttr API
//===----------------------------------------------------------------------===//

View File

@ -130,20 +130,6 @@ private:
/// Values which can be directly representable by MLIR attributes.
struct AttributeValue : EvaluatorValue {
AttributeValue(Attribute attr)
: AttributeValue(attr, mlir::UnknownLoc::get(attr.getContext())) {}
AttributeValue(Attribute attr, Location loc)
: EvaluatorValue(attr.getContext(), Kind::Attr, loc), attr(attr),
type(cast<TypedAttr>(attr).getType()) {
markFullyEvaluated();
}
// Constructors for partially evaluated AttributeValue.
AttributeValue(Type type)
: AttributeValue(mlir::UnknownLoc::get(type.getContext())) {}
AttributeValue(Type type, Location loc)
: EvaluatorValue(type.getContext(), Kind::Attr, loc), type(type) {}
Attribute getAttr() const { return attr; }
template <typename AttrTy>
AttrTy getAs() const {
@ -161,9 +147,32 @@ struct AttributeValue : EvaluatorValue {
Type getType() const { return type; }
// Factory methods that create AttributeValue objects
static std::shared_ptr<EvaluatorValue> get(Attribute attr,
LocationAttr loc = {});
static std::shared_ptr<EvaluatorValue> get(Type type, LocationAttr loc = {});
private:
// Make AttributeValue constructible only by the factory methods
struct PrivateTag {};
// Constructor that requires a PrivateTag
AttributeValue(PrivateTag, Attribute attr, Location loc)
: EvaluatorValue(attr.getContext(), Kind::Attr, loc), attr(attr),
type(cast<TypedAttr>(attr).getType()) {
markFullyEvaluated();
}
// Constructor for partially evaluated AttributeValue
AttributeValue(PrivateTag, Type type, Location loc)
: EvaluatorValue(type.getContext(), Kind::Attr, loc), type(type) {}
Attribute attr = {};
Type type;
// Friend declaration for the factory methods
friend std::shared_ptr<EvaluatorValue> get(Attribute attr, LocationAttr loc);
friend std::shared_ptr<EvaluatorValue> get(Type type, LocationAttr loc);
};
// This perform finalization to `value`.

View File

@ -104,6 +104,11 @@ with Context() as ctx, Location.unknown():
%5 = om.integer.add %1, %3 : !om.integer
om.class.fields %5 : !om.integer
}
om.class @AppendList(%head: !om.string, %tail: !om.list<!om.string>) -> (result: !om.list<!om.string>) {
%0 = om.list_create %head : !om.string
%1 = om.list_concat %0, %tail : !om.list<!om.string>
om.class.fields %1 : !om.list<!om.string>
}
}
""")
@ -144,12 +149,12 @@ print(obj.type.name)
print(obj.field)
# location of the om.class.field @field
# CHECK: field: loc("-":50:7)
# CHECK: field: loc("-":{{.*}}:{{.*}})
print("field:", obj.get_field_loc("field"))
# CHECK: child.foo: 14
print("child.foo: ", obj.child.foo)
# CHECK: child.foo.loc loc("-":54:7)
# CHECK: child.foo.loc loc("-":{{.*}}:{{.*}})
print("child.foo.loc", obj.child.get_field_loc("foo"))
# CHECK: ('Root', 'x')
print(obj.reference)
@ -157,10 +162,10 @@ print(obj.reference)
# CHECK: 14
print(snd)
# CHECK: loc("-":50:7)
# CHECK: loc("-":{{.*}}:{{.*}})
print("tuple", obj.get_field_loc("tuple"))
# CHECK: loc("-":22:5)
# CHECK: loc("-":{{.*}}:{{.*}})
print(obj.loc)
try:
@ -172,18 +177,20 @@ except IndexError as e:
for (name, field) in obj:
# location from om.class.field @child, %0 : !om.class.type<@Child>
# CHECK: name: child, field: <circt.dialects.om.Object object
# CHECK-SAME: loc: loc("-":26:12)
# CHECK-SAME: loc: loc("-":{{.*}}:{{.*}})
# location from om.class.field @field, %param : !om.integer
# CHECK: name: field, field: 42
# CHECK-SAME: loc: loc("-":50:7)
# CHECK-SAME: loc: loc("-":{{.*}}:{{.*}})
# location from om.class.field @reference, %sym : !om.ref
# CHECK: name: reference, field: ('Root', 'x')
# CHECK-SAME: loc: loc("-":50:7)
# CHECK-SAME: loc: loc("-":{{.*}}:{{.*}})
loc = obj.get_field_loc(name)
print(f"name: {name}, field: {field}, loc: {loc}")
print("Check list")
# CHECK: Check list
# CHECK: ['X', 'Y']
print(obj.list)
print(list(obj.list))
for child in obj.nest.list_child:
# CHECK: 14
# CHECK-NEXT: 15
@ -259,6 +266,24 @@ delayed = evaluator.instantiate("IntegerBinaryArithmeticObjectsDelayed")
# CHECK: 3
print(delayed.result)
# Test string and list arguments
obj = evaluator.instantiate("AppendList", "a", ["b", "c"])
# CHECK: ['a', 'b', 'c']
print(list(obj.result))
# Test string and list arguments
try:
obj = evaluator.instantiate("AppendList", "a", [1, "b"])
except TypeError as e:
# CHECK: List elements must be of the same type
print(e)
try:
obj = evaluator.instantiate("AppendList", "a", [])
except TypeError as e:
# CHECK: Empty list is prohibited now
print(e)
with Context() as ctx:
circt.register_dialects(ctx)

View File

@ -445,13 +445,36 @@ static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value,
if (auto *attr = std::get_if<nb::str>(&value)) {
auto str = nb::cast<std::string>(*attr);
auto strRef = mlirStringRefCreate(str.data(), str.length());
return mlirStringAttrGet(ctx, strRef);
auto omStringType = omStringTypeGet(ctx);
return mlirStringAttrTypedGet(omStringType, strRef);
}
if (auto *attr = std::get_if<nb::bool_>(&value)) {
return mlirBoolAttrGet(ctx, nb::cast<bool>(*attr));
}
// For a python list try constructing OM list attribute. The element
// type must be uniform.
if (auto *attr = std::get_if<nb::list>(&value)) {
if (attr->size() == 0)
throw nb::type_error("Empty list is prohibited now");
std::vector<MlirAttribute> attrs;
attrs.reserve(attr->size());
std::optional<MlirType> elemenType;
for (auto v : *attr) {
attrs.push_back(
omPythonValueToPrimitive(nb::cast<PythonPrimitive>(v), ctx));
if (!elemenType)
elemenType = mlirAttributeGetType(attrs.back());
else if (!mlirTypeEqual(*elemenType,
mlirAttributeGetType(attrs.back()))) {
throw nb::type_error("List elements must be of the same type");
}
}
return omListAttrGet(*elemenType, attrs.size(), attrs.data());
}
throw nb::type_error("Unexpected OM primitive value");
}

View File

@ -87,6 +87,9 @@ MlirType omStringTypeGet(MlirContext ctx) {
return wrap(StringType::get(unwrap(ctx)));
}
/// Get the TypeID for a StringType.
MlirTypeID omStringTypeGetTypeID(void) { return wrap(StringType::getTypeID()); }
/// Is the Type a MapType.
bool omTypeIsAMapType(MlirType type) { return isa<MapType>(unwrap(type)); }
@ -263,7 +266,7 @@ MlirAttribute omEvaluatorValueGetPrimitive(OMEvaluatorValue evaluatorValue) {
/// Get the Primitive from an EvaluatorValue, which must contain a Primitive.
OMEvaluatorValue omEvaluatorValueFromPrimitive(MlirAttribute primitive) {
// Assert the Attribute is non-null, and return it.
return wrap(std::make_shared<evaluator::AttributeValue>(unwrap(primitive)));
return wrap(evaluator::AttributeValue::get(unwrap(primitive)));
}
/// Query if the EvaluatorValue is a List.
@ -429,6 +432,15 @@ MlirAttribute omListAttrGetElement(MlirAttribute attr, intptr_t pos) {
return wrap(listAttr.getElements()[pos]);
}
MlirAttribute omListAttrGet(MlirType elementType, intptr_t numElements,
const MlirAttribute *elements) {
SmallVector<Attribute, 8> attrs;
(void)unwrapList(static_cast<size_t>(numElements), elements, attrs);
auto type = unwrap(elementType);
auto *ctx = type.getContext();
return wrap(ListAttr::get(ctx, type, ArrayAttr::get(ctx, attrs)));
}
//===----------------------------------------------------------------------===//
// MapAttr API.
//===----------------------------------------------------------------------===//

View File

@ -36,7 +36,7 @@ circt::om::getEvaluatorValuesFromAttributes(MLIRContext *context,
SmallVector<evaluator::EvaluatorValuePtr> values;
values.reserve(attributes.size());
for (auto attr : attributes)
values.push_back(std::make_shared<evaluator::AttributeValue>(attr));
values.push_back(evaluator::AttributeValue::get(cast<TypedAttr>(attr)));
return values;
}
@ -132,8 +132,8 @@ FailureOr<evaluator::EvaluatorValuePtr> circt::om::Evaluator::getOrCreateValue(
// Create a partially evaluated AttributeValue of
// om::IntegerType in case we need to delay evaluation.
evaluator::EvaluatorValuePtr result =
std::make_shared<evaluator::AttributeValue>(
op.getResult().getType(), loc);
evaluator::AttributeValue::get(op.getResult().getType(),
loc);
return success(result);
})
.Case<ObjectFieldOp>([&](auto op) {
@ -411,8 +411,8 @@ FailureOr<circt::om::evaluator::EvaluatorValuePtr>
circt::om::Evaluator::evaluateConstant(ConstantOp op,
ActualParameters actualParams,
Location loc) {
return success(std::make_shared<circt::om::evaluator::AttributeValue>(
op.getValue(), loc));
// For list constants, create ListValue.
return success(om::evaluator::AttributeValue::get(op.getValue(), loc));
}
// Evaluator dispatch function for integer binary arithmetic.
@ -949,3 +949,41 @@ LogicalResult circt::om::evaluator::AttributeValue::finalizeImpl() {
getLoc(), "cannot finalize AttributeValue that is not fully evaluated");
return success();
}
std::shared_ptr<evaluator::EvaluatorValue>
circt::om::evaluator::AttributeValue::get(Attribute attr, LocationAttr loc) {
auto type = cast<TypedAttr>(attr).getType();
auto *context = type.getContext();
if (!loc)
loc = UnknownLoc::get(context);
// Special handling for ListType to create proper ListValue objects instead of
// AttributeValue objects.
if (auto listType = dyn_cast<circt::om::ListType>(type)) {
SmallVector<EvaluatorValuePtr> elements;
auto listAttr = cast<om::ListAttr>(attr);
auto values = getEvaluatorValuesFromAttributes(
listAttr.getContext(), listAttr.getElements().getValue());
elements.append(values.begin(), values.end());
auto list = std::make_shared<evaluator::ListValue>(listType, elements, loc);
return list;
}
return std::shared_ptr<AttributeValue>(
new AttributeValue(PrivateTag{}, attr, loc));
}
std::shared_ptr<evaluator::EvaluatorValue>
circt::om::evaluator::AttributeValue::get(Type type, LocationAttr loc) {
auto *context = type.getContext();
if (!loc)
loc = UnknownLoc::get(context);
// Special handling for ListType to create proper ListValue objects instead of
// AttributeValue objects.
if (auto listType = dyn_cast<circt::om::ListType>(type))
return std::make_shared<evaluator::ListValue>(listType, loc);
// Create the AttributeValue with the private tag
return std::shared_ptr<AttributeValue>(
new AttributeValue(PrivateTag{}, type, loc));
}

View File

@ -17,6 +17,84 @@
#include <mlir-c/Support.h>
#include <stdio.h>
void testTypes(MlirContext ctx) {
// Test StringType creation and type checking
MlirType stringType = omStringTypeGet(ctx);
// CHECK: string type is StringType: 1
fprintf(stderr, "string type is StringType: %d\n",
omTypeIsAStringType(stringType));
// CHECK: !om.string
mlirTypeDump(stringType);
MlirTypeID stringTypeID = omStringTypeGetTypeID();
MlirTypeID actualStringTypeID = mlirTypeGetTypeID(stringType);
// CHECK: StringType TypeID matches: 1
fprintf(stderr, "StringType TypeID matches: %d\n",
mlirTypeIDEqual(stringTypeID, actualStringTypeID));
}
void testListAttr(MlirContext ctx) {
// Test creating ListAttr with integer elements
MlirType i64Type = mlirIntegerTypeGet(ctx, 64);
MlirAttribute elem1 = mlirIntegerAttrGet(i64Type, 42);
MlirAttribute elem2 = mlirIntegerAttrGet(i64Type, 84);
const MlirAttribute elements[] = {elem1, elem2};
MlirAttribute listAttr = omListAttrGet(i64Type, 2, elements);
// CHECK: list attr is ListAttr: 1
fprintf(stderr, "list attr is ListAttr: %d\n", omAttrIsAListAttr(listAttr));
// CHECK: list attr num elements: 2
fprintf(stderr, "list attr num elements: %ld\n",
omListAttrGetNumElements(listAttr));
// CHECK: #om.list<i64, [42, 84]>
mlirAttributeDump(listAttr);
// Test accessing elements
MlirAttribute retrievedElem1 = omListAttrGetElement(listAttr, 0);
MlirAttribute retrievedElem2 = omListAttrGetElement(listAttr, 1);
// CHECK: first element: 42 : i64
fprintf(stderr, "first element: ");
mlirAttributeDump(retrievedElem1);
// CHECK: second element: 84 : i64
fprintf(stderr, "second element: ");
mlirAttributeDump(retrievedElem2);
// Test creating empty ListAttr
MlirAttribute emptyListAttr = omListAttrGet(i64Type, 0, NULL);
// CHECK: empty list attr is ListAttr: 1
fprintf(stderr, "empty list attr is ListAttr: %d\n",
omAttrIsAListAttr(emptyListAttr));
// CHECK: empty list attr num elements: 0
fprintf(stderr, "empty list attr num elements: %ld\n",
omListAttrGetNumElements(emptyListAttr));
// Test creating ListAttr with string elements
MlirType stringType = omStringTypeGet(ctx);
MlirAttribute strElem = mlirStringAttrTypedGet(
stringType, mlirStringRefCreateFromCString("hello"));
const MlirAttribute stringElements[] = {strElem};
MlirAttribute stringListAttr = omListAttrGet(stringType, 1, stringElements);
// CHECK: string list attr is ListAttr: 1
fprintf(stderr, "string list attr is ListAttr: %d\n",
omAttrIsAListAttr(stringListAttr));
// CHECK: #om.list<!om.string, ["hello" : !om.string]>
mlirAttributeDump(stringListAttr);
}
void testEvaluator(MlirContext ctx) {
const char *testIR =
"module {"
@ -161,6 +239,14 @@ void testEvaluator(MlirContext ctx) {
int main(void) {
MlirContext ctx = mlirContextCreate();
mlirDialectHandleRegisterDialect(mlirGetDialectHandle__om__(), ctx);
MlirDialectHandle omDialectHandle = mlirGetDialectHandle__om__();
// Load the OM dialect to ensure types are properly initialized
mlirContextGetOrLoadDialect(ctx,
mlirDialectHandleGetNamespace(omDialectHandle));
testTypes(ctx);
testListAttr(ctx);
testEvaluator(ctx);
return 0;
}

View File

@ -312,7 +312,7 @@ TEST(EvaluatorTests, InstantiateObjectWithChildObject) {
auto result = evaluator.instantiate(
builder.getStringAttr("MyClass"),
{std::make_shared<evaluator::AttributeValue>(circt::om::IntegerAttr::get(
{evaluator::AttributeValue::get(circt::om::IntegerAttr::get(
&context, builder.getI32IntegerAttr(42)))});
ASSERT_TRUE(succeeded(result));
@ -376,7 +376,7 @@ TEST(EvaluatorTests, InstantiateObjectWithFieldAccess) {
auto result = evaluator.instantiate(
builder.getStringAttr("MyClass"),
{std::make_shared<evaluator::AttributeValue>(circt::om::IntegerAttr::get(
{evaluator::AttributeValue::get(circt::om::IntegerAttr::get(
&context, builder.getI32IntegerAttr(42)))});
ASSERT_TRUE(succeeded(result));
@ -1439,4 +1439,49 @@ TEST(EvaluatorTests, NestedReferenceValue) {
.get()));
}
TEST(EvaluatorTests, ListAttrConcat) {
StringRef mod =
"om.class @ConcatListAttribute() -> (result: !om.list<!om.string>) {"
"%0 = om.constant #om.list<!om.string, [\"X\" : !om.string, \"Y\" : "
"!om.string]> : !om.list<!om.string>"
"%1 = om.list_concat %0, %0 : !om.list<!om.string>"
"om.class.fields %1 : !om.list<!om.string>"
"}";
DialectRegistry registry;
registry.insert<OMDialect>();
MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();
OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));
Evaluator evaluator(owning.release());
auto result = evaluator.instantiate(
StringAttr::get(&context, "ConcatListAttribute"), {});
ASSERT_TRUE(succeeded(result));
auto fieldValue = llvm::cast<evaluator::ObjectValue>(result.value().get())
->getField("result")
.value();
auto listVal =
llvm::cast<evaluator::ListValue>(fieldValue.get())->getElements();
ASSERT_EQ(4, listVal.size());
auto checkEq = [](evaluator::EvaluatorValue *val, const char *str) {
ASSERT_EQ(str, llvm::cast<evaluator::AttributeValue>(val)
->getAs<StringAttr>()
.getValue()
.str());
};
checkEq(listVal[0].get(), "X");
checkEq(listVal[1].get(), "Y");
checkEq(listVal[2].get(), "X");
checkEq(listVal[3].get(), "Y");
}
} // namespace