diff --git a/qubesadmin/base.py b/qubesadmin/base.py index ca61ba22..1f310c59 100644 --- a/qubesadmin/base.py +++ b/qubesadmin/base.py @@ -19,11 +19,24 @@ # with this program; if not, see . '''Base classes for managed objects''' +from __future__ import annotations + +import typing +from typing import BinaryIO, Any, TypeAlias, TypeVar, Generic +from collections.abc import Generator import qubesadmin.exc +if typing.TYPE_CHECKING: + from qubesadmin.vm import QubesVM + from qubesadmin.app import QubesBase + DEFAULT = object() +# We use Any because the dynamic metatada handling of the current code +# is too complex for type checkers otherwise +VMProperty: TypeAlias = Any # noqa: ANN401 + class PropertyHolder: '''A base class for object having properties retrievable using mgmt API. @@ -34,28 +47,29 @@ class PropertyHolder: ''' #: a place for appropriate Qubes() object (QubesLocal or QubesRemote), # use None for self - app = None + app: QubesBase - def __init__(self, app, method_prefix, method_dest): + def __init__(self, app: QubesBase, method_prefix: str, method_dest: str): #: appropriate Qubes() object (QubesLocal or QubesRemote), use None # for self self.app = app self._method_prefix = method_prefix self._method_dest = method_dest - self._properties = None + self._properties: list[str] | None = None self._properties_help = None # the cache is maintained by EventsDispatcher(), # through helper functions in QubesBase() self._properties_cache = {} - def clear_cache(self): + def clear_cache(self) -> None: """ Clear property cache. """ self._properties_cache = {} - def qubesd_call(self, dest, method, arg=None, payload=None, - payload_stream=None): + def qubesd_call(self, dest: str | None, method: str, + arg: str | None=None, payload: bytes | None=None, + payload_stream: BinaryIO | None=None) -> bytes: ''' Call into qubesd using appropriate mechanism. This method should be defined by a subclass. @@ -69,10 +83,7 @@ def qubesd_call(self, dest, method, arg=None, payload=None, :param payload_stream: file-like object to read payload from :return: Data returned by qubesd (string) ''' - if not self.app: - raise NotImplementedError - if dest is None: - dest = self._method_dest + dest: str = dest or self._method_dest if ( getattr(self, "_redirect_dispvm_calls", False) and dest.startswith("@dispvm") @@ -80,7 +91,7 @@ def qubesd_call(self, dest, method, arg=None, payload=None, if dest.startswith("@dispvm:"): dest = dest[len("@dispvm:") :] else: - dest = getattr(self.app, "default_dispvm", None) + dest: QubesVM | None = getattr(self.app, "default_dispvm", None) if dest: dest = dest.name # have the actual implementation at Qubes() instance @@ -88,7 +99,7 @@ def qubesd_call(self, dest, method, arg=None, payload=None, payload_stream) @staticmethod - def _parse_qubesd_response(response_data): + def _parse_qubesd_response(response_data: bytes) -> bytes: '''Parse response from qubesd. In case of success, return actual data. In case of error, @@ -122,7 +133,7 @@ def _parse_qubesd_response(response_data): raise qubesadmin.exc.QubesDaemonCommunicationError( 'Invalid response format') - def property_list(self): + def property_list(self) -> list[str]: ''' List available properties (their names). @@ -138,7 +149,7 @@ def property_list(self): # TODO: make it somehow immutable return self._properties - def property_help(self, name): + def property_help(self, name: str) -> str: ''' Get description of a property. @@ -151,7 +162,7 @@ def property_help(self, name): None) return help_text.decode('ascii') - def property_is_default(self, item): + def property_is_default(self, item: str) -> bool: ''' Check if given property have default value @@ -183,7 +194,7 @@ def property_is_default(self, item): self._properties_cache[item] = (is_default, value) return is_default - def property_get_default(self, item): + def property_get_default(self, item: str) -> VMProperty: ''' Get default property value, regardless of the current value @@ -206,7 +217,8 @@ def property_get_default(self, item): (prop_type, value) = property_str.split(b' ', 1) return self._parse_type_value(prop_type, value) - def clone_properties(self, src, proplist=None): + def clone_properties(self, src: PropertyHolder, + proplist: list[str] | None=None) -> None: '''Clone properties from other object. :param PropertyHolder src: source object @@ -223,7 +235,7 @@ def clone_properties(self, src, proplist=None): except AttributeError: continue - def __getattr__(self, item): + def __getattr__(self, item: str) -> VMProperty: if item.startswith('_'): raise AttributeError(item) # pre-fill cache if enabled @@ -254,7 +266,8 @@ def __getattr__(self, item): raise AttributeError(item) return value - def _deserialize_property(self, api_response): + def _deserialize_property(self, api_response: bytes) \ + -> tuple[bool, VMProperty]: """ Deserialize property.Get response format :param api_response: bytes, as retrieved from qubesd @@ -267,23 +280,26 @@ def _deserialize_property(self, api_response): value = self._parse_type_value(prop_type, value) return is_default, value - def _parse_type_value(self, prop_type, value): + def _parse_type_value(self, prop_type: bytes, value: bytes) -> VMProperty: ''' Parse `type=... ...` qubesd response format. Return a value of appropriate type. + Returns AttributeError instead of ValueError since this is used + to access named field + :param bytes prop_type: 'type=...' part of the response (including `type=` prefix) :param bytes value: 'value' part of the response :return: parsed value ''' # pylint: disable=too-many-return-statements - prop_type = prop_type.decode('ascii') + prop_type: str = prop_type.decode('ascii') if not prop_type.startswith('type='): raise qubesadmin.exc.QubesDaemonCommunicationError( 'Invalid type prefix received: {}'.format(prop_type)) (_, prop_type) = prop_type.split('=', 1) - value = value.decode() + value: str = value.decode() if prop_type == 'str': return str(value) if prop_type == 'bool': @@ -305,7 +321,7 @@ def _parse_type_value(self, prop_type, value): raise qubesadmin.exc.QubesDaemonCommunicationError( 'Received invalid value type: {}'.format(prop_type)) - def _fetch_all_properties(self): + def _fetch_all_properties(self) -> None: """ Retrieve all properties values at once using (prefix).property.GetAll method. If it succeed, save retrieved values in the properties cache. @@ -315,7 +331,7 @@ def _fetch_all_properties(self): :return: None """ - def unescape(line): + def unescape(line: bytes) -> Generator[int]: """Handle \\-escaped values, generates a list of character codes""" escaped = False for char in line: @@ -342,15 +358,15 @@ def unescape(line): return for line in properties_str.splitlines(): # decode newlines - line = bytes(unescape(line)) - name, property_str = line.split(b' ', 1) + line_bytes = bytes(list(unescape(line))) + name, property_str = line_bytes.split(b' ', 1) name = name.decode() is_default, value = self._deserialize_property(property_str) self._properties_cache[name] = (is_default, value) self._properties = list(self._properties_cache.keys()) @classmethod - def _local_properties(cls): + def _local_properties(cls) -> set: ''' Get set of property names that are properties on the Python object, and must not be set on the remote object @@ -367,7 +383,7 @@ def _local_properties(cls): return cls._local_properties_set - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: typing.Any) -> None: # noqa: ANN401 if key.startswith('_') or key in self._local_properties(): return super().__setattr__(key, value) if value is qubesadmin.DEFAULT: @@ -381,7 +397,9 @@ def __setattr__(self, key, value): qubesadmin.exc.QubesVMNotFoundError): raise qubesadmin.exc.QubesPropertyAccessError(key) else: - if isinstance(value, qubesadmin.vm.QubesVM): + # Dynamic import because qubesadmin.vm imports base.py + from qubesadmin.vm import QubesVM + if isinstance(value, QubesVM): value = value.name if value is None: value = '' @@ -395,7 +413,7 @@ def __setattr__(self, key, value): qubesadmin.exc.QubesVMNotFoundError): raise qubesadmin.exc.QubesPropertyAccessError(key) - def __delattr__(self, name): + def __delattr__(self, name: str) -> None: if name.startswith('_') or name in self._local_properties(): return super().__delattr__(name) try: @@ -408,10 +426,13 @@ def __delattr__(self, name): qubesadmin.exc.QubesVMNotFoundError): raise qubesadmin.exc.QubesPropertyAccessError(name) +WrapperObjectsCollectionKey: TypeAlias = int | str +T = TypeVar('T') -class WrapperObjectsCollection: +class WrapperObjectsCollection(Generic[T]): '''Collection of simple named objects''' - def __init__(self, app, list_method, object_class): + def __init__(self, app: QubesBase, + list_method: str, object_class: type[T]): ''' Construct manager of named wrapper objects. @@ -425,11 +446,13 @@ def __init__(self, app, list_method, object_class): self._list_method = list_method self._object_class = object_class #: names cache - self._names_list = None + self._names_list: list[WrapperObjectsCollectionKey] | None = None #: returned objects cache - self._objects = {} + self._objects: dict[WrapperObjectsCollectionKey, T] = {} - def clear_cache(self, invalidate_name=None): + def clear_cache(self, + invalidate_name: WrapperObjectsCollectionKey | None=None)\ + -> None: """Clear cached list of names. If *invalidate_name* is given, remove that object from cache explicitly too. @@ -438,7 +461,7 @@ def clear_cache(self, invalidate_name=None): if invalidate_name: self._objects.pop(invalidate_name, None) - def refresh_cache(self, force=False): + def refresh_cache(self, force: bool=False) -> None: '''Refresh cached list of names''' if not force and self._names_list is not None: return @@ -447,17 +470,18 @@ def refresh_cache(self, force=False): assert list_data[-1] == '\n' self._names_list = [str(name) for name in list_data[:-1].splitlines()] - for name, obj in list(self._objects.items()): + for name, obj in self._objects.items(): + assert hasattr(obj, "name") if obj.name not in self._names_list: # Object no longer exists del self._objects[name] - def __getitem__(self, item): + def __getitem__(self, item: WrapperObjectsCollectionKey) -> T: if not self.app.blind_mode and item not in self: raise KeyError(item) return self.get_blind(item) - def get_blind(self, item): + def get_blind(self, item: WrapperObjectsCollectionKey) -> T: ''' Get a property without downloading the list and checking if it's present @@ -466,25 +490,30 @@ def get_blind(self, item): self._objects[item] = self._object_class(self.app, item) return self._objects[item] - def __contains__(self, item): + def __contains__(self, item: WrapperObjectsCollectionKey) -> bool: self.refresh_cache() + assert self._names_list is not None return item in self._names_list - def __iter__(self): + def __iter__(self) -> Generator[WrapperObjectsCollectionKey]: self.refresh_cache() + assert self._names_list is not None yield from self._names_list - def keys(self): + def keys(self) -> list[WrapperObjectsCollectionKey]: '''Get list of names.''' self.refresh_cache() + assert self._names_list is not None return list(self._names_list) - def items(self): + def items(self) -> list[tuple[WrapperObjectsCollectionKey, T]]: '''Get list of (key, value) pairs''' self.refresh_cache() + assert self._names_list is not None return [(key, self.get_blind(key)) for key in self._names_list] - def values(self): + def values(self) -> list[T]: '''Get list of objects''' self.refresh_cache() + assert self._names_list is not None return [self.get_blind(key) for key in self._names_list] diff --git a/qubesadmin/vm/__init__.py b/qubesadmin/vm/__init__.py index 00785db8..4a93dfbd 100644 --- a/qubesadmin/vm/__init__.py +++ b/qubesadmin/vm/__init__.py @@ -64,7 +64,7 @@ def __init__(self, app, name, klass=None, power_state=None): self.firewall = qubesadmin.firewall.Firewall(self) @property - def name(self): + def name(self) -> str: """Domain name""" return self._method_dest @@ -80,7 +80,7 @@ def name(self, new_value): self._volumes = None self.app.domains.clear_cache() - def __str__(self): + def __str__(self) -> str: return self._method_dest def __lt__(self, other):