[RTG] Add memory type and operations (#8368)

This commit is contained in:
Martin Erhart 2025-05-14 11:07:03 +01:00 committed by GitHub
parent aa8b980f6a
commit 3c9922be92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 296 additions and 6 deletions

View File

@ -22,6 +22,7 @@ declare_mlir_python_sources(PyRTGSources
pyrtg/index.py
pyrtg/integers.py
pyrtg/labels.py
pyrtg/memories.py
pyrtg/rtg.py
pyrtg/sequences.py
pyrtg/sets.py

View File

@ -20,3 +20,4 @@ from .arrays import Array
from .contexts import CPUCore
from .control_flow import If, Else, EndIf, For, Foreach
from .tuples import Tuple
from .memories import Memory, MemoryBlock

View File

@ -0,0 +1,97 @@
from __future__ import annotations
from .core import Value
from .circt import ir
from .index import index
from .rtg import rtg
from .integers import Integer
from .resources import Immediate
from typing import Union
class MemoryBlock(Value):
def __init__(self, value: ir.Value):
"""
For library internal usage only.
"""
self._value = value
def declare(base_address: int, end_address: int,
address_width: int) -> MemoryBlock:
"""
Declare a new memory block with the specified parameters.
Args:
base_address: The first valid address of the memory
end_address: The last valid address of the memory
address_width: The width of the memory block addresses in bits.
"""
return rtg.MemoryBlockDeclareOp(
rtg.MemoryBlockType.get(address_width),
ir.IntegerAttr.get(ir.IntegerType.get_signless(address_width),
base_address),
ir.IntegerAttr.get(ir.IntegerType.get_signless(address_width),
end_address))
def _get_ssa_value(self) -> ir.Value:
return self._value
def get_type(self) -> ir.Type:
return self._value.type
def type(address_width: int) -> ir.Type:
return rtg.MemoryBlockType.get(address_width)
class Memory(Value):
def __init__(self, value: ir.Value):
"""
For library internal usage only.
"""
self._value = value
def alloc(mem_block: MemoryBlock, size: Union[Integer, int],
align: Union[Integer, int]) -> Memory:
"""
Allocate a new memory from a memory block with the specified parameters.
Args:
size: The size of the memory in bytes.
align: The alignment of the memory in bytes.
"""
if isinstance(size, int):
size = index.ConstantOp(size)
if isinstance(align, int):
align = index.ConstantOp(align)
return rtg.MemoryAllocOp(mem_block, size, align)
def size(self) -> Integer:
"""
Get the size of the memory in bytes.
"""
return rtg.MemorySizeOp(self._value)
def base_address(self) -> Immediate:
"""
Get the base address of the memory as an immediate matching the memories
address width.
"""
return rtg.MemoryBaseAddressOp(self._value)
def _get_ssa_value(self) -> ir.Value:
return self._value
def get_type(self) -> ir.Type:
return self._value.type
def type(address_width: int) -> ir.Type:
return rtg.MemoryType.get(address_width)

View File

@ -46,6 +46,12 @@ def _FromCirctValue(value: ir.Value) -> Value:
if isinstance(type, ir.TupleType):
from .tuples import Tuple
return Tuple(value)
if isinstance(type, rtg.MemoryType):
from .memories import Memory
return Memory(value)
if isinstance(type, rtg.MemoryBlockType):
from .memories import MemoryBlock
return MemoryBlock(value)
assert False, "Unsupported value"

View File

@ -2,7 +2,7 @@
# RUN: %rtgtool% %s --seed=0 --output-format=elaborated | FileCheck %s --check-prefix=ELABORATED
# RUN: %rtgtool% %s --seed=0 -o %t --output-format=asm && FileCheck %s --input-file=%t --check-prefix=ASM
from pyrtg import test, sequence, target, entry, rtg, Label, Set, Integer, Bag, rtgtest, Immediate, IntegerRegister, Array, Bool, Tuple, embed_comment
from pyrtg import test, sequence, target, entry, rtg, Label, Set, Integer, Bag, rtgtest, Immediate, IntegerRegister, Array, Bool, Tuple, embed_comment, MemoryBlock, Memory
# MLIR-LABEL: rtg.target @Tgt0 : !rtg.dict<entry0: !rtg.set<index>>
# MLIR-NEXT: [[C0:%.+]] = index.constant 0
@ -39,6 +39,19 @@ class Tgt1:
return Label.declare("l0")
# MLIR-LABEL: rtg.target @Tgt2
# MLIR-NEXT: [[V0:%.+]] = rtg.isa.memory_block_declare [0x0 - 0x3f] : !rtg.isa.memory_block<32>
# MLIR-NEXT: rtg.yield [[V0]] : !rtg.isa.memory_block<32>
@target
class Tgt2:
@entry
def mem_blk_0():
return MemoryBlock.declare(base_address=0, end_address=63, address_width=32)
# MLIR-LABEL: rtg.target @Tgt4
# MLIR-NEXT: [[IDX12:%.+]] = index.constant 12
# MLIR-NEXT: [[IDX11:%.+]] = index.constant 11
@ -361,6 +374,25 @@ def test4_integer_to_immediate():
Immediate(12, Integer(2)))
# MLIR-LABEL: rtg.test @test6_memories
# MLIR-NEXT: [[REG:%.+]] = rtg.fixed_reg #rtgtest.t0 : !rtgtest.ireg
# MLIR-NEXT: [[IDX8:%.+]] = index.constant 8
# MLIR-NEXT: [[IDX4:%.+]] = index.constant 4
# MLIR-NEXT: [[MEM:%.+]] = rtg.isa.memory_alloc %mem_blk, [[IDX8]], [[IDX4]] : !rtg.isa.memory_block<32>
# MLIR-NEXT: [[SIZE:%.+]] = rtg.isa.memory_size [[MEM]] : !rtg.isa.memory<32>
# MLIR-NEXT: [[IMM:%.+]] = rtg.isa.int_to_immediate [[SIZE]] : !rtg.isa.immediate<32>
# MLIR-NEXT: rtgtest.rv32i.auipc [[REG]], [[IMM]] : !rtg.isa.immediate<32>
# MLIR-NEXT: [[BASE:%.+]] = rtg.isa.memory_base_address [[MEM]] : !rtg.isa.memory<32>
# MLIR-NEXT: rtgtest.rv32i.auipc [[REG]], [[BASE]] : !rtg.isa.immediate<32>
@test(("mem_blk", MemoryBlock.type(32)))
def test6_memories(mem_blk):
mem = Memory.alloc(mem_blk, size=8, align=4)
rtgtest.AUIPC(IntegerRegister.t0(), Immediate(32, mem.size()))
rtgtest.AUIPC(IntegerRegister.t0(), mem.base_address())
# MLIR-LABEL: rtg.test @test7_bools
# MLIR: index.bool.constant false
# MLIR: index.bool.constant true

View File

@ -98,6 +98,16 @@ MLIR_CAPI_EXPORTED bool rtgTypeIsAArray(MlirType type);
/// Returns the element type of the RTG array.
MLIR_CAPI_EXPORTED MlirType rtgArrayTypeGetElementType(MlirType type);
/// If the type is an RTG memory.
MLIR_CAPI_EXPORTED bool rtgTypeIsAMemory(MlirType type);
/// Creates an RTG memory type in the context.
MLIR_CAPI_EXPORTED MlirType rtgMemoryTypeGet(MlirContext ctx,
uint32_t addressWidth);
/// Returns the address with of an RTG memory type.
MLIR_CAPI_EXPORTED uint32_t rtgMemoryTypeGetAddressWidth(MlirType type);
/// If the type is an RTG memory block.
MLIR_CAPI_EXPORTED bool rtgTypeIsAMemoryBlock(MlirType type);

View File

@ -817,3 +817,55 @@ def MemoryBlockDeclareOp : RTGISAOp<"memory_block_declare", [
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
//===- ISA Memory Handling Operations -------------------------------------===//
def MemoryAllocOp : RTGISAOp<"memory_alloc", [
TypesMatchWith<"memory must have the same address width as the memory block",
"memoryBlock", "result",
"MemoryType::get($_ctxt, " #
"cast<MemoryBlockType>($_self).getAddressWidth())">,
]> {
let summary = "allocate a memory with the provided properties";
let description = [{
This operation declares a memory to be allocated with the provided
properties. It is only allowed to declare new memories in the `rtg.target`
operations and must be passed as argument to the `rtg.test`.
}];
let arguments = (ins MemoryBlockType:$memoryBlock,
Index:$size,
Index:$alignment);
let results = (outs MemoryType:$result);
let assemblyFormat = [{
$memoryBlock `,` $size `,` $alignment
`:` qualified(type($memoryBlock)) attr-dict
}];
}
def MemoryBaseAddressOp : RTGISAOp<"memory_base_address", [
Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
]> {
let summary = "get the memory base address as an immediate";
let description = [{
This operation returns the base address of the given memory. The bit-width
of the returned immediate must match the address width of the given memory.
}];
let arguments = (ins MemoryType:$memory);
let results = (outs ImmediateType:$result);
let assemblyFormat = "$memory `:` qualified(type($memory)) attr-dict";
}
def MemorySizeOp : RTGISAOp<"memory_size", [Pure]> {
let summary = "get the size of the memory in bytes";
let arguments = (ins MemoryType:$memory);
let results = (outs Index:$result);
let assemblyFormat = "$memory `:` qualified(type($memory)) attr-dict";
}

View File

@ -187,6 +187,18 @@ class ImmediateOfWidth<int width> : Type<
BuildableType<"::circt::rtg::ImmediateType::get($_builder.getContext(), " #
width # ")">;
def MemoryType : RTGISATypeDef<"Memory", "memory"> {
let summary = "handle to a memory";
let description = [{
This type is used to represent memory resources that are allocated from
memory blocks and can be accessed and manipulated by payload dialect
operations.
}];
let parameters = (ins "uint32_t":$addressWidth);
let assemblyFormat = "`<` $addressWidth `>`";
}
def MemoryBlockType : RTGISATypeDef<"MemoryBlock", "memory_block"> {
let summary = "handle to a memory block";
let description = [{

View File

@ -59,6 +59,8 @@ public:
TupleCreateOp, TupleExtractOp,
// Immediates
IntToImmediateOp,
// Memories
MemoryAllocOp, MemoryBaseAddressOp, MemorySizeOp,
// Memory Blocks
MemoryBlockDeclareOp,
// Misc ops
@ -134,6 +136,9 @@ public:
HANDLE(VirtualRegisterOp, Unhandled);
HANDLE(IntToImmediateOp, Unhandled);
HANDLE(MemoryBlockDeclareOp, Unhandled);
HANDLE(MemoryAllocOp, Unhandled);
HANDLE(MemoryBaseAddressOp, Unhandled);
HANDLE(MemorySizeOp, Unhandled);
#undef HANDLE
};

View File

@ -225,6 +225,12 @@ with Context() as ctx, Location.unknown():
# CHECK: !rtg.isa.memory_block<32>
print(memory_block_type)
memoryTy = rtg.MemoryType.get(32)
# CHECK: address_width=32
print(f'address_width={memoryTy.address_width}')
# CHECK: !rtg.isa.memory<32>
print(memoryTy)
with Context() as ctx, Location.unknown():
circt.register_dialects(ctx)
indexTy = IndexType.get()

View File

@ -144,6 +144,17 @@ void circt::python::populateDialectRTGSubmodule(nb::module_ &m) {
return rtgMemoryBlockTypeGetAddressWidth(self);
});
mlir_type_subclass(m, "MemoryType", rtgTypeIsAMemory)
.def_classmethod(
"get",
[](nb::object cls, uint32_t addressWidth, MlirContext ctxt) {
return cls(rtgMemoryTypeGet(ctxt, addressWidth));
},
nb::arg("self"), nb::arg("address_width"), nb::arg("ctxt") = nullptr)
.def_property_readonly("address_width", [](MlirType self) {
return rtgMemoryTypeGetAddressWidth(self);
});
//===--------------------------------------------------------------------===//
// Attributes
//===--------------------------------------------------------------------===//

View File

@ -165,6 +165,14 @@ def type_to_pytype(t) -> ir.Type:
return rtg.ArrayType(t)
except ValueError:
pass
try:
return rtg.MemoryType(t)
except ValueError:
pass
try:
return rtg.MemoryBlockType(t)
except ValueError:
pass
try:
return rtgtest.IntegerRegisterType(t)
except ValueError:

View File

@ -144,6 +144,19 @@ uint32_t rtgImmediateTypeGetWidth(MlirType type) {
return cast<ImmediateType>(unwrap(type)).getWidth();
}
// MemoryType
//===----------------------------------------------------------------------===//
bool rtgTypeIsAMemory(MlirType type) { return isa<MemoryType>(unwrap(type)); }
MlirType rtgMemoryTypeGet(MlirContext ctxt, uint32_t addressWidth) {
return wrap(MemoryType::get(unwrap(ctxt), addressWidth));
}
uint32_t rtgMemoryTypeGetAddressWidth(MlirType type) {
return cast<MemoryType>(unwrap(type)).getAddressWidth();
}
// MemoryBlockType
//===----------------------------------------------------------------------===//

View File

@ -731,6 +731,24 @@ void MemoryBlockDeclareOp::print(OpAsmPrinter &p) {
{getBaseAddressAttrName(), getEndAddressAttrName()});
}
//===----------------------------------------------------------------------===//
// MemoryBaseAddressOp
//===----------------------------------------------------------------------===//
LogicalResult MemoryBaseAddressOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands.empty())
return failure();
auto memTy = dyn_cast<MemoryType>(operands[0].getType());
if (!memTy)
return failure();
inferredReturnTypes.push_back(
ImmediateType::get(context, memTy.getAddressWidth()));
return success();
}
//===----------------------------------------------------------------------===//
// TableGen generated logic.
//===----------------------------------------------------------------------===//

View File

@ -175,7 +175,7 @@ static void testImmediate(MlirContext ctx) {
rtgImmediateAttrGetValue(immediateAttr));
}
static void testMemoryBlock(MlirContext ctx) {
static void testMemories(MlirContext ctx) {
MlirType memoryBlockTy = rtgMemoryBlockTypeGet(ctx, 32);
// CHECK: is_memory_block
@ -187,6 +187,15 @@ static void testMemoryBlock(MlirContext ctx) {
// CHECK: address_width=32
fprintf(stderr, "address_width=%u\n",
rtgMemoryBlockTypeGetAddressWidth(memoryBlockTy));
MlirType memoryTy = rtgMemoryTypeGet(ctx, 32);
// CHECK: is_memory
fprintf(stderr,
rtgTypeIsAMemory(memoryTy) ? "is_memory\n" : "isnot_memory\n");
// CHECK: addressWidth=32
fprintf(stderr, "addressWidth=%u\n", rtgMemoryTypeGetAddressWidth(memoryTy));
// CHECK: !rtg.isa.memory<32>
mlirTypeDump(memoryTy);
}
int main(int argc, char **argv) {
@ -205,7 +214,7 @@ int main(int argc, char **argv) {
testLabelVisibilityAttr(ctx);
testDefaultContextAttr(ctx);
testImmediate(ctx);
testMemoryBlock(ctx); // Add the new test
testMemories(ctx);
mlirContextDestroy(ctx);

View File

@ -192,10 +192,19 @@ rtg.test @tuples() {
%1 = rtg.tuple_extract %0 at 1 : tuple<index, i1>
}
// CHECK-LABEL: @memoryBlocks : !rtg.dict<mem_block: !rtg.isa.memory_block<32>>
rtg.target @memoryBlocks : !rtg.dict<mem_block: !rtg.isa.memory_block<32>> {
// CHECK-LABEL: @memoryBlocks : !rtg.dict<mem_base_address: !rtg.isa.immediate<32>, mem_block: !rtg.isa.memory_block<32>, mem_size: index>
rtg.target @memoryBlocks : !rtg.dict<mem_base_address: !rtg.isa.immediate<32>, mem_block: !rtg.isa.memory_block<32>, mem_size: index> {
// CHECK: rtg.isa.memory_block_declare [0x0 - 0x8] : !rtg.isa.memory_block<32>
%0 = rtg.isa.memory_block_declare [0x0 - 0x8] : !rtg.isa.memory_block<32>
// CHECK: [[IDX8:%.+]] = index.constant 8
// CHECK: [[V1:%.+]] = rtg.isa.memory_alloc %0, [[IDX8]], [[IDX8]] : !rtg.isa.memory_block<32>
// CHECK: [[V2:%.+]] = rtg.isa.memory_base_address [[V1]] : !rtg.isa.memory<32>
// CHECK: [[V3:%.+]] = rtg.isa.memory_size [[V1]] : !rtg.isa.memory<32>
%idx8 = index.constant 8
%1 = rtg.isa.memory_alloc %0, %idx8, %idx8 : !rtg.isa.memory_block<32>
%2 = rtg.isa.memory_base_address %1 : !rtg.isa.memory<32>
%3 = rtg.isa.memory_size %1 : !rtg.isa.memory<32>
rtg.yield %0 : !rtg.isa.memory_block<32>
rtg.yield %2, %0, %3 : !rtg.isa.immediate<32>, !rtg.isa.memory_block<32>, index
}