diff --git a/src/fastcs/util.py b/src/fastcs/util.py index a73cd591b..02905b4b8 100644 --- a/src/fastcs/util.py +++ b/src/fastcs/util.py @@ -32,13 +32,23 @@ def numpy_to_fastcs_datatype(np_type) -> DataType: def validate_hinted_attributes(controller: BaseController): - """Validates that type-hinted attributes exist in the controller, and are accessible - via the dot accessor, from the attributes dictionary and with the right datatype. + """Validates that type-hinted attributes in the controller and all subcontrollers + exist with the right datatype and access mode. """ - hints = get_type_hints(type(controller)) - alias_hints = {k: v for k, v in hints.items() if isinstance(v, _GenericAlias)} - for name, hint in alias_hints.items(): - attr_class = get_origin(hint) + for subcontroller in controller.get_sub_controllers().values(): + validate_hinted_attributes(subcontroller) + hints = { + k: v + for k, v in get_type_hints(type(controller)).items() + if isinstance(v, _GenericAlias | type) + } + for name, hint in hints.items(): + if isinstance(hint, type): + attr_class = hint + attr_dtype = None + else: + attr_class = get_origin(hint) + attr_dtype = get_args(hint)[0] if not issubclass(attr_class, Attribute): continue attr = getattr(controller, name, None) @@ -47,14 +57,16 @@ def validate_hinted_attributes(controller: BaseController): f"Controller `{controller.__class__.__name__}` failed to introspect " f"hinted attribute `{name}` during initialisation" ) - if type(attr) is not attr_class: + if attr_class is not type(attr): + # skip validation if access mode not specified + if attr_class is Attribute and isinstance(attr, Attribute): + continue raise RuntimeError( f"Controller '{controller.__class__.__name__}' introspection of hinted " f"attribute '{name}' does not match defined access mode. " f"Expected '{attr_class.__name__}', got '{type(attr).__name__}'." ) - attr_dtype = get_args(hint)[0] - if attr.datatype.dtype != attr_dtype: + if attr_dtype is not None and attr_dtype != attr.datatype.dtype: raise RuntimeError( f"Controller '{controller.__class__.__name__}' introspection of hinted " f"attribute '{name}' does not match defined datatype. " diff --git a/tests/test_util.py b/tests/test_util.py index e4ea0af63..542f18e81 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -6,7 +6,7 @@ from pvi.device import SignalR from pydantic import ValidationError -from fastcs.attributes import AttrR, AttrRW +from fastcs.attributes import Attribute, AttrR, AttrRW from fastcs.controller import Controller from fastcs.datatypes import Bool, Enum, Float, Int, String from fastcs.launch import FastCS @@ -125,3 +125,52 @@ class ControllerWrongEnumClass(Controller): "'hinted_enum' does not match defined datatype. " "Expected 'MyEnum', got 'MyEnum2'." ) + + class ControllerUnspecifiedAccessMode(Controller): + hinted: Attribute[int] + + async def initialise(self): + self.hinted = AttrR(Int()) + + # no assertion thrown + FastCS(ControllerUnspecifiedAccessMode(), [], loop) + + +def test_hinted_attributes_verified_on_subcontrollers(): + loop = asyncio.get_event_loop() + + class ControllerWithWrongType(Controller): + hinted_missing: AttrR[int] + + async def connect(self): + return + + class TopController(Controller): + async def initialise(self): + subcontroller = ControllerWithWrongType() + self.register_sub_controller("MySubController", subcontroller) + + with pytest.raises(RuntimeError, match="failed to introspect hinted attribute"): + FastCS(TopController(), [], loop) + + +def test_hinted_attribute_types_verified(): + # test verification works with non-GenericAlias type hints + loop = asyncio.get_event_loop() + + class ControllerAttrWrongAccessMode(Controller): + read_attr: AttrR + + async def initialise(self): + self.read_attr = AttrRW(Int()) + + with pytest.raises(RuntimeError, match="does not match defined access mode"): + FastCS(ControllerAttrWrongAccessMode(), [], loop) + + class ControllerUnspecifiedAccessMode(Controller): + unspecified_access_mode: Attribute + + async def initialise(self): + self.unspecified_access_mode = AttrRW(Int()) + + FastCS(ControllerUnspecifiedAccessMode(), [], loop)