diff --git a/qubesadmin/base.py b/qubesadmin/base.py
index ca61ba22..1f310c59 100644
--- a/qubesadmin/base.py
+++ b/qubesadmin/base.py
@@ -19,11 +19,24 @@
# with this program; if not, see .
'''Base classes for managed objects'''
+from __future__ import annotations
+
+import typing
+from typing import BinaryIO, Any, TypeAlias, TypeVar, Generic
+from collections.abc import Generator
import qubesadmin.exc
+if typing.TYPE_CHECKING:
+ from qubesadmin.vm import QubesVM
+ from qubesadmin.app import QubesBase
+
DEFAULT = object()
+# We use Any because the dynamic metatada handling of the current code
+# is too complex for type checkers otherwise
+VMProperty: TypeAlias = Any # noqa: ANN401
+
class PropertyHolder:
'''A base class for object having properties retrievable using mgmt API.
@@ -34,28 +47,29 @@ class PropertyHolder:
'''
#: a place for appropriate Qubes() object (QubesLocal or QubesRemote),
# use None for self
- app = None
+ app: QubesBase
- def __init__(self, app, method_prefix, method_dest):
+ def __init__(self, app: QubesBase, method_prefix: str, method_dest: str):
#: appropriate Qubes() object (QubesLocal or QubesRemote), use None
# for self
self.app = app
self._method_prefix = method_prefix
self._method_dest = method_dest
- self._properties = None
+ self._properties: list[str] | None = None
self._properties_help = None
# the cache is maintained by EventsDispatcher(),
# through helper functions in QubesBase()
self._properties_cache = {}
- def clear_cache(self):
+ def clear_cache(self) -> None:
"""
Clear property cache.
"""
self._properties_cache = {}
- def qubesd_call(self, dest, method, arg=None, payload=None,
- payload_stream=None):
+ def qubesd_call(self, dest: str | None, method: str,
+ arg: str | None=None, payload: bytes | None=None,
+ payload_stream: BinaryIO | None=None) -> bytes:
'''
Call into qubesd using appropriate mechanism. This method should be
defined by a subclass.
@@ -69,10 +83,7 @@ def qubesd_call(self, dest, method, arg=None, payload=None,
:param payload_stream: file-like object to read payload from
:return: Data returned by qubesd (string)
'''
- if not self.app:
- raise NotImplementedError
- if dest is None:
- dest = self._method_dest
+ dest: str = dest or self._method_dest
if (
getattr(self, "_redirect_dispvm_calls", False)
and dest.startswith("@dispvm")
@@ -80,7 +91,7 @@ def qubesd_call(self, dest, method, arg=None, payload=None,
if dest.startswith("@dispvm:"):
dest = dest[len("@dispvm:") :]
else:
- dest = getattr(self.app, "default_dispvm", None)
+ dest: QubesVM | None = getattr(self.app, "default_dispvm", None)
if dest:
dest = dest.name
# have the actual implementation at Qubes() instance
@@ -88,7 +99,7 @@ def qubesd_call(self, dest, method, arg=None, payload=None,
payload_stream)
@staticmethod
- def _parse_qubesd_response(response_data):
+ def _parse_qubesd_response(response_data: bytes) -> bytes:
'''Parse response from qubesd.
In case of success, return actual data. In case of error,
@@ -122,7 +133,7 @@ def _parse_qubesd_response(response_data):
raise qubesadmin.exc.QubesDaemonCommunicationError(
'Invalid response format')
- def property_list(self):
+ def property_list(self) -> list[str]:
'''
List available properties (their names).
@@ -138,7 +149,7 @@ def property_list(self):
# TODO: make it somehow immutable
return self._properties
- def property_help(self, name):
+ def property_help(self, name: str) -> str:
'''
Get description of a property.
@@ -151,7 +162,7 @@ def property_help(self, name):
None)
return help_text.decode('ascii')
- def property_is_default(self, item):
+ def property_is_default(self, item: str) -> bool:
'''
Check if given property have default value
@@ -183,7 +194,7 @@ def property_is_default(self, item):
self._properties_cache[item] = (is_default, value)
return is_default
- def property_get_default(self, item):
+ def property_get_default(self, item: str) -> VMProperty:
'''
Get default property value, regardless of the current value
@@ -206,7 +217,8 @@ def property_get_default(self, item):
(prop_type, value) = property_str.split(b' ', 1)
return self._parse_type_value(prop_type, value)
- def clone_properties(self, src, proplist=None):
+ def clone_properties(self, src: PropertyHolder,
+ proplist: list[str] | None=None) -> None:
'''Clone properties from other object.
:param PropertyHolder src: source object
@@ -223,7 +235,7 @@ def clone_properties(self, src, proplist=None):
except AttributeError:
continue
- def __getattr__(self, item):
+ def __getattr__(self, item: str) -> VMProperty:
if item.startswith('_'):
raise AttributeError(item)
# pre-fill cache if enabled
@@ -254,7 +266,8 @@ def __getattr__(self, item):
raise AttributeError(item)
return value
- def _deserialize_property(self, api_response):
+ def _deserialize_property(self, api_response: bytes) \
+ -> tuple[bool, VMProperty]:
"""
Deserialize property.Get response format
:param api_response: bytes, as retrieved from qubesd
@@ -267,23 +280,26 @@ def _deserialize_property(self, api_response):
value = self._parse_type_value(prop_type, value)
return is_default, value
- def _parse_type_value(self, prop_type, value):
+ def _parse_type_value(self, prop_type: bytes, value: bytes) -> VMProperty:
'''
Parse `type=... ...` qubesd response format. Return a value of
appropriate type.
+ Returns AttributeError instead of ValueError since this is used
+ to access named field
+
:param bytes prop_type: 'type=...' part of the response (including
`type=` prefix)
:param bytes value: 'value' part of the response
:return: parsed value
'''
# pylint: disable=too-many-return-statements
- prop_type = prop_type.decode('ascii')
+ prop_type: str = prop_type.decode('ascii')
if not prop_type.startswith('type='):
raise qubesadmin.exc.QubesDaemonCommunicationError(
'Invalid type prefix received: {}'.format(prop_type))
(_, prop_type) = prop_type.split('=', 1)
- value = value.decode()
+ value: str = value.decode()
if prop_type == 'str':
return str(value)
if prop_type == 'bool':
@@ -305,7 +321,7 @@ def _parse_type_value(self, prop_type, value):
raise qubesadmin.exc.QubesDaemonCommunicationError(
'Received invalid value type: {}'.format(prop_type))
- def _fetch_all_properties(self):
+ def _fetch_all_properties(self) -> None:
"""
Retrieve all properties values at once using (prefix).property.GetAll
method. If it succeed, save retrieved values in the properties cache.
@@ -315,7 +331,7 @@ def _fetch_all_properties(self):
:return: None
"""
- def unescape(line):
+ def unescape(line: bytes) -> Generator[int]:
"""Handle \\-escaped values, generates a list of character codes"""
escaped = False
for char in line:
@@ -342,15 +358,15 @@ def unescape(line):
return
for line in properties_str.splitlines():
# decode newlines
- line = bytes(unescape(line))
- name, property_str = line.split(b' ', 1)
+ line_bytes = bytes(list(unescape(line)))
+ name, property_str = line_bytes.split(b' ', 1)
name = name.decode()
is_default, value = self._deserialize_property(property_str)
self._properties_cache[name] = (is_default, value)
self._properties = list(self._properties_cache.keys())
@classmethod
- def _local_properties(cls):
+ def _local_properties(cls) -> set:
'''
Get set of property names that are properties on the Python object,
and must not be set on the remote object
@@ -367,7 +383,7 @@ def _local_properties(cls):
return cls._local_properties_set
- def __setattr__(self, key, value):
+ def __setattr__(self, key: str, value: typing.Any) -> None: # noqa: ANN401
if key.startswith('_') or key in self._local_properties():
return super().__setattr__(key, value)
if value is qubesadmin.DEFAULT:
@@ -381,7 +397,9 @@ def __setattr__(self, key, value):
qubesadmin.exc.QubesVMNotFoundError):
raise qubesadmin.exc.QubesPropertyAccessError(key)
else:
- if isinstance(value, qubesadmin.vm.QubesVM):
+ # Dynamic import because qubesadmin.vm imports base.py
+ from qubesadmin.vm import QubesVM
+ if isinstance(value, QubesVM):
value = value.name
if value is None:
value = ''
@@ -395,7 +413,7 @@ def __setattr__(self, key, value):
qubesadmin.exc.QubesVMNotFoundError):
raise qubesadmin.exc.QubesPropertyAccessError(key)
- def __delattr__(self, name):
+ def __delattr__(self, name: str) -> None:
if name.startswith('_') or name in self._local_properties():
return super().__delattr__(name)
try:
@@ -408,10 +426,13 @@ def __delattr__(self, name):
qubesadmin.exc.QubesVMNotFoundError):
raise qubesadmin.exc.QubesPropertyAccessError(name)
+WrapperObjectsCollectionKey: TypeAlias = int | str
+T = TypeVar('T')
-class WrapperObjectsCollection:
+class WrapperObjectsCollection(Generic[T]):
'''Collection of simple named objects'''
- def __init__(self, app, list_method, object_class):
+ def __init__(self, app: QubesBase,
+ list_method: str, object_class: type[T]):
'''
Construct manager of named wrapper objects.
@@ -425,11 +446,13 @@ def __init__(self, app, list_method, object_class):
self._list_method = list_method
self._object_class = object_class
#: names cache
- self._names_list = None
+ self._names_list: list[WrapperObjectsCollectionKey] | None = None
#: returned objects cache
- self._objects = {}
+ self._objects: dict[WrapperObjectsCollectionKey, T] = {}
- def clear_cache(self, invalidate_name=None):
+ def clear_cache(self,
+ invalidate_name: WrapperObjectsCollectionKey | None=None)\
+ -> None:
"""Clear cached list of names.
If *invalidate_name* is given, remove that object from cache
explicitly too.
@@ -438,7 +461,7 @@ def clear_cache(self, invalidate_name=None):
if invalidate_name:
self._objects.pop(invalidate_name, None)
- def refresh_cache(self, force=False):
+ def refresh_cache(self, force: bool=False) -> None:
'''Refresh cached list of names'''
if not force and self._names_list is not None:
return
@@ -447,17 +470,18 @@ def refresh_cache(self, force=False):
assert list_data[-1] == '\n'
self._names_list = [str(name) for name in list_data[:-1].splitlines()]
- for name, obj in list(self._objects.items()):
+ for name, obj in self._objects.items():
+ assert hasattr(obj, "name")
if obj.name not in self._names_list:
# Object no longer exists
del self._objects[name]
- def __getitem__(self, item):
+ def __getitem__(self, item: WrapperObjectsCollectionKey) -> T:
if not self.app.blind_mode and item not in self:
raise KeyError(item)
return self.get_blind(item)
- def get_blind(self, item):
+ def get_blind(self, item: WrapperObjectsCollectionKey) -> T:
'''
Get a property without downloading the list
and checking if it's present
@@ -466,25 +490,30 @@ def get_blind(self, item):
self._objects[item] = self._object_class(self.app, item)
return self._objects[item]
- def __contains__(self, item):
+ def __contains__(self, item: WrapperObjectsCollectionKey) -> bool:
self.refresh_cache()
+ assert self._names_list is not None
return item in self._names_list
- def __iter__(self):
+ def __iter__(self) -> Generator[WrapperObjectsCollectionKey]:
self.refresh_cache()
+ assert self._names_list is not None
yield from self._names_list
- def keys(self):
+ def keys(self) -> list[WrapperObjectsCollectionKey]:
'''Get list of names.'''
self.refresh_cache()
+ assert self._names_list is not None
return list(self._names_list)
- def items(self):
+ def items(self) -> list[tuple[WrapperObjectsCollectionKey, T]]:
'''Get list of (key, value) pairs'''
self.refresh_cache()
+ assert self._names_list is not None
return [(key, self.get_blind(key)) for key in self._names_list]
- def values(self):
+ def values(self) -> list[T]:
'''Get list of objects'''
self.refresh_cache()
+ assert self._names_list is not None
return [self.get_blind(key) for key in self._names_list]
diff --git a/qubesadmin/vm/__init__.py b/qubesadmin/vm/__init__.py
index 00785db8..4a93dfbd 100644
--- a/qubesadmin/vm/__init__.py
+++ b/qubesadmin/vm/__init__.py
@@ -64,7 +64,7 @@ def __init__(self, app, name, klass=None, power_state=None):
self.firewall = qubesadmin.firewall.Firewall(self)
@property
- def name(self):
+ def name(self) -> str:
"""Domain name"""
return self._method_dest
@@ -80,7 +80,7 @@ def name(self, new_value):
self._volumes = None
self.app.domains.clear_cache()
- def __str__(self):
+ def __str__(self) -> str:
return self._method_dest
def __lt__(self, other):