[PyCDE] Improving support for Any type (#8044)

- Adds a `castable` method to Type which passes through to the
`checkInnerTypeMatch` C++ function.
- Use that method to type check the results of BundleSignal.unpack.
This commit is contained in:
John Demme 2025-01-08 18:29:11 -05:00 committed by GitHub
parent 42e4c202f2
commit 405b6b873f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 39 additions and 8 deletions

View File

@ -5,7 +5,7 @@
from __future__ import annotations
from .support import get_user_loc, _obj_to_value_infer_type
from .types import ChannelDirection, ChannelSignaling, Type
from .types import BundledChannel, ChannelDirection, ChannelSignaling, Type
from .circt.dialects import esi, sv
from .circt import support
@ -784,15 +784,14 @@ class BundleSignal(Signal):
def reg(self, clk, rst=None, name=None):
raise TypeError("Cannot register a bundle")
def unpack(self, **kwargs: Dict[str,
ChannelSignal]) -> Dict[str, ChannelSignal]:
def unpack(self, **kwargs: ChannelSignal) -> Dict[str, ChannelSignal]:
"""Given FROM channels, unpack a bundle into the TO channels."""
from_channels = {
bc.name: (idx, bc) for idx, bc in enumerate(
filter(lambda c: c.direction == ChannelDirection.FROM,
self.type.channels))
}
to_channels = [
to_channels: List[BundledChannel] = [
c for c in self.type.channels if c.direction == ChannelDirection.TO
]
@ -801,7 +800,7 @@ class BundleSignal(Signal):
if name not in from_channels:
raise ValueError(f"Unknown channel name '{name}'")
idx, bc = from_channels[name]
if value.type != bc.channel:
if not bc.channel.castable(value.type):
raise TypeError(f"Expected channel type {bc.channel}, got {value.type} "
f"on channel '{name}'")
operands[idx] = value.value
@ -814,10 +813,13 @@ class BundleSignal(Signal):
self.value, operands)
to_channels_results = unpack_op.toChannels
return {
ret = {
bc.name: _FromCirctValue(to_channels_results[idx])
for idx, bc in enumerate(to_channels)
}
if not all([bc.channel.castable(ret[bc.name].type) for bc in to_channels]):
raise TypeError("Unpacked bundle did not match expected types")
return ret
def connect(self, other: BundleSignal):
"""Connect two bundles together such that one drives the other."""

View File

@ -134,6 +134,12 @@ class Type:
"""Create an array type"""
return Array(self, len)
def castable(self, value: Type) -> bool:
"""Return True if a value of 'value' can be cast to this type."""
if not isinstance(value, Type):
raise TypeError("Can only cast to a Type")
return esi.check_inner_type_match(self._type, value._type)
def __repr__(self):
return self._type.__repr__()
@ -551,6 +557,16 @@ class Any(Type):
def is_hw_type(self) -> bool:
return False
def _from_obj_or_sig(self,
obj,
alias: typing.Optional["TypeAlias"] = None) -> "Signal":
"""Any signal can be any type. Skip the type check."""
from .signals import Signal
if isinstance(obj, Signal):
return obj
return self._from_obj(obj, alias)
class Channel(Type):
"""An ESI channel type."""
@ -713,12 +729,15 @@ class Bundle(Type):
return False
@property
def channels(self):
def channels(self) -> typing.List[BundledChannel]:
return [
BundledChannel(name, dir, _FromCirctType(type))
for (name, dir, type) in self._type.channels
]
def castable(self, _) -> bool:
raise TypeError("Cannot check cast-ablity to a bundle")
def inverted(self) -> "Bundle":
"""Return a new bundle with all the channels direction inverted."""
return Bundle([

View File

@ -39,6 +39,8 @@ MLIR_CAPI_EXPORTED void circtESIAppendMlirFile(MlirModule,
MLIR_CAPI_EXPORTED MlirOperation circtESILookup(MlirModule,
MlirStringRef symbol);
MLIR_CAPI_EXPORTED bool circtESICheckInnerTypeMatch(MlirType to, MlirType from);
//===----------------------------------------------------------------------===//
// Channel bundles
//===----------------------------------------------------------------------===//

View File

@ -33,7 +33,7 @@ struct ServicePortInfo {
ChannelBundleType type;
};
// Check that the channels on two bundles match allowing for AnyType.
// Check that two types match, allowing for AnyType in 'expected'.
// NOLINTNEXTLINE(misc-no-recursion)
LogicalResult checkInnerTypeMatch(Type expected, Type actual);
/// Check that the channels on two bundles match allowing for AnyType in the

View File

@ -221,6 +221,10 @@ void circt::python::populateDialectESISubmodule(py::module &m) {
.def("__len__", &circtESIAppIDAttrPathGetNumComponents)
.def("__getitem__", &circtESIAppIDAttrPathGetComponent);
m.def("check_inner_type_match", &circtESICheckInnerTypeMatch,
"Check that two types match, allowing for AnyType in 'expected'.",
py::arg("expected"), py::arg("actual"));
py::class_<PyAppIDIndex>(m, "AppIDIndex")
.def(py::init<MlirOperation>(), py::arg("root"))
.def("get_child_appids_of", &PyAppIDIndex::getChildAppIDsOf,

View File

@ -70,6 +70,10 @@ MlirType circtESIListTypeGetElementType(MlirType list) {
return wrap(cast<ListType>(unwrap(list)).getElementType());
}
bool circtESICheckInnerTypeMatch(MlirType to, MlirType from) {
return succeeded(checkInnerTypeMatch(unwrap(to), unwrap(from)));
}
void circtESIAppendMlirFile(MlirModule cMod, MlirStringRef filename) {
ModuleOp modOp = unwrap(cMod);
auto loadedMod =