[RTG] Custom parser/printer for sequence op and type for sequence families (#8146)

Allow type parameters in the sequence type and attach it as a type attribute to the sequence op. That way ops referring to a sequence don't have to access the operation's body to verify the type.
This commit is contained in:
Martin Erhart 2025-02-04 16:33:48 +00:00 committed by GitHub
parent 5d016eb43b
commit e4d67daaf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 188 additions and 56 deletions

View File

@ -29,7 +29,16 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(RTG, rtg);
MLIR_CAPI_EXPORTED bool rtgTypeIsASequence(MlirType type);
/// Creates an RTG sequence type in the context.
MLIR_CAPI_EXPORTED MlirType rtgSequenceTypeGet(MlirContext ctxt);
MLIR_CAPI_EXPORTED MlirType rtgSequenceTypeGet(MlirContext ctxt,
intptr_t numElements,
MlirType const *elementTypes);
/// The number of substitution elements of the RTG sequence.
MLIR_CAPI_EXPORTED unsigned rtgSequenceTypeGetNumElements(MlirType type);
/// The type of of the substitution element at the given index.
MLIR_CAPI_EXPORTED MlirType rtgSequenceTypeGetElement(MlirType type,
unsigned i);
/// If the type is an RTG label.
MLIR_CAPI_EXPORTED bool rtgTypeIsALabel(MlirType type);

View File

@ -50,12 +50,13 @@ def SequenceOp : RTGOp<"sequence", [
stronger top-level isolation guarantees.
}];
let arguments = (ins SymbolNameAttr:$sym_name);
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<SequenceType>:$sequenceType);
let regions = (region SizedRegion<1>:$bodyRegion);
let assemblyFormat = [{
$sym_name attr-dict-with-keyword $bodyRegion
}];
let hasCustomAssemblyFormat = 1;
let hasRegionVerifier = 1;
}
def SequenceClosureOp : RTGOp<"sequence_closure", [
@ -77,7 +78,7 @@ def SequenceClosureOp : RTGOp<"sequence_closure", [
}];
let arguments = (ins SymbolNameAttr:$sequence, Variadic<AnyType>:$args);
let results = (outs SequenceType:$ref);
let results = (outs FullySubstitutedSequenceType:$ref);
let assemblyFormat = [{
$sequence (`(` $args^ `:` qualified(type($args)) `)`)? attr-dict
@ -94,7 +95,7 @@ def InvokeSequenceOp : RTGOp<"invoke_sequence", []> {
were directly inlined relacing this operation.
}];
let arguments = (ins SequenceType:$sequence);
let arguments = (ins FullySubstitutedSequenceType:$sequence);
let assemblyFormat = "$sequence attr-dict";
}

View File

@ -19,16 +19,27 @@ include "mlir/IR/AttrTypeBase.td"
class RTGTypeDef<string name> : TypeDef<RTGDialect, name>;
def SequenceType : RTGTypeDef<"Sequence"> {
let summary = "handle to a sequence closure";
let summary = "handle to a sequence or sequence family";
let description = [{
An SSA value of this type refers to an `rtg.sequence` operation and the
argument values it should be invoked with (if it has any).
An SSA value of this type refers to a sequence if the list of element types
is empty or a sequence family if there are elements left to be substituted.
}];
let parameters = (ins OptionalArrayRefParameter<
"mlir::Type", "element types">:$elementTypes);
let mnemonic = "sequence";
let assemblyFormat = "";
let assemblyFormat = "(`<` $elementTypes^ `>`)?";
}
def FullySubstitutedSequenceType : DialectType<RTGDialect,
CPred<"llvm::isa<rtg::SequenceType>($_self) && "
"llvm::cast<rtg::SequenceType>($_self).getElementTypes().empty()">,
"fully substituted sequence type", "::circt::rtg::SequenceType">,
BuildableType<
"::circt::rtg::SequenceType::get($_builder.getContext(), " #
"llvm::ArrayRef<::mlir::Type>{})">;
def LabelType : RTGTypeDef<"Label"> {
let summary = "a reference to a label";
let description = [{

View File

@ -41,12 +41,11 @@ with Context() as ctx, Location.unknown():
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
seq = rtg.SequenceOp('seq')
setTy = rtg.SetType.get(rtg.SequenceType.get())
seq = rtg.SequenceOp('seq', TypeAttr.get(rtg.SequenceType.get([setTy])))
seqBlock = Block.create_at_start(seq.bodyRegion, [setTy])
# CHECK: rtg.sequence @seq {
# CHECK: ^bb{{.*}}(%{{.*}}: !rtg.set<!rtg.sequence>):
# CHECK: rtg.sequence @seq(%{{.*}}: !rtg.set<!rtg.sequence>) {
# CHECK: }
print(m)
@ -54,7 +53,8 @@ with Context() as ctx, Location.unknown():
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
seq = rtg.SequenceOp('sequence_name')
seq = rtg.SequenceOp('sequence_name',
TypeAttr.get(rtg.SequenceType.get([])))
Block.create_at_start(seq.bodyRegion, [])
test = rtg.TestOp('test_name', TypeAttr.get(rtg.DictType.get()))
@ -89,12 +89,14 @@ with Context() as ctx, Location.unknown():
setTy = rtg.SetType.get(indexTy)
bagTy = rtg.BagType.get(indexTy)
ireg = rtgtest.IntegerRegisterType.get()
seq = rtg.SequenceOp('seq')
seq = rtg.SequenceOp(
'seq',
TypeAttr.get(
rtg.SequenceType.get([sequenceTy, labelTy, setTy, bagTy, ireg])))
Block.create_at_start(seq.bodyRegion,
[sequenceTy, labelTy, setTy, bagTy, ireg])
# CHECK: rtg.sequence @seq
# CHECK: (%{{.*}}: !rtg.sequence, %{{.*}}: !rtg.label, %{{.*}}: !rtg.set<index>, %{{.*}}: !rtg.bag<index>, %{{.*}}: !rtgtest.ireg):
# CHECK: rtg.sequence @seq(%{{.*}}: !rtg.sequence, %{{.*}}: !rtg.label, %{{.*}}: !rtg.set<index>, %{{.*}}: !rtg.bag<index>, %{{.*}}: !rtgtest.ireg)
print(m)
with Context() as ctx, Location.unknown():
@ -189,7 +191,7 @@ with Context() as ctx, Location.unknown():
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
seq = rtg.SequenceOp('seq')
seq = rtg.SequenceOp('seq', TypeAttr.get(rtg.SequenceType.get([])))
block = Block.create_at_start(seq.bodyRegion, [])
with InsertionPoint(block):
l = rtg.label_decl("label", [])

View File

@ -26,10 +26,19 @@ void circt::python::populateDialectRTGSubmodule(nb::module_ &m) {
mlir_type_subclass(m, "SequenceType", rtgTypeIsASequence)
.def_classmethod(
"get",
[](nb::object cls, MlirContext ctxt) {
return cls(rtgSequenceTypeGet(ctxt));
[](nb::object cls, std::vector<MlirType> &elementTypes,
MlirContext ctxt) {
return cls(rtgSequenceTypeGet(ctxt, elementTypes.size(),
elementTypes.data()));
},
nb::arg("self"), nb::arg("ctxt") = nullptr);
nb::arg("self"), nb::arg("elementTypes") = std::vector<MlirType>(),
nb::arg("ctxt") = nullptr)
.def_property_readonly(
"num_elements",
[](MlirType self) { return rtgSequenceTypeGetNumElements(self); })
.def("get_element", [](MlirType self, unsigned i) {
return rtgSequenceTypeGetElement(self, i);
});
mlir_type_subclass(m, "LabelType", rtgTypeIsALabel)
.def_classmethod(

View File

@ -32,8 +32,20 @@ bool rtgTypeIsASequence(MlirType type) {
return isa<SequenceType>(unwrap(type));
}
MlirType rtgSequenceTypeGet(MlirContext ctxt) {
return wrap(SequenceType::get(unwrap(ctxt)));
MlirType rtgSequenceTypeGet(MlirContext ctxt, intptr_t numElements,
MlirType const *elementTypes) {
SmallVector<Type> types;
for (unsigned i = 0; i < numElements; ++i)
types.emplace_back(unwrap(elementTypes[i]));
return wrap(SequenceType::get(unwrap(ctxt), types));
}
unsigned rtgSequenceTypeGetNumElements(MlirType type) {
return cast<SequenceType>(unwrap(type)).getElementTypes().size();
}
MlirType rtgSequenceTypeGetElement(MlirType type, unsigned i) {
return wrap(cast<SequenceType>(unwrap(type)).getElementTypes()[i]);
}
// LabelType

View File

@ -18,6 +18,77 @@ using namespace mlir;
using namespace circt;
using namespace rtg;
//===----------------------------------------------------------------------===//
// SequenceOp
//===----------------------------------------------------------------------===//
LogicalResult SequenceOp::verifyRegions() {
if (TypeRange(getSequenceType().getElementTypes()) !=
getBody()->getArgumentTypes())
return emitOpError("sequence type does not match block argument types");
return success();
}
ParseResult SequenceOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the name as a symbol.
if (parser.parseSymbolName(
result.getOrAddProperties<SequenceOp::Properties>().sym_name))
return failure();
// Parse the function signature.
SmallVector<OpAsmParser::Argument> arguments;
if (parser.parseArgumentList(arguments, OpAsmParser::Delimiter::Paren,
/*allowType=*/true, /*allowAttrs=*/true))
return failure();
SmallVector<Type> argTypes;
SmallVector<Location> argLocs;
argTypes.reserve(arguments.size());
argLocs.reserve(arguments.size());
for (auto &arg : arguments) {
argTypes.push_back(arg.type);
argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
}
Type type = SequenceType::get(result.getContext(), argTypes);
result.getOrAddProperties<SequenceOp::Properties>().sequenceType =
TypeAttr::get(type);
auto loc = parser.getCurrentLocation();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
return parser.emitError(loc)
<< "'" << result.name.getStringRef() << "' op ";
})))
return failure();
std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
if (parser.parseRegion(*bodyRegionRegion, arguments))
return failure();
if (bodyRegionRegion->empty()) {
bodyRegionRegion->emplaceBlock();
bodyRegionRegion->addArguments(argTypes, argLocs);
}
result.addRegion(std::move(bodyRegionRegion));
return success();
}
void SequenceOp::print(OpAsmPrinter &p) {
p << ' ';
p.printSymbolName(getSymNameAttr().getValue());
p << "(";
llvm::interleaveComma(getBody()->getArguments(), p,
[&](auto arg) { p.printRegionArgument(arg); });
p << ")";
p.printOptionalAttrDictWithKeyword(
(*this)->getAttrs(), {getSymNameAttrName(), getSequenceTypeAttrName()});
p << ' ';
p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
}
//===----------------------------------------------------------------------===//
// SequenceClosureOp
//===----------------------------------------------------------------------===//
@ -31,7 +102,8 @@ SequenceClosureOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
<< "'" << getSequence()
<< "' does not reference a valid 'rtg.sequence' operation";
if (seq.getBodyRegion().getArgumentTypes() != getArgs().getTypes())
if (TypeRange(seq.getSequenceType().getElementTypes()) !=
getArgs().getTypes())
return emitOpError("referenced 'rtg.sequence' op's argument types must "
"match 'args' types");

View File

@ -18,7 +18,7 @@ int main(int argc, char **argv) {
mlirDialectHandleRegisterDialect(mlirGetDialectHandle__rtg__(), ctx);
MlirModule moduleOp = mlirModuleCreateParse(
ctx, mlirStringRefCreateFromCString("rtg.sequence @seq {\n"
ctx, mlirStringRefCreateFromCString("rtg.sequence @seq() {\n"
"}\n"
"rtg.test @test : !rtg.dict<> {\n"
" %0 = rtg.sequence_closure @seq\n"

View File

@ -15,13 +15,24 @@
#include "mlir-c/BuiltinTypes.h"
static void testSequenceType(MlirContext ctx) {
MlirType sequenceTy = rtgSequenceTypeGet(ctx);
MlirType sequenceTy = rtgSequenceTypeGet(ctx, 0, NULL);
// CHECK: is_sequence
fprintf(stderr, rtgTypeIsASequence(sequenceTy) ? "is_sequence\n"
: "isnot_sequence\n");
// CHECK: !rtg.sequence
mlirTypeDump(sequenceTy);
MlirType sequenceWithArgsTy = rtgSequenceTypeGet(ctx, 1, &sequenceTy);
// CHECK: is_sequence
fprintf(stderr, rtgTypeIsASequence(sequenceWithArgsTy) ? "is_sequence\n"
: "isnot_sequence\n");
// CHECK: 1
fprintf(stderr, "%d\n", rtgSequenceTypeGetNumElements(sequenceWithArgsTy));
// CHECK: !rtg.sequence
mlirTypeDump(rtgSequenceTypeGetElement(sequenceWithArgsTy, 0));
// CHECK: !rtg.sequence<!rtg.sequence>
mlirTypeDump(sequenceWithArgsTy);
}
static void testLabelType(MlirContext ctx) {

View File

@ -1,8 +1,7 @@
// RUN: circt-opt %s | FileCheck %s
// RUN: circt-opt %s --verify-roundtrip | FileCheck %s
// CHECK-LABEL: rtg.sequence @seq
// CHECK-SAME: attributes {rtg.some_attr} {
rtg.sequence @seq0 attributes {rtg.some_attr} {
rtg.sequence @seq0() {
%arg = arith.constant 1 : index
// CHECK: [[LBL0:%.*]] = rtg.label_decl "label_string_{0}_{1}", %{{.*}}, %{{.*}}
%0 = rtg.label_decl "label_string_{0}_{1}", %arg, %arg
@ -16,14 +15,16 @@ rtg.sequence @seq0 attributes {rtg.some_attr} {
rtg.label external %0
}
// CHECK-LABEL: rtg.sequence @seqAttrsAndTypeElements
// CHECK-SAME: (%arg0: !rtg.sequence<!rtg.sequence<!rtg.label, !rtg.set<index>>>) attributes {rtg.some_attr} {
rtg.sequence @seqAttrsAndTypeElements(%arg0: !rtg.sequence<!rtg.sequence<!rtg.label, !rtg.set<index>>>) attributes {rtg.some_attr} {}
// CHECK-LABEL: rtg.sequence @seq1
// CHECK: ^bb0(%arg0: i32, %arg1: !rtg.sequence):
rtg.sequence @seq1 {
^bb0(%arg0: i32, %arg1: !rtg.sequence):
}
// CHECK-SAME: (%arg0: i32, %arg1: !rtg.sequence)
rtg.sequence @seq1(%arg0: i32, %arg1: !rtg.sequence) { }
// CHECK-LABEL: rtg.sequence @invocations
rtg.sequence @invocations {
rtg.sequence @invocations() {
// CHECK: [[V0:%.+]] = rtg.sequence_closure @seq0
// CHECK: [[C0:%.+]] = arith.constant 0 : i32
// CHECK: [[V1:%.+]] = rtg.sequence_closure @seq1([[C0]], [[V0]] : i32, !rtg.sequence)
@ -55,8 +56,7 @@ func.func @sets(%arg0: i32, %arg1: i32) {
}
// CHECK-LABEL: @bags
rtg.sequence @bags {
^bb0(%arg0: i32, %arg1: i32, %arg2: index):
rtg.sequence @bags(%arg0: i32, %arg1: i32, %arg2: index) {
// CHECK: [[BAG:%.+]] = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32 {rtg.some_attr}
// CHECK: [[R:%.+]] = rtg.bag_select_random [[BAG]] : !rtg.bag<i32> {rtg.some_attr}
// CHECK: [[EMPTY:%.+]] = rtg.bag_create : i32

View File

@ -2,7 +2,7 @@
// CHECK-LABEL: rtg.sequence @seq
// CHECK-SAME: attributes {rtg.some_attr} {
rtg.sequence @seq0 attributes {rtg.some_attr} {
rtg.sequence @seq0() attributes {rtg.some_attr} {
// CHECK-NEXT: arith.constant
%arg = arith.constant 1 : index
// CHECK-NEXT: rtg.label_decl "label_string_{0}_{1}", %{{.*}}, %{{.*}}

View File

@ -10,15 +10,26 @@ rtg.sequence_closure @seq0
// -----
rtg.sequence @seq0 {
^bb0(%arg0: i32):
}
rtg.sequence @seq0(%arg0: i32) { }
// expected-error @below {{referenced 'rtg.sequence' op's argument types must match 'args' types}}
rtg.sequence_closure @seq0
// -----
// expected-error @below {{sequence type does not match block argument types}}
"rtg.sequence"()<{sym_name="seq0", sequenceType=!rtg.sequence<i32>}>({^bb0:}) : () -> ()
// -----
// expected-note @below {{prior use here}}
rtg.sequence @seq0(%arg0: !rtg.sequence<i32>) {
// expected-error @below {{use of value '%arg0' expects different type than prior uses: '!rtg.sequence' vs '!rtg.sequence<i32>'}}
rtg.invoke_sequence %arg0
}
// -----
// expected-error @below {{terminator operand types must match dict entry types}}
rtg.target @target : !rtg.dict<a: i32> {
rtg.yield
@ -53,30 +64,28 @@ rtg.test @test : !rtg.dict<"": i32> {
// -----
rtg.sequence @seq {
^bb0(%arg0: i32, %arg1: i64, %arg2: index):
rtg.sequence @seq(%arg0: i32, %arg1: i64, %arg2: index) {
// expected-error @below {{types of all elements must match}}
"rtg.bag_create"(%arg0, %arg1, %arg2, %arg2){} : (i32, i64, index, index) -> !rtg.bag<i32>
}
// -----
rtg.sequence @seq {
^bb0(%arg0: i64, %arg1: i64, %arg2: index):
rtg.sequence @seq(%arg0: i64, %arg1: i64, %arg2: index) {
// expected-error @below {{operand types must match bag element type}}
"rtg.bag_create"(%arg0, %arg1, %arg2, %arg2){} : (i64, i64, index, index) -> !rtg.bag<i32>
}
// -----
rtg.sequence @seq {
rtg.sequence @seq() {
// expected-error @below {{expected 1 or more operands, but found 0}}
rtg.set_union : !rtg.set<i32>
}
// -----
rtg.sequence @seq {
rtg.sequence @seq() {
// expected-error @below {{expected 1 or more operands, but found 0}}
rtg.bag_union : !rtg.bag<i32>
}

View File

@ -108,15 +108,13 @@ rtg.target @target1 : !rtg.dict<num_cpus: index> {
// Unused sequences are removed
// CHECK-NOT: rtg.sequence @unused
rtg.sequence @unused {}
rtg.sequence @unused() {}
rtg.sequence @seq0 {
^bb0(%arg0: index):
rtg.sequence @seq0(%arg0: index) {
func.call @dummy2(%arg0) : (index) -> ()
}
rtg.sequence @seq1 {
^bb0(%arg0: index):
rtg.sequence @seq1(%arg0: index) {
%0 = rtg.sequence_closure @seq0(%arg0 : index)
func.call @dummy2(%arg0) : (index) -> ()
rtg.invoke_sequence %0
@ -134,8 +132,7 @@ rtg.test @nestedSequences : !rtg.dict<> {
rtg.invoke_sequence %1
}
rtg.sequence @seq2 {
^bb0(%arg0: index):
rtg.sequence @seq2(%arg0: index) {
func.call @dummy2(%arg0) : (index) -> ()
}
@ -153,8 +150,7 @@ rtg.test @sameSequenceDifferentArgs : !rtg.dict<> {
rtg.invoke_sequence %3
}
rtg.sequence @seq3 {
^bb0(%arg0: !rtg.set<index>):
rtg.sequence @seq3(%arg0: !rtg.set<index>) {
%0 = rtg.set_select_random %arg0 : !rtg.set<index> // we can't use a custom seed here because it would render the test useless
func.call @dummy2(%0) : (index) -> ()
}