[RTG][CAPI] Expose getter for bag and set element type (#8227)

This commit is contained in:
Martin Erhart 2025-02-19 13:29:33 +00:00 committed by GitHub
parent 90aec03b8c
commit 435dceeaf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 66 additions and 10 deletions

View File

@ -21,7 +21,7 @@ class Test:
sig = inspect.signature(test_func)
assert len(sig.parameters) == 0, "test arguments not supported yet"
self.type = rtg.DictType.get(None, [])
self.type = rtg.DictType.get([])
@property
def name(self) -> str:

View File

@ -58,12 +58,18 @@ MLIR_CAPI_EXPORTED bool rtgTypeIsASet(MlirType type);
/// Creates an RTG set type in the context.
MLIR_CAPI_EXPORTED MlirType rtgSetTypeGet(MlirType elementType);
/// Return the element type of the RTG set.
MLIR_CAPI_EXPORTED MlirType rtgSetTypeGetElementType(MlirType type);
/// If the type is an RTG bag.
MLIR_CAPI_EXPORTED bool rtgTypeIsABag(MlirType type);
/// Creates an RTG bag type in the context.
MLIR_CAPI_EXPORTED MlirType rtgBagTypeGet(MlirType elementType);
/// Return the element type of the RTG bag.
MLIR_CAPI_EXPORTED MlirType rtgBagTypeGetElementType(MlirType type);
/// If the type is an RTG dict.
MLIR_CAPI_EXPORTED bool rtgTypeIsADict(MlirType type);

View File

@ -13,8 +13,8 @@ with Context() as ctx, Location.unknown():
m = Module.create()
with InsertionPoint(m.body):
cpuTy = rtgtest.CPUType.get()
dictTy = rtg.DictType.get(ctx, [(StringAttr.get('cpu0'), cpuTy),
(StringAttr.get('cpu1'), cpuTy)])
dictTy = rtg.DictType.get([(StringAttr.get('cpu0'), cpuTy),
(StringAttr.get('cpu1'), cpuTy)], ctx)
target = rtg.TargetOp('target_name', TypeAttr.get(dictTy))
targetBlock = Block.create_at_start(target.bodyRegion, [])
@ -44,6 +44,9 @@ with Context() as ctx, Location.unknown():
seq = rtg.SequenceOp('seq', TypeAttr.get(rtg.SequenceType.get([setTy])))
seqBlock = Block.create_at_start(seq.bodyRegion, [setTy])
# CHECK: !rtg.sequence{{$}}
print(setTy.element_type)
# CHECK: rtg.sequence @seq(%{{.*}}: !rtg.set<!rtg.sequence>) {
# CHECK: }
print(m)
@ -100,6 +103,9 @@ with Context() as ctx, Location.unknown():
seq.bodyRegion,
[sequenceTy, labelTy, setTy, bagTy, ireg, randomizedSequenceTy])
# CHECK: index{{$}}
print(bagTy.element_type)
# CHECK: rtg.sequence @seq(%{{.*}}: !rtg.sequence, %{{.*}}: !rtg.label, %{{.*}}: !rtg.set<index>, %{{.*}}: !rtg.bag<index>, %{{.*}}: !rtgtest.ireg, %{{.*}}: !rtg.randomized_sequence)
print(m)

View File

@ -62,7 +62,10 @@ void circt::python::populateDialectRTGSubmodule(nb::module_ &m) {
[](nb::object cls, MlirType elementType) {
return cls(rtgSetTypeGet(elementType));
},
nb::arg("self"), nb::arg("element_type"));
nb::arg("self"), nb::arg("element_type"))
.def_property_readonly("element_type", [](MlirType self) {
return rtgSetTypeGetElementType(self);
});
mlir_type_subclass(m, "BagType", rtgTypeIsABag)
.def_classmethod(
@ -70,13 +73,17 @@ void circt::python::populateDialectRTGSubmodule(nb::module_ &m) {
[](nb::object cls, MlirType elementType) {
return cls(rtgBagTypeGet(elementType));
},
nb::arg("self"), nb::arg("element_type"));
nb::arg("self"), nb::arg("element_type"))
.def_property_readonly("element_type", [](MlirType self) {
return rtgBagTypeGetElementType(self);
});
mlir_type_subclass(m, "DictType", rtgTypeIsADict)
.def_classmethod(
"get",
[](nb::object cls, MlirContext ctxt,
const std::vector<std::pair<MlirAttribute, MlirType>> &entries) {
[](nb::object cls,
const std::vector<std::pair<MlirAttribute, MlirType>> &entries,
MlirContext ctxt) {
std::vector<MlirAttribute> names;
std::vector<MlirType> types;
for (auto entry : entries) {
@ -86,9 +93,10 @@ void circt::python::populateDialectRTGSubmodule(nb::module_ &m) {
return cls(
rtgDictTypeGet(ctxt, types.size(), names.data(), types.data()));
},
nb::arg("self"), nb::arg("ctxt") = nullptr,
nb::arg("self"),
nb::arg("entries") =
std::vector<std::pair<MlirAttribute, MlirType>>());
std::vector<std::pair<MlirAttribute, MlirType>>(),
nb::arg("ctxt") = nullptr);
nb::enum_<RTGLabelVisibility>(m, "LabelVisibility")
.value("LOCAL", RTG_LABEL_VISIBILITY_LOCAL)

View File

@ -88,7 +88,7 @@ def type_to_pytype(t) -> ir.Type:
if t.__class__ != ir.Type:
return t
from .dialects import esi, hw, seq
from .dialects import esi, hw, seq, rtg
try:
return ir.IntegerType(t)
except ValueError:
@ -129,6 +129,30 @@ def type_to_pytype(t) -> ir.Type:
return esi.BundleType(t)
except ValueError:
pass
try:
return rtg.LabelType(t)
except ValueError:
pass
try:
return rtg.SetType(t)
except ValueError:
pass
try:
return rtg.BagType(t)
except ValueError:
pass
try:
return rtg.SequenceType(t)
except ValueError:
pass
try:
return rtg.RandomizedSequenceType(t)
except ValueError:
pass
try:
return rtg.DictType(t)
except ValueError:
pass
raise TypeError(f"Cannot convert {repr(t)} to python type")

View File

@ -79,6 +79,10 @@ MlirType rtgSetTypeGet(MlirType elementType) {
return wrap(SetType::get(ty.getContext(), ty));
}
MlirType rtgSetTypeGetElementType(MlirType type) {
return wrap(cast<SetType>(unwrap(type)).getElementType());
}
// BagType
//===----------------------------------------------------------------------===//
@ -89,6 +93,10 @@ MlirType rtgBagTypeGet(MlirType elementType) {
return wrap(BagType::get(ty.getContext(), ty));
}
MlirType rtgBagTypeGetElementType(MlirType type) {
return wrap(cast<BagType>(unwrap(type)).getElementType());
}
// DictType
//===----------------------------------------------------------------------===//

View File

@ -64,6 +64,8 @@ static void testSetType(MlirContext ctx) {
fprintf(stderr, rtgTypeIsASet(setTy) ? "is_set\n" : "isnot_set\n");
// CHECK: !rtg.set<i32>
mlirTypeDump(setTy);
// CHECK: i32{{$}}
mlirTypeDump(rtgSetTypeGetElementType(setTy));
}
static void testBagType(MlirContext ctx) {
@ -74,6 +76,8 @@ static void testBagType(MlirContext ctx) {
fprintf(stderr, rtgTypeIsABag(bagTy) ? "is_bag\n" : "isnot_bag\n");
// CHECK: !rtg.bag<i32>
mlirTypeDump(bagTy);
// CHECK: i32{{$}}
mlirTypeDump(rtgBagTypeGetElementType(bagTy));
}
static void testDictType(MlirContext ctx) {