mirror of https://github.com/llvm/circt.git
[PyCDE] Improving support for Any type (#8044)
- Adds a `castable` method to Type which passes through to the `checkInnerTypeMatch` C++ function. - Use that method to type check the results of BundleSignal.unpack.
This commit is contained in:
parent
42e4c202f2
commit
405b6b873f
|
@ -5,7 +5,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .support import get_user_loc, _obj_to_value_infer_type
|
||||
from .types import ChannelDirection, ChannelSignaling, Type
|
||||
from .types import BundledChannel, ChannelDirection, ChannelSignaling, Type
|
||||
|
||||
from .circt.dialects import esi, sv
|
||||
from .circt import support
|
||||
|
@ -784,15 +784,14 @@ class BundleSignal(Signal):
|
|||
def reg(self, clk, rst=None, name=None):
|
||||
raise TypeError("Cannot register a bundle")
|
||||
|
||||
def unpack(self, **kwargs: Dict[str,
|
||||
ChannelSignal]) -> Dict[str, ChannelSignal]:
|
||||
def unpack(self, **kwargs: ChannelSignal) -> Dict[str, ChannelSignal]:
|
||||
"""Given FROM channels, unpack a bundle into the TO channels."""
|
||||
from_channels = {
|
||||
bc.name: (idx, bc) for idx, bc in enumerate(
|
||||
filter(lambda c: c.direction == ChannelDirection.FROM,
|
||||
self.type.channels))
|
||||
}
|
||||
to_channels = [
|
||||
to_channels: List[BundledChannel] = [
|
||||
c for c in self.type.channels if c.direction == ChannelDirection.TO
|
||||
]
|
||||
|
||||
|
@ -801,7 +800,7 @@ class BundleSignal(Signal):
|
|||
if name not in from_channels:
|
||||
raise ValueError(f"Unknown channel name '{name}'")
|
||||
idx, bc = from_channels[name]
|
||||
if value.type != bc.channel:
|
||||
if not bc.channel.castable(value.type):
|
||||
raise TypeError(f"Expected channel type {bc.channel}, got {value.type} "
|
||||
f"on channel '{name}'")
|
||||
operands[idx] = value.value
|
||||
|
@ -814,10 +813,13 @@ class BundleSignal(Signal):
|
|||
self.value, operands)
|
||||
|
||||
to_channels_results = unpack_op.toChannels
|
||||
return {
|
||||
ret = {
|
||||
bc.name: _FromCirctValue(to_channels_results[idx])
|
||||
for idx, bc in enumerate(to_channels)
|
||||
}
|
||||
if not all([bc.channel.castable(ret[bc.name].type) for bc in to_channels]):
|
||||
raise TypeError("Unpacked bundle did not match expected types")
|
||||
return ret
|
||||
|
||||
def connect(self, other: BundleSignal):
|
||||
"""Connect two bundles together such that one drives the other."""
|
||||
|
|
|
@ -134,6 +134,12 @@ class Type:
|
|||
"""Create an array type"""
|
||||
return Array(self, len)
|
||||
|
||||
def castable(self, value: Type) -> bool:
|
||||
"""Return True if a value of 'value' can be cast to this type."""
|
||||
if not isinstance(value, Type):
|
||||
raise TypeError("Can only cast to a Type")
|
||||
return esi.check_inner_type_match(self._type, value._type)
|
||||
|
||||
def __repr__(self):
|
||||
return self._type.__repr__()
|
||||
|
||||
|
@ -551,6 +557,16 @@ class Any(Type):
|
|||
def is_hw_type(self) -> bool:
|
||||
return False
|
||||
|
||||
def _from_obj_or_sig(self,
|
||||
obj,
|
||||
alias: typing.Optional["TypeAlias"] = None) -> "Signal":
|
||||
"""Any signal can be any type. Skip the type check."""
|
||||
|
||||
from .signals import Signal
|
||||
if isinstance(obj, Signal):
|
||||
return obj
|
||||
return self._from_obj(obj, alias)
|
||||
|
||||
|
||||
class Channel(Type):
|
||||
"""An ESI channel type."""
|
||||
|
@ -713,12 +729,15 @@ class Bundle(Type):
|
|||
return False
|
||||
|
||||
@property
|
||||
def channels(self):
|
||||
def channels(self) -> typing.List[BundledChannel]:
|
||||
return [
|
||||
BundledChannel(name, dir, _FromCirctType(type))
|
||||
for (name, dir, type) in self._type.channels
|
||||
]
|
||||
|
||||
def castable(self, _) -> bool:
|
||||
raise TypeError("Cannot check cast-ablity to a bundle")
|
||||
|
||||
def inverted(self) -> "Bundle":
|
||||
"""Return a new bundle with all the channels direction inverted."""
|
||||
return Bundle([
|
||||
|
|
|
@ -39,6 +39,8 @@ MLIR_CAPI_EXPORTED void circtESIAppendMlirFile(MlirModule,
|
|||
MLIR_CAPI_EXPORTED MlirOperation circtESILookup(MlirModule,
|
||||
MlirStringRef symbol);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool circtESICheckInnerTypeMatch(MlirType to, MlirType from);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Channel bundles
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -33,7 +33,7 @@ struct ServicePortInfo {
|
|||
ChannelBundleType type;
|
||||
};
|
||||
|
||||
// Check that the channels on two bundles match allowing for AnyType.
|
||||
// Check that two types match, allowing for AnyType in 'expected'.
|
||||
// NOLINTNEXTLINE(misc-no-recursion)
|
||||
LogicalResult checkInnerTypeMatch(Type expected, Type actual);
|
||||
/// Check that the channels on two bundles match allowing for AnyType in the
|
||||
|
|
|
@ -221,6 +221,10 @@ void circt::python::populateDialectESISubmodule(py::module &m) {
|
|||
.def("__len__", &circtESIAppIDAttrPathGetNumComponents)
|
||||
.def("__getitem__", &circtESIAppIDAttrPathGetComponent);
|
||||
|
||||
m.def("check_inner_type_match", &circtESICheckInnerTypeMatch,
|
||||
"Check that two types match, allowing for AnyType in 'expected'.",
|
||||
py::arg("expected"), py::arg("actual"));
|
||||
|
||||
py::class_<PyAppIDIndex>(m, "AppIDIndex")
|
||||
.def(py::init<MlirOperation>(), py::arg("root"))
|
||||
.def("get_child_appids_of", &PyAppIDIndex::getChildAppIDsOf,
|
||||
|
|
|
@ -70,6 +70,10 @@ MlirType circtESIListTypeGetElementType(MlirType list) {
|
|||
return wrap(cast<ListType>(unwrap(list)).getElementType());
|
||||
}
|
||||
|
||||
bool circtESICheckInnerTypeMatch(MlirType to, MlirType from) {
|
||||
return succeeded(checkInnerTypeMatch(unwrap(to), unwrap(from)));
|
||||
}
|
||||
|
||||
void circtESIAppendMlirFile(MlirModule cMod, MlirStringRef filename) {
|
||||
ModuleOp modOp = unwrap(cMod);
|
||||
auto loadedMod =
|
||||
|
|
Loading…
Reference in New Issue