[RTG] Add custom tuple type to support empty tuples (#8711)

The builtin tuple type does not allow empty tuples, but they can avoid special casing in the frontend and can also be used as indicators (something like a none type could otherwise be used for but which we don't have in RTG).
This commit is contained in:
Martin Erhart 2025-07-17 08:13:28 +01:00 committed by GitHub
parent 5ee54acb05
commit 916ff355ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 159 additions and 51 deletions

View File

@ -45,7 +45,7 @@ def _FromCirctValue(value: ir.Value) -> Value:
if isinstance(type, rtg.ImmediateType):
from .immediates import Immediate
return Immediate(type.width, value)
if isinstance(type, ir.TupleType):
if isinstance(type, rtg.TupleType):
from .tuples import Tuple
return Tuple(value)
if isinstance(type, rtg.MemoryType):
@ -99,10 +99,9 @@ def _FromCirctType(type: Union[ir.Type, Type]) -> Type:
if isinstance(type, rtgtest.CPUType):
from .contexts import CPUCoreType
return CPUCoreType()
if isinstance(type, ir.TupleType):
if isinstance(type, rtg.TupleType):
from .tuples import TupleType
return TupleType(
[_FromCirctType(type.get_type(i)) for i in range(type.num_types)])
return TupleType([_FromCirctType(ty) for ty in type.fields])
if isinstance(type, rtg.MemoryType):
from .memories import MemoryType
return MemoryType(type.address_width)

View File

@ -29,9 +29,6 @@ class Tuple(Value):
element must be provided. Each element can be of a different type.
"""
if len(elements) == 0:
raise ValueError("at least one element must be present")
return rtg.TupleCreateOp(elements)
def __getitem__(self, i) -> Value:
@ -67,4 +64,4 @@ class TupleType(Type):
TupleType) and self.element_types == other.element_types
def _codegen(self) -> ir.Type:
return ir.TupleType.get_tuple([ty._codegen() for ty in self.element_types])
return rtg.TupleType.get([ty._codegen() for ty in self.element_types])

View File

@ -297,7 +297,7 @@ def test2_labels(config):
# MLIR-LABEL: rtg.test @test3_registers_and_immediates()
# MLIR-NEXT: %idx2097152 = index.constant 2097152
# MLIR-NEXT: %idx2097151 = index.constant 2097151
# MLIR-NEXT: %idx0 = index.constant 0
# MLIR-NEXT: [[IMM32:%.+]] = rtg.constant #rtg.isa.immediate<32, 32>
# MLIR-NEXT: [[IMM21:%.+]] = rtg.constant #rtg.isa.immediate<21, 16>
@ -313,7 +313,7 @@ def test2_labels(config):
# MLIR-NEXT: rtgtest.rv32i.beq [[VREG]], [[T2]], [[IMM13]] : !rtg.isa.immediate<13>
# MLIR-NEXT: rtgtest.rv32i.jal [[VREG]], [[IMM21]] : !rtg.isa.immediate<21>
# MLIR-NEXT: rtgtest.rv32i.auipc [[VREG]], [[IMM32]] : !rtg.isa.immediate<32>
# MLIR-NEXT: [[RND:%.+]] = rtg.random_number_in_range [%idx0, %idx2097152)
# MLIR-NEXT: [[RND:%.+]] = rtg.random_number_in_range [%idx0, %idx2097151]
# MLIR-NEXT: [[RND_IMM:%.+]] = rtg.isa.int_to_immediate [[RND]]
# MLIR-NEXT: rtgtest.rv32i.jal [[VREG]], [[RND_IMM]] : !rtg.isa.immediate<21>
# MLIR-NEXT: }
@ -395,7 +395,7 @@ def test7_bools(config):
# MLIR-LABEL: rtg.test @test8_random_integer
# MLIR-NEXT: rtg.random_number_in_range [%a, %b)
# MLIR-NEXT: rtg.random_number_in_range [%a, %b]
@sequence([IntegerType()])
@ -410,7 +410,7 @@ def test8_random_integer(config):
# MLIR-LABEL: rtg.test @test90_tuples
# MLIR-NEXT: [[V0:%.+]] = rtg.tuple_create %a, %b : index, i1
# MLIR-NEXT: rtg.tuple_extract [[V0]] at 1 : tuple<index, i1>
# MLIR-NEXT: rtg.tuple_extract [[V0]] at 1 : !rtg.tuple<index, i1>
@config

View File

@ -99,6 +99,21 @@ MLIR_CAPI_EXPORTED bool rtgTypeIsAArray(MlirType type);
/// Returns the element type of the RTG array.
MLIR_CAPI_EXPORTED MlirType rtgArrayTypeGetElementType(MlirType type);
/// Creates an RTG tuple type in the context.
MLIR_CAPI_EXPORTED MlirType rtgTupleTypeGet(MlirContext ctxt,
intptr_t numFields,
MlirType const *fieldTypes);
/// If the type is an RTG tuple.
MLIR_CAPI_EXPORTED bool rtgTypeIsATuple(MlirType type);
/// Returns the number of fields in the RTG tuple.
MLIR_CAPI_EXPORTED intptr_t rtgTypeGetNumFields(MlirType type);
/// Returns a field type of the RTG tuple.
MLIR_CAPI_EXPORTED MlirType rtgTupleTypeGetFieldType(MlirType type,
intptr_t idx);
/// If the type is an RTG memory.
MLIR_CAPI_EXPORTED bool rtgTypeIsAMemory(MlirType type);

View File

@ -602,7 +602,7 @@ def TupleCreateOp : RTGOp<"tuple_create", [
let summary = "create a tuple";
let arguments = (ins Variadic<AnyType>:$elements);
let results = (outs Builtin_Tuple:$result);
let results = (outs TupleType:$result);
let assemblyFormat = [{
($elements^ `:` qualified(type($elements)))? attr-dict
@ -615,7 +615,7 @@ def TupleExtractOp : RTGOp<"tuple_extract", [
]> {
let summary = "get an element from a tuple";
let arguments = (ins Builtin_Tuple:$tuple, IndexAttr:$index);
let arguments = (ins TupleType:$tuple, IndexAttr:$index);
let results = (outs AnyType:$result);
let assemblyFormat = [{

View File

@ -148,6 +148,21 @@ def ArrayType : RTGTypeDef<"Array"> {
];
}
def TupleType : RTGTypeDef<"Tuple"> {
let summary = "a tuple of zero or more fields";
let description = [{
This type represents a tuple of zero or more fields. The fields can be
of any type. The builtin tuple type is not used because it does not allow
zero fields.
}];
let parameters = (ins
OptionalArrayRefParameter<"::mlir::Type", "tuple field types">:$fieldTypes);
let mnemonic = "tuple";
let assemblyFormat = "(`<` $fieldTypes^ `>`)?";
}
//===----------------------------------------------------------------------===//
// Types for ISA targets
//===----------------------------------------------------------------------===//

View File

@ -251,3 +251,16 @@ with Context() as ctx, Location.unknown():
print(f"element_type={arr.element_type}")
# CHECK: !rtg.array<index>
print(arr)
tup = rtg.TupleType.get([indexTy, indexTy])
# CHECK: fields=[IndexType(index), IndexType(index)]
print(f"fields={tup.fields}")
# CHECK: !rtg.tuple<index, index>
print(tup)
tup = rtg.TupleType.get([])
# CHECK: fields=[]
print(f"fields={tup.fields}")
# CHECK: !rtg.tuple
print(tup)

View File

@ -111,6 +111,23 @@ void circt::python::populateDialectRTGSubmodule(nb::module_ &m) {
return rtgArrayTypeGetElementType(self);
});
mlir_type_subclass(m, "TupleType", rtgTypeIsATuple)
.def_classmethod(
"get",
[](nb::object cls, const std::vector<MlirType> &fieldTypes,
MlirContext ctxt) {
return cls(
rtgTupleTypeGet(ctxt, fieldTypes.size(), fieldTypes.data()));
},
nb::arg("self"), nb::arg("field_types") = std::vector<MlirType>(),
nb::arg("ctxt") = nullptr)
.def_property_readonly("fields", [](MlirType self) {
std::vector<MlirType> fields;
for (intptr_t i = 0; i < rtgTypeGetNumFields(self); ++i)
fields.push_back(rtgTupleTypeGetFieldType(self, i));
return fields;
});
// Types for ISA targets
//===--------------------------------------------------------------------===//

View File

@ -173,6 +173,10 @@ def type_to_pytype(t) -> ir.Type:
return rtg.MemoryBlockType(t)
except ValueError:
pass
try:
return rtg.TupleType(t)
except ValueError:
pass
try:
return rtgtest.IntegerRegisterType(t)
except ValueError:

View File

@ -136,6 +136,29 @@ MlirType rtgArrayTypeGetElementType(MlirType type) {
return wrap(cast<ArrayType>(unwrap(type)).getElementType());
}
// TupleType
//===----------------------------------------------------------------------===//
MlirType rtgTupleTypeGet(MlirContext ctxt, intptr_t numFields,
MlirType const *fieldTypes) {
SmallVector<Type> types;
for (unsigned i = 0; i < numFields; ++i)
types.emplace_back(unwrap(fieldTypes[i]));
return wrap(rtg::TupleType::get(unwrap(ctxt), types));
}
bool rtgTypeIsATuple(MlirType type) {
return isa<rtg::TupleType>(unwrap(type));
}
intptr_t rtgTypeGetNumFields(MlirType type) {
return cast<rtg::TupleType>(unwrap(type)).getFieldTypes().size();
}
MlirType rtgTupleTypeGetFieldType(MlirType type, intptr_t idx) {
return wrap(cast<rtg::TupleType>(unwrap(type)).getFieldTypes()[idx]);
}
// ImmediateType
//===----------------------------------------------------------------------===//

View File

@ -285,7 +285,7 @@ LogicalResult SetCartesianProductOp::inferReturnTypes(
for (auto operand : operands)
elementTypes.push_back(cast<SetType>(operand.getType()).getElementType());
inferredReturnTypes.push_back(
SetType::get(TupleType::get(context, elementTypes)));
SetType::get(rtg::TupleType::get(context, elementTypes)));
return success();
}
@ -369,16 +369,10 @@ LogicalResult TupleCreateOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands.empty()) {
if (loc)
return mlir::emitError(*loc) << "empty tuples not allowed";
return failure();
}
SmallVector<Type> elementTypes;
for (auto operand : operands)
elementTypes.push_back(operand.getType());
inferredReturnTypes.push_back(TupleType::get(context, elementTypes));
inferredReturnTypes.push_back(rtg::TupleType::get(context, elementTypes));
return success();
}
@ -392,18 +386,24 @@ LogicalResult TupleExtractOp::inferReturnTypes(
SmallVectorImpl<Type> &inferredReturnTypes) {
assert(operands.size() == 1 && "must have exactly one operand");
auto tupleTy = dyn_cast<TupleType>(operands[0].getType());
auto tupleTy = dyn_cast<rtg::TupleType>(operands[0].getType());
size_t idx = properties.as<Properties *>()->getIndex().getInt();
if (!tupleTy || tupleTy.getTypes().size() <= idx) {
if (!tupleTy) {
if (loc)
return mlir::emitError(*loc) << "only RTG tuples are supported";
return failure();
}
if (tupleTy.getFieldTypes().size() <= idx) {
if (loc)
return mlir::emitError(*loc)
<< "index (" << idx
<< ") must be smaller than number of elements in tuple ("
<< tupleTy.getTypes().size() << ")";
<< tupleTy.getFieldTypes().size() << ")";
return failure();
}
inferredReturnTypes.push_back(tupleTy.getTypes()[idx]);
inferredReturnTypes.push_back(tupleTy.getFieldTypes()[idx]);
return success();
}

View File

@ -115,6 +115,27 @@ static void testArrayType(MlirContext ctx) {
mlirTypeDump(arrayTy);
}
static void testTupleType(MlirContext ctx) {
MlirType i32Type = mlirIntegerTypeGet(ctx, 32);
MlirType tupleTy = rtgTupleTypeGet(ctx, 2, (MlirType[]){i32Type, i32Type});
// CHECK: is_tuple
fprintf(stderr, rtgTypeIsATuple(tupleTy) ? "is_tuple\n" : "isnot_tuple\n");
// CHECK: 2
fprintf(stderr, "%ld\n", rtgTypeGetNumFields(tupleTy));
// CHECK: i32
mlirTypeDump(rtgTupleTypeGetFieldType(tupleTy, 1));
// CHECK: !rtg.tuple<i32, i32>
mlirTypeDump(tupleTy);
MlirType emptyTupleTy = rtgTupleTypeGet(ctx, 0, NULL);
// CHECK: is_tuple
fprintf(stderr,
rtgTypeIsATuple(emptyTupleTy) ? "is_tuple\n" : "isnot_tuple\n");
// CHECK: !rtg.tuple
mlirTypeDump(emptyTupleTy);
}
static void testLabelVisibilityAttr(MlirContext ctx) {
MlirAttribute labelVisibility =
rtgLabelVisibilityAttrGet(ctx, RTG_LABEL_VISIBILITY_GLOBAL);
@ -219,6 +240,7 @@ int main(int argc, char **argv) {
testBagType(ctx);
testDictType(ctx);
testArrayType(ctx);
testTupleType(ctx);
testLabelVisibilityAttr(ctx);
testContextAttrs(ctx);

View File

@ -62,7 +62,7 @@ rtg.sequence @seqRandomizationAndEmbedding() {
}
// CHECK-LABEL: @sets
func.func @sets(%arg0: i32, %arg1: i32) -> !rtg.set<tuple<i32, i32>> {
func.func @sets(%arg0: i32, %arg1: i32) -> !rtg.set<!rtg.tuple<i32, i32>> {
// CHECK: [[SET:%.+]] = rtg.set_create %arg0, %arg1 : i32
// CHECK: [[R:%.+]] = rtg.set_select_random [[SET]] : !rtg.set<i32>
// CHECK: [[EMPTY:%.+]] = rtg.set_create : i32
@ -80,7 +80,7 @@ func.func @sets(%arg0: i32, %arg1: i32) -> !rtg.set<tuple<i32, i32>> {
%prod = rtg.set_cartesian_product %set, %set : !rtg.set<i32>, !rtg.set<i32>
%bag = rtg.set_convert_to_bag %set : !rtg.set<i32>
return %prod : !rtg.set<tuple<i32, i32>>
return %prod : !rtg.set<!rtg.tuple<i32, i32>>
}
// CHECK-LABEL: @bags
@ -184,15 +184,18 @@ rtg.test @arrays(arr = %arr: !rtg.array<index>) {
}
// CHECK-LABEL: rtg.test @tuples
rtg.test @tuples() {
// CHECK-SAME: (tup = %{{.*}}: !rtg.tuple)
rtg.test @tuples(tup = %tup: !rtg.tuple) {
// CHECK-NEXT: [[IDX0:%.+]] = index.constant 0
// CHECK-NEXT: [[TRUE:%.+]] = index.bool.constant true
// CHECK-NEXT: [[TUPLE:%.+]] = rtg.tuple_create [[IDX0]], [[TRUE]] : index, i1
// CHECK-NEXT: rtg.tuple_extract [[TUPLE]] at 1 : tuple<index, i1>
// CHECK-NEXT: rtg.tuple_extract [[TUPLE]] at 1 : !rtg.tuple<index, i1>
// CHECK-NEXT: rtg.tuple_create
%idx0 = index.constant 0
%true = index.bool.constant true
%0 = rtg.tuple_create %idx0, %true : index, i1
%1 = rtg.tuple_extract %0 at 1 : tuple<index, i1>
%1 = rtg.tuple_extract %0 at 1 : !rtg.tuple<index, i1>
rtg.tuple_create
}
// CHECK-LABEL: @memoryBlocks : !rtg.dict<mem_base_address: !rtg.isa.immediate<32>, mem_block: !rtg.isa.memory_block<32>, mem_size: index>

View File

@ -176,7 +176,7 @@ rtg.sequence @seq() {
rtg.sequence @setCartesianProduct() {
// expected-error @below {{at least one set must be provided}}
// expected-error @below {{failed to infer returned types}}
%0 = "rtg.set_cartesian_product"() : () -> (!rtg.set<tuple<index>>)
%0 = "rtg.set_cartesian_product"() : () -> (!rtg.set<!rtg.tuple<index>>)
}
// -----
@ -250,16 +250,16 @@ rtg.test @test(a = %a: i32, b = %b: index) {
// -----
rtg.test @emptyTuple() {
// expected-error @below {{empty tuples not allowed}}
%0 = rtg.tuple_create
rtg.test @incorrect_tuple_type(tup = %tup : tuple<index, i1>) {
// expected-error @below {{only RTG tuples are supported}}
rtg.tuple_extract %tup at 2 : tuple<index, i1>
}
// -----
rtg.test @tupleExtractOOB(tup = %tup : tuple<index, i1>) {
rtg.test @tupleExtractOOB(tup = %tup : !rtg.tuple<index, i1>) {
// expected-error @below {{index (2) must be smaller than number of elements in tuple (2)}}
rtg.tuple_extract %tup at 2 : tuple<index, i1>
rtg.tuple_extract %tup at 2 : !rtg.tuple<index, i1>
}
// -----

View File

@ -7,9 +7,9 @@ func.func @dummy4(%arg0: index, %arg1: index, %arg2: !rtg.bag<index>, %arg3: !rt
func.func @dummy5(%arg0: i1) -> () {return}
func.func @dummy6(%arg0: !rtg.isa.immediate<2>) -> () {return}
func.func @dummy7(%arg0: !rtg.array<index>) -> () {return}
func.func @dummy8(%arg0: tuple<index, index>) -> () {return}
func.func @dummy9(%arg0: !rtg.set<tuple<index, i1, !rtgtest.ireg>>) -> () {return}
func.func @dummy10(%arg0: !rtg.set<tuple<index>>) -> () {return}
func.func @dummy8(%arg0: !rtg.tuple<index, index>) -> () {return}
func.func @dummy9(%arg0: !rtg.set<!rtg.tuple<index, i1, !rtgtest.ireg>>) -> () {return}
func.func @dummy10(%arg0: !rtg.set<!rtg.tuple<index>>) -> () {return}
func.func @dummy11(%arg0: !rtg.set<index>) -> () {return}
func.func @dummy12(%arg0: !rtg.bag<index>) -> () {return}
func.func @dummy13(%arg0: !rtg.isa.memory_block<32>) -> () {return}
@ -89,23 +89,23 @@ rtg.test @setCartesianProduct(singleton = %none: index) {
// CHECK-DAG: [[T6:%.+]] = rtg.tuple_create [[IDX0]], [[FALSE]], [[S0]] : index, i1, !rtgtest.ireg
// CHECK-DAG: [[T7:%.+]] = rtg.tuple_create [[IDX1]], [[TRUE]], [[S0]] : index, i1, !rtgtest.ireg
// CHECK-DAG: [[T8:%.+]] = rtg.tuple_create [[IDX0]], [[TRUE]], [[S0]] : index, i1, !rtgtest.ireg
// CHECK-DAG: [[SET:%.+]] = rtg.set_create [[T1]], [[T2]], [[T3]], [[T4]], [[T5]], [[T6]], [[T7]], [[T8]] : tuple<index, i1, !rtgtest.ireg>
// CHECK-NEXT: func.call @dummy9([[SET]]) : (!rtg.set<tuple<index, i1, !rtgtest.ireg>>) -> ()
// CHECK-DAG: [[SET:%.+]] = rtg.set_create [[T1]], [[T2]], [[T3]], [[T4]], [[T5]], [[T6]], [[T7]], [[T8]] : !rtg.tuple<index, i1, !rtgtest.ireg>
// CHECK-NEXT: func.call @dummy9([[SET]]) : (!rtg.set<!rtg.tuple<index, i1, !rtgtest.ireg>>) -> ()
%3 = rtg.set_cartesian_product %0, %1, %2 : !rtg.set<index>, !rtg.set<i1>, !rtg.set<!rtgtest.ireg>
func.call @dummy9(%3) : (!rtg.set<tuple<index, i1, !rtgtest.ireg>>) -> ()
func.call @dummy9(%3) : (!rtg.set<!rtg.tuple<index, i1, !rtgtest.ireg>>) -> ()
// CHECK-NEXT: [[EMPTY:%.+]] = rtg.set_create : tuple<index, i1, !rtgtest.ireg>
// CHECK-NEXT: func.call @dummy9([[EMPTY]]) : (!rtg.set<tuple<index, i1, !rtgtest.ireg>>) -> ()
// CHECK-NEXT: [[EMPTY:%.+]] = rtg.set_create : !rtg.tuple<index, i1, !rtgtest.ireg>
// CHECK-NEXT: func.call @dummy9([[EMPTY]]) : (!rtg.set<!rtg.tuple<index, i1, !rtgtest.ireg>>) -> ()
%4 = rtg.set_create : !rtgtest.ireg
%5 = rtg.set_cartesian_product %0, %1, %4 : !rtg.set<index>, !rtg.set<i1>, !rtg.set<!rtgtest.ireg>
func.call @dummy9(%5) : (!rtg.set<tuple<index, i1, !rtgtest.ireg>>) -> ()
func.call @dummy9(%5) : (!rtg.set<!rtg.tuple<index, i1, !rtgtest.ireg>>) -> ()
// CHECK-NEXT: [[T9:%.+]] = rtg.tuple_create [[IDX1]] : index
// CHECK-NEXT: [[T10:%.+]] = rtg.tuple_create [[IDX0]] : index
// CHECK-NEXT: [[SET2:%.+]] = rtg.set_create [[T9]], [[T10]] : tuple<index>
// CHECK-NEXT: func.call @dummy10([[SET2]]) : (!rtg.set<tuple<index>>) -> ()
// CHECK-NEXT: [[SET2:%.+]] = rtg.set_create [[T9]], [[T10]] : !rtg.tuple<index>
// CHECK-NEXT: func.call @dummy10([[SET2]]) : (!rtg.set<!rtg.tuple<index>>) -> ()
%6 = rtg.set_cartesian_product %0 : !rtg.set<index>
func.call @dummy10(%6) : (!rtg.set<tuple<index>>) -> ()
func.call @dummy10(%6) : (!rtg.set<!rtg.tuple<index>>) -> ()
}
// CHECK-LABEL: rtg.test @bagOperations
@ -675,13 +675,13 @@ rtg.test @tuples(singleton = %none: index) {
%idx0 = index.constant 0
%idx1 = index.constant 1
%0 = rtg.tuple_create %idx1, %idx0 : index, index
%1 = rtg.tuple_extract %0 at 1 : tuple<index, index>
%1 = rtg.tuple_extract %0 at 1 : !rtg.tuple<index, index>
// CHECK-NEXT: %idx1 = index.constant 1
// CHECK-NEXT: %idx0 = index.constant 0
// CHECK-NEXT: [[V0:%.+]] = rtg.tuple_create %idx1, %idx0 : index, index
// CHECK-NEXT: func.call @dummy8([[V0]])
func.call @dummy8(%0) : (tuple<index, index>) -> ()
func.call @dummy8(%0) : (!rtg.tuple<index, index>) -> ()
// CHECK-NEXT: func.call @dummy2(%idx0)
func.call @dummy2(%1) : (index) -> ()