diff --git a/qubesadmin/app.py b/qubesadmin/app.py index 808b0f7f..83f584e8 100644 --- a/qubesadmin/app.py +++ b/qubesadmin/app.py @@ -31,7 +31,11 @@ import sys import logging +import typing from logging import Logger +from subprocess import Popen +from typing import IO +from collections.abc import Generator, Iterable import qubesadmin.base import qubesadmin.exc @@ -39,69 +43,94 @@ import qubesadmin.storage import qubesadmin.utils import qubesadmin.vm +from qubesadmin.label import Label +from qubesadmin.vm import Klass, PowerState import qubesadmin.config import qubesadmin.device_protocol from qubesadmin.vm import QubesVM try: - import qubesdb + import qubesdb # type: ignore has_qubesdb = True except ImportError: has_qubesdb = False +DeviceClass = typing.Literal["mic", "block", "pci", "usb", "webcam"] + class VMCollection: - """Collection of VMs objects""" + """A lazily-loaded, cached view of the VMs known to qubesd. + + Membership (``__contains__``, ``__iter__``, ``keys()``, ``values()``) + reflects exclusively the VMs returned by the most recent + ``admin.vm.List`` call. That list is fetched on the first membership + query and cached; call :py:meth:`refresh_cache` to reload. + + :py:meth:`get_blind` is **not** a membership operation. + It returns (or creates) a :py:class:`~qubesadmin.vm.QubesVM` handle for + the given name without querying the API and without registering the name + in the membership view. Its purpose is to provide a unique + :py:class:`~qubesadmin.vm.QubesVM` object for the same name on repeated + calls. + Blind objects desynced with ``vm.List`` are pruned from the internal cache + on the next :py:meth:`refresh_cache` call. + + ``del domains[name]`` removes the VM by calling ``admin.vm.Remove``. + There is no ``__setitem__``; use + :py:meth:`~qubesadmin.app.QubesBase.add_new_vm` to create VMs. + """ - def __init__(self, app): + def __init__(self, app: "QubesBase"): self.app = app - self._vm_list = None - self._vm_objects = {} + # object cache — may contain blind objects not confirmed by vm.List + self._vms: dict[str, QubesVM] = {} + # names returned by the last admin.vm.List call + self._known_names: set[str] = set() + self._initialized: bool = False - def clear_cache(self, invalidate_name=None): + def clear_cache(self, invalidate_name: str | None=None) -> None: """Clear cached list of VMs - If *invalidate_name* is given, remove that object from cache - explicitly too. + If *invalidate_name* is given, remove that object from the collection. """ - self._vm_list = None + self._initialized = False + self._known_names = set() if invalidate_name: - self._vm_objects.pop(invalidate_name, None) + self._vms.pop(invalidate_name, None) - def refresh_cache(self, force=False): - """Refresh cached list of VMs""" - if not force and self._vm_list is not None: - return + def _ensure_cache_loaded(self) -> None: + """Load the VM list from qubesd if not already cached.""" + if not self._initialized: + self.refresh_cache() + + def refresh_cache(self) -> None: + """Reload the VM list from qubesd, discarding any cached data.""" vm_list_data = self.app.qubesd_call("dom0", "admin.vm.List") - new_vm_list = {} - # FIXME: this will probably change + # Handle renamed VMs + vms_by_current_name = {vm.name: vm for vm in self._vms.values()} + new_known_names: set[str] = set() for vm_data in vm_list_data.splitlines(): vm_name, props = vm_data.decode("ascii").split(" ", 1) - vm_name = str(vm_name) - props = props.split(" ") - new_vm_list[vm_name] = dict( - [vm_prop.split("=", 1) for vm_prop in props] - ) - # if cache not enabled, drop power state - if not self.app.cache_enabled: - try: - del new_vm_list[vm_name]["state"] - except KeyError: - pass - - self._vm_list = new_vm_list - for name, vm in list(self._vm_objects.items()): - if vm.name not in self._vm_list: - # VM no longer exists - del self._vm_objects[name] - elif vm.klass != self._vm_list[vm.name]["class"]: - # VM class have changed - del self._vm_objects[name] - # TODO: some generation ID, to detect VM re-creation - elif name != vm.name: - # renamed - self._vm_objects[vm.name] = vm - del self._vm_objects[name] + props_dict = dict(vm_prop.split("=", 1) + for vm_prop in props.split(" ")) + klass = typing.cast(Klass, props_dict["class"]) + power_state = typing.cast(PowerState, props_dict.get("state")) + new_known_names.add(vm_name) + existing_vm = self._vms.get(vm_name) or\ + vms_by_current_name.get(vm_name) + # TODO: some generation ID (e.g. uuid), to detect VM re-creation + if existing_vm is None or existing_vm.klass != klass: + self._vms[vm_name] = QubesVM( + self.app, vm_name, klass=klass, power_state=power_state + ) + elif existing_vm is not self._vms.get(vm_name): + # renamed: existing_vm was found via vms_by_current_name + self._vms[vm_name] = existing_vm + # Drop objects for VMs that no longer exist + self._vms = {name: vm for name, vm in self._vms.items() + if name in new_known_names} + self._known_names = new_known_names + self._initialized = True def __getitem__(self, item: str | QubesVM) -> QubesVM: if isinstance(item, QubesVM): @@ -112,25 +141,15 @@ def __getitem__(self, item: str | QubesVM) -> QubesVM: def get_blind(self, item: str) -> QubesVM: """ - Get a vm without downloading the list - and checking if exists + Get a vm from the collection. If the vm is not in the collection + already, a new basic entry will be created from the provided name. """ - if item not in self._vm_objects: - cls = qubesadmin.vm.QubesVM - # provide class name to constructor, if already cached (which can be - # done by 'item not in self' check above, unless blind_mode is - # enabled - klass = None - power_state = None - if self._vm_list and item in self._vm_list: - klass = self._vm_list[item]["class"] - power_state = self._vm_list[item].get("state") - self._vm_objects[item] = cls( - self.app, item, klass=klass, power_state=power_state - ) - return self._vm_objects[item] + if item not in self._vms: + self._vms[item] = QubesVM(self.app, item) + return self._vms[item] - def get(self, item, default=None) -> QubesVM: + def get(self, item: str | QubesVM, default: QubesVM | None=None)\ + -> QubesVM | None: """ Get a VM object, or return *default* if it can't be found. """ @@ -139,30 +158,31 @@ def get(self, item, default=None) -> QubesVM: except KeyError: return default - def __contains__(self, item): + def __contains__(self, item: QubesVM | str) -> bool: if isinstance(item, qubesadmin.vm.QubesVM): item = item.name - self.refresh_cache() - return item in self._vm_list + self._ensure_cache_loaded() + return item in self._known_names - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: self.app.qubesd_call(key, "admin.vm.Remove") - self.clear_cache() + self._known_names.discard(key) + self._vms.pop(key, None) - def __iter__(self): - self.refresh_cache() - for vm in sorted(self._vm_list): - yield self[vm] + def __iter__(self) -> Generator[QubesVM, None, None]: + self._ensure_cache_loaded() + for vm in sorted(self._known_names): + yield self.get_blind(vm) - def keys(self): + def keys(self) -> Iterable[str]: """Get list of VM names.""" - self.refresh_cache() - return self._vm_list.keys() + self._ensure_cache_loaded() + return self._known_names - def values(self): + def values(self) -> list[QubesVM]: """Get list of VM objects.""" - self.refresh_cache() - return [self[name] for name in self._vm_list] + self._ensure_cache_loaded() + return [self.get_blind(name) for name in self._known_names] class QubesBase(qubesadmin.base.PropertyHolder): @@ -177,12 +197,11 @@ class in py:class:`qubesadmin.Qubes` instead, which points at #: domains (VMs) collection domains: VMCollection #: labels collection - labels: qubesadmin.base.WrapperObjectsCollection + labels: qubesadmin.base.WrapperObjectsCollection[Label] #: storage pools - pools: qubesadmin.base.WrapperObjectsCollection + pools: qubesadmin.base.WrapperObjectsCollection[qubesadmin.storage.Pool] #: type of qubesd connection: either 'socket' or 'qrexec' - qubesd_connection_type: str | None = None # See in PR#416 why we keep - # =None here to not trip the CI + qubesd_connection_type: typing.Literal["socket", "qrexec"] #: logger log: Logger #: do not check for object (VM, label etc) existence before really needed @@ -190,7 +209,7 @@ class in py:class:`qubesadmin.Qubes` instead, which points at #: cache retrieved properties values cache_enabled: bool = False - def __init__(self): + def __init__(self) -> None: super().__init__(self, "admin.property.", "dom0") self.domains = VMCollection(self) self.labels = qubesadmin.base.WrapperObjectsCollection( @@ -200,27 +219,32 @@ def __init__(self): self, "admin.pool.List", qubesadmin.storage.Pool ) #: cache for available storage pool drivers and options to create them - self._pool_drivers = None + self._pool_drivers: dict[str, list[str]] | None = None self.log = logging.getLogger("app") self._local_name = None - def list_vmclass(self): + def list_vmclass(self) -> list[Klass]: """Call Qubesd in order to obtain the vm classes list""" vmclass = ( self.qubesd_call("dom0", "admin.vmclass.List").decode().splitlines() ) - return sorted(vmclass) + for e in vmclass: + assert e in typing.get_args(Klass) + return typing.cast(list[Klass], sorted(vmclass)) - def list_deviceclass(self): + def list_deviceclass(self) -> list[DeviceClass]: """Call Qubesd in order to obtain the device classes list""" deviceclasses = ( self.qubesd_call("dom0", "admin.deviceclass.List") .decode() .splitlines() ) - return sorted(deviceclasses) + for e in deviceclasses: + assert e in typing.get_args(DeviceClass) + + return typing.cast(list[DeviceClass], sorted(deviceclasses)) - def _refresh_pool_drivers(self): + def _refresh_pool_drivers(self) -> None: """ Refresh cached storage pool drivers and their parameters. @@ -240,17 +264,19 @@ def _refresh_pool_drivers(self): self._pool_drivers = pool_drivers @property - def pool_drivers(self): + def pool_drivers(self) -> Iterable[str]: """Available storage pool drivers""" self._refresh_pool_drivers() + assert self._pool_drivers is not None return self._pool_drivers.keys() - def pool_driver_parameters(self, driver): + def pool_driver_parameters(self, driver: str) -> list[str]: """Parameters to initialize storage pool using given driver""" self._refresh_pool_drivers() + assert self._pool_drivers is not None return self._pool_drivers[driver] - def add_pool(self, name, driver, **kwargs): + def add_pool(self, name: str, driver: str, **kwargs) -> None: """Add a storage pool to config :param name: name of storage pool to create @@ -268,12 +294,12 @@ def add_pool(self, name, driver, **kwargs): "dom0", "admin.pool.Add", driver, payload.encode("utf-8") ) - def remove_pool(self, name): + def remove_pool(self, name: str) -> None: """Remove a storage pool""" self.qubesd_call("dom0", "admin.pool.Remove", name, None) @property - def local_name(self): + def local_name(self) -> str: """Get localhost name""" if not self._local_name: local_name = None @@ -291,7 +317,7 @@ def local_name(self): return self._local_name - def get_label(self, label): + def get_label(self, label: str | int) -> Label: """Get label as identified by index or name :throws QubesLabelNotFoundError: when label is not found @@ -308,10 +334,11 @@ def get_label(self, label): for i in self.labels.values(): if i.index == int(label): return i + assert isinstance(label, str) raise qubesadmin.exc.QubesLabelNotFoundError(label) @staticmethod - def get_vm_class(clsname): + def get_vm_class(clsname: str) -> str: """Find the class for a domain. Compatibility function, client tools use str to identify domain classes. @@ -323,8 +350,10 @@ def get_vm_class(clsname): return clsname def add_new_vm( - self, cls, name, label, template=None, pool=None, pools=None - ): + self, cls: str | type[QubesVM], name: str, label: str, + template: str | QubesVM | None=None, pool: str | None=None, + pools: dict | None=None + ) -> QubesVM: """Create new Virtual Machine Example usage with custom storage pools: @@ -380,16 +409,16 @@ def add_new_vm( def clone_vm( self, - src_vm, - new_name, - new_cls=None, + src_vm: str | QubesVM, + new_name: str, + new_cls: str | None=None, *, - pool=None, - pools=None, - ignore_errors=False, - ignore_volumes=None, - ignore_devices=False, - ): + pool: str | None=None, + pools: dict | None=None, + ignore_errors: bool=False, + ignore_volumes: list | None=None, + ignore_devices: bool=False, + ) -> QubesVM: # pylint: disable=too-many-statements # pylint: disable=too-many-branches """Clone Virtual Machine @@ -625,8 +654,9 @@ def clone_vm( return dst_vm def qubesd_call( - self, dest, method, arg=None, payload=None, payload_stream=None - ): + self, dest: str | None, method: str, arg: str | None=None, + payload: bytes | None=None, payload_stream: IO | None=None + ) -> bytes: """ Execute Admin API method. @@ -649,16 +679,16 @@ def qubesd_call( def run_service( self, - dest, - service, - user=None, + dest: str, + service: str, + user: str | None=None, *, - filter_esc=False, - localcmd=None, - wait=True, - autostart=True, + filter_esc: bool=False, + localcmd: str | None=None, + wait: bool=True, + autostart: bool=True, **kwargs, - ): + ) -> Popen: """Run qrexec service in a given destination *kwargs* are passed verbatim to :py:meth:`subprocess.Popen`. @@ -680,7 +710,9 @@ def run_service( ) @staticmethod - def _call_with_stream(command, payload, payload_stream): + def _call_with_stream(command: str | list[str], payload: bytes | None, + payload_stream: IO)\ + -> tuple[Popen, bytes, bytes]: """Helper method to pass data to qubesd. Calls a command with payload and payload_stream as input. @@ -701,6 +733,7 @@ def _call_with_stream(command, payload, payload_stream): # because the process can get blocked on stdout or stderr pipe. # However, in practice the output should be always smaller # than 4K. + assert proc.stdin is not None proc.stdin.write(payload) try: shutil.copyfileobj(payload_stream, proc.stdin) @@ -713,7 +746,8 @@ def _call_with_stream(command, payload, payload_stream): stdout, stderr = proc.communicate() return proc, stdout, stderr - def _invalidate_cache(self, subject, event, name, **kwargs): + def _invalidate_cache(self, subject: QubesVM | None, + event: str, name: str, **kwargs) -> None: """Invalidate cached value of a property. This method is designed to be hooked as an event handler for: @@ -734,15 +768,18 @@ def _invalidate_cache(self, subject, event, name, **kwargs): :return: none """ # pylint: disable=unused-argument if subject is None: - subject = self + subject_or_self = self + else: + subject_or_self = subject try: # pylint: disable=protected-access - del subject._properties_cache[name] + del subject_or_self._properties_cache[name] except KeyError: pass - def _update_power_state_cache(self, subject, event, **kwargs): + def _update_power_state_cache(self, subject: QubesVM, + event: str, **kwargs) -> None: """Update cached VM power state. This method is designed to be hooked as an event handler for: @@ -784,7 +821,7 @@ def _update_power_state_cache(self, subject, event, **kwargs): # pylint: disable=protected-access subject._power_state_cache = power_state - def _invalidate_cache_all(self): + def _invalidate_cache_all(self) -> None: """Invalidate all cached data @@ -800,7 +837,7 @@ def _invalidate_cache_all(self): """ # pylint: disable=protected-access self.domains.clear_cache() - for vm in self.domains._vm_objects.values(): + for vm in self.domains._vms.values(): assert isinstance(vm, qubesadmin.vm.QubesVM) vm._power_state_cache = None vm._properties_cache = {} @@ -817,8 +854,9 @@ class QubesLocal(QubesBase): qubesd_connection_type = "socket" def qubesd_call( - self, dest, method, arg=None, payload=None, payload_stream=None - ): + self, dest: str | None, method: str, arg: str | None=None, + payload: bytes | None=None, payload_stream: IO | None=None + ) -> bytes: """ Execute Admin API method. @@ -846,13 +884,16 @@ def qubesd_call( raise qubesadmin.exc.QubesDaemonCommunicationError( "{} not found".format(method_path) ) + assert arg is not None + assert dest is not None command = [ "env", "QREXEC_REMOTE_DOMAIN=dom0", "QREXEC_REQUESTED_TARGET=" + dest, method_path, - arg, ] + if arg is not None: + command.append(arg) if os.getuid() != 0: command.insert(0, "sudo") (_, stdout, _) = self._call_with_stream( @@ -881,16 +922,16 @@ def qubesd_call( def run_service( self, - dest, - service, - user=None, + dest: str, + service: str, + user: str | None=None, *, - filter_esc=False, - localcmd=None, - wait=True, - autostart=True, + filter_esc: bool=False, + localcmd: str | None=None, + wait: bool=True, + autostart: bool=True, **kwargs, - ): + ) -> Popen: """Run qrexec service in a given destination :param str dest: Destination - may be a VM name or empty @@ -990,8 +1031,9 @@ class QubesRemote(QubesBase): qubesd_connection_type = "qrexec" def qubesd_call( - self, dest, method, arg=None, payload=None, payload_stream=None - ): + self, dest: str | None, method: str, arg: str | None=None, + payload: bytes | None=None, payload_stream: IO | None=None + ) -> bytes: """ Execute Admin API method. @@ -1008,6 +1050,7 @@ def qubesd_call( .. warning:: *payload_stream* will get closed by this function """ service_name = method + assert dest is not None if arg is not None: service_name += "+" + arg command = [qubesadmin.config.QREXEC_CLIENT_VM, dest, service_name] @@ -1032,16 +1075,16 @@ def qubesd_call( def run_service( self, - dest, - service, - user=None, + dest: str, + service: str, + user: str | None=None, *, - filter_esc=False, - localcmd=None, - wait=True, - autostart=True, + filter_esc: bool=False, + localcmd: str | None=None, + wait: bool=True, + autostart: bool=True, **kwargs, - ): + ) -> Popen: """Run qrexec service in a given destination :param str dest: Destination - may be a VM name or empty diff --git a/qubesadmin/backup/core3.py b/qubesadmin/backup/core3.py index bb540680..ce47d27f 100644 --- a/qubesadmin/backup/core3.py +++ b/qubesadmin/backup/core3.py @@ -28,12 +28,10 @@ import qubesadmin.backup import qubesadmin.firewall -from qubesadmin import device_protocol +from qubesadmin import utils from qubesadmin.vm import QubesVM - - class Core3VM(qubesadmin.backup.BackupVM): '''VM object''' @property @@ -138,7 +136,7 @@ def import_core3_vm(self, element: _Element) -> None: for opt_node in node.findall('./option'): opt_name = opt_node.get('name') options[opt_name] = opt_node.text - options['required'] = device_protocol.qbool( + options['required'] = utils.qbool( node.get('required', 'yes')) vm.devices[bus_name][(backend_domain, port_id)] = options diff --git a/qubesadmin/base.py b/qubesadmin/base.py index ca61ba22..6f1f8500 100644 --- a/qubesadmin/base.py +++ b/qubesadmin/base.py @@ -19,11 +19,24 @@ # with this program; if not, see . '''Base classes for managed objects''' +from __future__ import annotations + +import typing +from typing import BinaryIO, Any, TypeAlias, TypeVar, Generic +from collections.abc import Generator import qubesadmin.exc +if typing.TYPE_CHECKING: + from qubesadmin.vm import QubesVM + from qubesadmin.app import QubesBase + DEFAULT = object() +# We use Any because the dynamic metatada handling of the current code +# is too complex for type checkers otherwise +VMProperty: TypeAlias = Any # noqa: ANN401 + class PropertyHolder: '''A base class for object having properties retrievable using mgmt API. @@ -34,28 +47,29 @@ class PropertyHolder: ''' #: a place for appropriate Qubes() object (QubesLocal or QubesRemote), # use None for self - app = None + app: QubesBase - def __init__(self, app, method_prefix, method_dest): + def __init__(self, app: QubesBase, method_prefix: str, method_dest: str): #: appropriate Qubes() object (QubesLocal or QubesRemote), use None # for self self.app = app self._method_prefix = method_prefix self._method_dest = method_dest - self._properties = None + self._properties: list[str] | None = None self._properties_help = None # the cache is maintained by EventsDispatcher(), # through helper functions in QubesBase() self._properties_cache = {} - def clear_cache(self): + def clear_cache(self) -> None: """ Clear property cache. """ self._properties_cache = {} - def qubesd_call(self, dest, method, arg=None, payload=None, - payload_stream=None): + def qubesd_call(self, dest: str | None, method: str, + arg: str | None=None, payload: bytes | None=None, + payload_stream: BinaryIO | None=None) -> bytes: ''' Call into qubesd using appropriate mechanism. This method should be defined by a subclass. @@ -69,10 +83,7 @@ def qubesd_call(self, dest, method, arg=None, payload=None, :param payload_stream: file-like object to read payload from :return: Data returned by qubesd (string) ''' - if not self.app: - raise NotImplementedError - if dest is None: - dest = self._method_dest + dest: str = dest or self._method_dest if ( getattr(self, "_redirect_dispvm_calls", False) and dest.startswith("@dispvm") @@ -80,7 +91,7 @@ def qubesd_call(self, dest, method, arg=None, payload=None, if dest.startswith("@dispvm:"): dest = dest[len("@dispvm:") :] else: - dest = getattr(self.app, "default_dispvm", None) + dest: QubesVM | None = getattr(self.app, "default_dispvm", None) if dest: dest = dest.name # have the actual implementation at Qubes() instance @@ -88,7 +99,7 @@ def qubesd_call(self, dest, method, arg=None, payload=None, payload_stream) @staticmethod - def _parse_qubesd_response(response_data): + def _parse_qubesd_response(response_data: bytes) -> bytes: '''Parse response from qubesd. In case of success, return actual data. In case of error, @@ -122,7 +133,7 @@ def _parse_qubesd_response(response_data): raise qubesadmin.exc.QubesDaemonCommunicationError( 'Invalid response format') - def property_list(self): + def property_list(self) -> list[str]: ''' List available properties (their names). @@ -138,7 +149,7 @@ def property_list(self): # TODO: make it somehow immutable return self._properties - def property_help(self, name): + def property_help(self, name: str) -> str: ''' Get description of a property. @@ -151,7 +162,7 @@ def property_help(self, name): None) return help_text.decode('ascii') - def property_is_default(self, item): + def property_is_default(self, item: str) -> bool: ''' Check if given property have default value @@ -183,7 +194,7 @@ def property_is_default(self, item): self._properties_cache[item] = (is_default, value) return is_default - def property_get_default(self, item): + def property_get_default(self, item: str) -> VMProperty: ''' Get default property value, regardless of the current value @@ -206,7 +217,8 @@ def property_get_default(self, item): (prop_type, value) = property_str.split(b' ', 1) return self._parse_type_value(prop_type, value) - def clone_properties(self, src, proplist=None): + def clone_properties(self, src: PropertyHolder, + proplist: list[str] | None=None) -> None: '''Clone properties from other object. :param PropertyHolder src: source object @@ -223,7 +235,7 @@ def clone_properties(self, src, proplist=None): except AttributeError: continue - def __getattr__(self, item): + def __getattr__(self, item: str) -> VMProperty: if item.startswith('_'): raise AttributeError(item) # pre-fill cache if enabled @@ -254,7 +266,8 @@ def __getattr__(self, item): raise AttributeError(item) return value - def _deserialize_property(self, api_response): + def _deserialize_property(self, api_response: bytes) \ + -> tuple[bool, VMProperty]: """ Deserialize property.Get response format :param api_response: bytes, as retrieved from qubesd @@ -267,7 +280,7 @@ def _deserialize_property(self, api_response): value = self._parse_type_value(prop_type, value) return is_default, value - def _parse_type_value(self, prop_type, value): + def _parse_type_value(self, prop_type: bytes, value: bytes) -> VMProperty: ''' Parse `type=... ...` qubesd response format. Return a value of appropriate type. @@ -278,20 +291,23 @@ def _parse_type_value(self, prop_type, value): :return: parsed value ''' # pylint: disable=too-many-return-statements - prop_type = prop_type.decode('ascii') + prop_type: str = prop_type.decode('ascii') if not prop_type.startswith('type='): raise qubesadmin.exc.QubesDaemonCommunicationError( 'Invalid type prefix received: {}'.format(prop_type)) (_, prop_type) = prop_type.split('=', 1) - value = value.decode() + value: str = value.decode() if prop_type == 'str': return str(value) if prop_type == 'bool': if value == '': + # TODO shouldn't that at least be ValueError ? + # but then we need to properly propagate that modification return AttributeError return value == "True" if prop_type == 'int': if value == '': + # TODO same as above return AttributeError return int(value) if prop_type == 'vm': @@ -305,7 +321,7 @@ def _parse_type_value(self, prop_type, value): raise qubesadmin.exc.QubesDaemonCommunicationError( 'Received invalid value type: {}'.format(prop_type)) - def _fetch_all_properties(self): + def _fetch_all_properties(self) -> None: """ Retrieve all properties values at once using (prefix).property.GetAll method. If it succeed, save retrieved values in the properties cache. @@ -315,7 +331,7 @@ def _fetch_all_properties(self): :return: None """ - def unescape(line): + def unescape(line: bytes) -> Generator[int]: """Handle \\-escaped values, generates a list of character codes""" escaped = False for char in line: @@ -342,15 +358,15 @@ def unescape(line): return for line in properties_str.splitlines(): # decode newlines - line = bytes(unescape(line)) - name, property_str = line.split(b' ', 1) + line_bytes = bytes(list(unescape(line))) + name, property_str = line_bytes.split(b' ', 1) name = name.decode() is_default, value = self._deserialize_property(property_str) self._properties_cache[name] = (is_default, value) self._properties = list(self._properties_cache.keys()) @classmethod - def _local_properties(cls): + def _local_properties(cls) -> set: ''' Get set of property names that are properties on the Python object, and must not be set on the remote object @@ -367,7 +383,7 @@ def _local_properties(cls): return cls._local_properties_set - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: typing.Any) -> None: # noqa: ANN401 if key.startswith('_') or key in self._local_properties(): return super().__setattr__(key, value) if value is qubesadmin.DEFAULT: @@ -381,7 +397,9 @@ def __setattr__(self, key, value): qubesadmin.exc.QubesVMNotFoundError): raise qubesadmin.exc.QubesPropertyAccessError(key) else: - if isinstance(value, qubesadmin.vm.QubesVM): + # Dynamic import because qubesadmin.vm imports base.py + from qubesadmin.vm import QubesVM + if isinstance(value, QubesVM): value = value.name if value is None: value = '' @@ -395,7 +413,7 @@ def __setattr__(self, key, value): qubesadmin.exc.QubesVMNotFoundError): raise qubesadmin.exc.QubesPropertyAccessError(key) - def __delattr__(self, name): + def __delattr__(self, name: str) -> None: if name.startswith('_') or name in self._local_properties(): return super().__delattr__(name) try: @@ -408,10 +426,13 @@ def __delattr__(self, name): qubesadmin.exc.QubesVMNotFoundError): raise qubesadmin.exc.QubesPropertyAccessError(name) +WrapperObjectsCollectionKey: TypeAlias = int | str +T = TypeVar('T') -class WrapperObjectsCollection: +class WrapperObjectsCollection(Generic[T]): '''Collection of simple named objects''' - def __init__(self, app, list_method, object_class): + def __init__(self, app: QubesBase, + list_method: str, object_class: type[T]): ''' Construct manager of named wrapper objects. @@ -425,11 +446,13 @@ def __init__(self, app, list_method, object_class): self._list_method = list_method self._object_class = object_class #: names cache - self._names_list = None + self._names_list: list[WrapperObjectsCollectionKey] | None = None #: returned objects cache - self._objects = {} + self._objects: dict[WrapperObjectsCollectionKey, T] = {} - def clear_cache(self, invalidate_name=None): + def clear_cache(self, + invalidate_name: WrapperObjectsCollectionKey | None=None)\ + -> None: """Clear cached list of names. If *invalidate_name* is given, remove that object from cache explicitly too. @@ -438,7 +461,7 @@ def clear_cache(self, invalidate_name=None): if invalidate_name: self._objects.pop(invalidate_name, None) - def refresh_cache(self, force=False): + def refresh_cache(self, force: bool=False) -> None: '''Refresh cached list of names''' if not force and self._names_list is not None: return @@ -447,17 +470,18 @@ def refresh_cache(self, force=False): assert list_data[-1] == '\n' self._names_list = [str(name) for name in list_data[:-1].splitlines()] - for name, obj in list(self._objects.items()): + for name, obj in self._objects.items(): + assert hasattr(obj, "name") if obj.name not in self._names_list: # Object no longer exists del self._objects[name] - def __getitem__(self, item): + def __getitem__(self, item: WrapperObjectsCollectionKey) -> T: if not self.app.blind_mode and item not in self: raise KeyError(item) return self.get_blind(item) - def get_blind(self, item): + def get_blind(self, item: WrapperObjectsCollectionKey) -> T: ''' Get a property without downloading the list and checking if it's present @@ -466,25 +490,30 @@ def get_blind(self, item): self._objects[item] = self._object_class(self.app, item) return self._objects[item] - def __contains__(self, item): + def __contains__(self, item: WrapperObjectsCollectionKey) -> bool: self.refresh_cache() + assert self._names_list is not None return item in self._names_list - def __iter__(self): + def __iter__(self) -> Generator[WrapperObjectsCollectionKey]: self.refresh_cache() + assert self._names_list is not None yield from self._names_list - def keys(self): + def keys(self) -> list[WrapperObjectsCollectionKey]: '''Get list of names.''' self.refresh_cache() + assert self._names_list is not None return list(self._names_list) - def items(self): + def items(self) -> list[tuple[WrapperObjectsCollectionKey, T]]: '''Get list of (key, value) pairs''' self.refresh_cache() + assert self._names_list is not None return [(key, self.get_blind(key)) for key in self._names_list] - def values(self): + def values(self) -> list[T]: '''Get list of objects''' self.refresh_cache() + assert self._names_list is not None return [self.get_blind(key) for key in self._names_list] diff --git a/qubesadmin/device_protocol.py b/qubesadmin/device_protocol.py index c7b60a54..6234266d 100644 --- a/qubesadmin/device_protocol.py +++ b/qubesadmin/device_protocol.py @@ -29,15 +29,20 @@ The same in `qubes-core-admin` and `qubes-core-admin-client`, should be moved to one place. """ - +from __future__ import annotations import string import sys from enum import Enum -from typing import Optional, Dict, Any, List, Union, Tuple, Callable +from typing import Any, TYPE_CHECKING +from collections.abc import Callable import qubesadmin.exc + from qubesadmin.exc import QubesValueError +if TYPE_CHECKING: + from qubesadmin.vm import QubesVM + from qubesadmin.app import VMCollection class ProtocolError(AssertionError): @@ -46,51 +51,27 @@ class ProtocolError(AssertionError): """ -QubesVM = 'qubesadmin.vm.QubesVM' - - class UnexpectedDeviceProperty(qubesadmin.exc.QubesException, ValueError): """ Device has unexpected property such as backend_domain, devclass etc. """ -def qbool(value): - """ - Property setter for boolean properties. - - It accepts (case-insensitive) ``'0'``, ``'no'`` and ``false`` as - :py:obj:`False` and ``'1'``, ``'yes'`` and ``'true'`` as - :py:obj:`True`. - """ - - if isinstance(value, str): - lcvalue = value.lower() - if lcvalue in ("0", "no", "false", "off"): - return False - if lcvalue in ("1", "yes", "true", "on"): - return True - raise QubesValueError( - "Invalid literal for boolean property: {!r}".format(value) - ) - - return bool(value) - - class DeviceSerializer: """ Group of method for serialization of device properties. """ - ALLOWED_CHARS_KEY = set( + ALLOWED_CHARS_KEY: set[str] = set( string.digits + string.ascii_letters + r"!#$%&()*+,-./:;<>?@[\]^_{|}~" ) - ALLOWED_CHARS_PARAM = ALLOWED_CHARS_KEY.union(set(string.punctuation + " ")) + ALLOWED_CHARS_PARAM: set[str]\ + = ALLOWED_CHARS_KEY.union(set(string.punctuation + " ")) @classmethod def unpack_properties( cls, untrusted_serialization: bytes - ) -> Tuple[Dict, Dict]: + ) -> tuple[dict, dict]: """ Unpacks basic port properties from a serialized encoded string. @@ -106,8 +87,8 @@ def unpack_properties( "ascii", errors="strict" ).strip() - properties: Dict[str, str] = {} - options: Dict[str, str] = {} + properties: dict[str, str] = {} + options: dict[str, str] = {} if not ut_decoded: return properties, options @@ -155,7 +136,7 @@ def unpack_properties( return properties, options @classmethod - def pack_property(cls, key: str, value: Optional[str]): + def pack_property(cls, key: str, value: object) -> bytes: """ Add property `key=value` to serialization. """ @@ -175,8 +156,8 @@ def pack_property(cls, key: str, value: Optional[str]): @staticmethod def parse_basic_device_properties( - expected_device: "VirtualDevice", properties: Dict[str, Any] - ): + expected_device: VirtualDevice, properties: dict + ) -> None: """ Validates properties against an expected port configuration. @@ -228,7 +209,7 @@ def parse_basic_device_properties( properties["port"] = expected @staticmethod - def serialize_str(value: str) -> str: + def serialize_str(value: object) -> str: """ Serialize python string to ensure consistency. """ @@ -245,7 +226,7 @@ def deserialize_str(value: str) -> str: def sanitize_str( untrusted_value: str, allowed_chars: set, - replace_char: Optional[str] = None, + replace_char: str | None = None, error_message: str = "", ) -> str: """ @@ -281,18 +262,18 @@ class Port: def __init__( self, - backend_domain: Optional[QubesVM], - port_id: Optional[str], - devclass: Optional[str], + backend_domain: QubesVM | None, + port_id: str | None, + devclass: str | None, ): self.__backend_domain = backend_domain self.__port_id = port_id self.__devclass = devclass - def __hash__(self): + def __hash__(self) -> int: return hash((self.backend_name, self.port_id, self.devclass)) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Port): return ( self.backend_name == other.backend_name @@ -301,7 +282,7 @@ def __eq__(self, other): ) return False - def __lt__(self, other): + def __lt__(self, other: object) -> bool: if isinstance(other, Port): return (self.backend_name, self.devclass, self.port_id) < ( other.backend_name, @@ -313,10 +294,10 @@ def __lt__(self, other): "is not supported" ) - def __repr__(self): + def __repr__(self) -> str: return f"{self.backend_name}+{self.port_id}" - def __str__(self): + def __str__(self) -> str: return f"{self.backend_name}:{self.port_id}" @property @@ -328,8 +309,9 @@ def backend_name(self) -> str: @classmethod def from_qarg( - cls, representation: str, devclass, domains, blind=False - ) -> "Port": + cls, representation: str, devclass: str, + domains: VMCollection, blind: bool=False + ) -> Port: """ Parse qrexec argument + to retrieve Port. """ @@ -341,8 +323,9 @@ def from_qarg( @classmethod def from_str( - cls, representation: str, devclass, domains, blind=False - ) -> "Port": + cls, representation: str, devclass: str, + domains: VMCollection, blind: bool=False + ) -> Port: """ Parse string : to retrieve Port. """ @@ -355,7 +338,7 @@ def from_str( @classmethod def _parse( cls, representation: str, devclass: str, get_domain: Callable, sep: str - ) -> "Port": + ) -> Port: """ Parse string representation and return instance of Port. """ @@ -375,7 +358,7 @@ def port_id(self) -> str: return "*" @property - def backend_domain(self) -> Optional[QubesVM]: + def backend_domain(self) -> QubesVM | None: """Which domain exposed this port. (immutable)""" return self.__backend_domain @@ -390,7 +373,7 @@ def devclass(self) -> str: return "peripheral" @property - def has_devclass(self): + def has_devclass(self) -> bool: """Returns True if devclass is set.""" return self.__devclass is not None @@ -401,10 +384,10 @@ class AnyPort(Port): def __init__(self, devclass: str): super().__init__(None, "*", devclass) - def __repr__(self): + def __repr__(self) -> str: return "*" - def __str__(self): + def __str__(self) -> str: return "*" @@ -419,18 +402,18 @@ class VirtualDevice: def __init__( self, - port: Optional[Port] = None, - device_id: Optional[str] = None, + port: Port | None = None, + device_id: str | None = None, ): assert not isinstance(port, AnyPort) or device_id is not None - self.port: Optional[Port] = port # type: ignore + self.port: Port | None = port self._device_id = device_id - def clone(self, **kwargs) -> "VirtualDevice": + def clone(self, **kwargs) -> VirtualDevice: """ Clone object and substitute attributes with explicitly given. """ - attr: Dict[str, Any] = { + attr: dict[str, Any] = { # noqa:ANN401 "port": self.port, "device_id": self.device_id, } @@ -443,7 +426,7 @@ def port(self) -> Port: return self._port @port.setter - def port(self, value: Union[Port, str, None]): + def port(self, value: Port | str | None) -> None: # pylint: disable=missing-function-docstring if isinstance(value, Port): self._port = value @@ -469,7 +452,7 @@ def is_device_id_set(self) -> bool: return self._device_id is not None @property - def backend_domain(self) -> Optional[QubesVM]: + def backend_domain(self) -> QubesVM | None: # pylint: disable=missing-function-docstring return self.port.backend_domain @@ -499,10 +482,10 @@ def description(self) -> str: return "any device" return self.device_id - def __hash__(self): + def __hash__(self) -> int: return hash((self.port, self.device_id)) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, (VirtualDevice, DeviceAssignment)): result = ( self.port == other.port and self.device_id == other.device_id @@ -512,7 +495,7 @@ def __eq__(self, other): return self.port == other and self.device_id == "*" return super().__eq__(other) - def __lt__(self, other): + def __lt__(self, other: object) -> bool: """ Desired order (important for auto-attachment): @@ -543,11 +526,11 @@ def __lt__(self, other): "is not supported" ) - def __repr__(self): + def __repr__(self) -> str: return f"{self.port!r}:{self.device_id}" @property - def repr_for_qarg(self): + def repr_for_qarg(self) -> str: """Object representation for qrexec argument""" res = repr(self).replace(":", "+") # replace '?' in category @@ -555,18 +538,18 @@ def repr_for_qarg(self): res = res.replace(unknown_dev, "_" * len(unknown_dev)) return res.replace("*", "_") - def __str__(self): + def __str__(self) -> str: return f"{self.port}:{self.device_id}" @classmethod def from_qarg( cls, representation: str, - devclass: Optional[str], - domains, + devclass: str | None, + domains: VMCollection, blind: bool = False, - backend: Optional[QubesVM] = None, - ) -> "VirtualDevice": + backend: QubesVM | None = None, + ) -> VirtualDevice: """ Parse qrexec argument +: to get device info """ @@ -583,15 +566,16 @@ def from_qarg( def from_str( cls, representation: str, - devclass: Optional[str], - domains, + devclass: str | None, + domains: VMCollection | None, blind: bool = False, - backend: Optional[QubesVM] = None, - ) -> "VirtualDevice": + backend: QubesVM | None = None, + ) -> VirtualDevice: """ Parse string +: to get device info """ if backend is None: + assert domains is not None if blind: get_domain = domains.get_blind else: @@ -604,15 +588,16 @@ def from_str( def _parse( cls, representation: str, - devclass: Optional[str], - get_domain: Callable, - backend: Optional[QubesVM], + devclass: str | None, + get_domain: Callable | None, + backend: QubesVM | None, sep: str, - ) -> "VirtualDevice": + ) -> VirtualDevice: """ Parse string representation and return instance of VirtualDevice. """ if backend is None: + assert get_domain is not None backend_name, identity = representation.split(sep, 1) if backend_name == "_": backend_name = "*" @@ -690,7 +675,7 @@ class DeviceCategory(Enum): PCI_USB = ("p0c03**",) @staticmethod - def from_str(interface_encoding: str) -> "DeviceCategory": + def from_str(interface_encoding: str) -> DeviceCategory: """ Returns `DeviceCategory` from data encoded in string. """ @@ -721,7 +706,7 @@ class DeviceInterface: Peripheral device interface wrapper. """ - def __init__(self, interface_encoding: str, devclass: Optional[str] = None): + def __init__(self, interface_encoding: str, devclass: str | None = None): ifc_padded = interface_encoding.ljust(6, "*") if devclass: if len(ifc_padded) > 6: @@ -762,7 +747,7 @@ def __init__(self, interface_encoding: str, devclass: Optional[str] = None): self._category = DeviceCategory.from_str(self._interface_encoding) @property - def devclass(self) -> Optional[str]: + def devclass(self) -> str | None: """Immutable Device class such like: 'usb', 'pci' etc.""" return self._devclass @@ -772,12 +757,12 @@ def category(self) -> DeviceCategory: return self._category @classmethod - def unknown(cls) -> "DeviceInterface": + def unknown(cls) -> DeviceInterface: """Value for unknown device interface.""" return cls("?******") @staticmethod - def from_str_bulk(interfaces: Optional[str]) -> List["DeviceInterface"]: + def from_str_bulk(interfaces: str | None) -> list[DeviceInterface]: """Interprets string of interfaces as list of `DeviceInterface`. Examples: @@ -796,18 +781,18 @@ def from_str_bulk(interfaces: Optional[str]) -> List["DeviceInterface"]: for i in range(0, len(interfaces), 7) ] - def __repr__(self): + def __repr__(self) -> str: return self._interface_encoding - def __hash__(self): + def __hash__(self) -> int: return hash(repr(self)) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, DeviceInterface): return False return repr(self) == repr(other) - def __str__(self): + def __str__(self) -> str: if self.devclass == "block": return "Block Device" if self.devclass in ("usb", "pci"): @@ -848,7 +833,7 @@ def __str__(self): return repr(self) @staticmethod - def _load_classes(bus: str): + def _load_classes(bus: str) -> dict: """ List of known device classes, subclasses and programming interfaces. """ @@ -884,7 +869,7 @@ def _load_classes(bus: str): return result - def matches(self, other: "DeviceInterface") -> bool: + def matches(self, other: DeviceInterface) -> bool: """ Check if this `DeviceInterface` (pattern) matches given one. @@ -913,15 +898,15 @@ def __init__( self, port: Port, *, - vendor: Optional[str] = None, - product: Optional[str] = None, - manufacturer: Optional[str] = None, - name: Optional[str] = None, - serial: Optional[str] = None, - interfaces: Optional[List[DeviceInterface]] = None, - parent: Optional["DeviceInfo"] = None, - attachment: Optional[QubesVM] = None, - device_id: Optional[str] = None, + vendor: str | None = None, + product: str | None = None, + manufacturer: str | None = None, + name: str | None = None, + serial: str | None = None, + interfaces: list[DeviceInterface] | None = None, + parent: DeviceInfo | None = None, + attachment: QubesVM | None = None, + device_id: str | None = None, **kwargs, ): super().__init__(port, device_id) @@ -1044,7 +1029,7 @@ def description(self) -> str: return f"{cat}: {vendor} {prod}" @property - def interfaces(self) -> List[DeviceInterface]: + def interfaces(self) -> list[DeviceInterface]: """ Non-empty list of device interfaces. @@ -1055,7 +1040,7 @@ def interfaces(self) -> List[DeviceInterface]: return self._interfaces @property - def parent_device(self) -> Optional[VirtualDevice]: + def parent_device(self) -> VirtualDevice | None: """ The parent device, if any. @@ -1065,7 +1050,7 @@ def parent_device(self) -> Optional[VirtualDevice]: return self._parent @property - def subdevices(self) -> List[VirtualDevice]: + def subdevices(self) -> list[VirtualDevice]: """ The list of children devices if any. @@ -1082,7 +1067,7 @@ def subdevices(self) -> List[VirtualDevice]: ] @property - def attachment(self) -> Optional[QubesVM]: + def attachment(self) -> QubesVM | None: """ VM to which device is attached (frontend domain). """ @@ -1135,8 +1120,8 @@ def deserialize( cls, serialization: bytes, expected_backend_domain: QubesVM, - expected_devclass: Optional[str] = None, - ) -> "DeviceInfo": + expected_devclass: str | None = None, + ) -> DeviceInfo: """ Recovers a serialized object, see: :py:meth:`serialize`. """ @@ -1160,7 +1145,7 @@ def deserialize( @classmethod def _deserialize( cls, untrusted_serialization: bytes, expected_device: VirtualDevice - ) -> "DeviceInfo": + ) -> DeviceInfo: """ Actually deserializes the object. """ @@ -1217,7 +1202,7 @@ def device_id(self) -> str: return self._device_id @device_id.setter - def device_id(self, value): + def device_id(self, value: str) -> None: # Do not auto-override value like in super class self._device_id = value @@ -1226,7 +1211,7 @@ class UnknownDevice(DeviceInfo): """Unknown device - for example, exposed by domain not running currently""" @staticmethod - def from_device(device: VirtualDevice) -> "UnknownDevice": + def from_device(device: VirtualDevice) -> UnknownDevice: """ Return `UnknownDevice` based on any virtual device. """ @@ -1252,9 +1237,9 @@ class DeviceAssignment: def __init__( self, device: VirtualDevice, - frontend_domain=None, - options=None, - mode: Union[str, AssignmentMode] = "manual", + frontend_domain: QubesVM | None=None, + options: dict[str, object] | None=None, + mode: str | AssignmentMode = AssignmentMode.MANUAL, ): if isinstance(device, DeviceInfo): device = VirtualDevice(device.port, device.device_id) @@ -1272,12 +1257,12 @@ def new( backend_domain: QubesVM, port_id: str, devclass: str, - device_id: Optional[str] = None, + device_id: str | None = None, *, - frontend_domain: Optional[QubesVM] = None, - options=None, - mode: Union[str, AssignmentMode] = "manual", - ) -> "DeviceAssignment": + frontend_domain: QubesVM | None = None, + options: dict[str, object] | None=None, + mode: str | AssignmentMode = AssignmentMode.MANUAL, + ) -> DeviceAssignment: """Helper method to create a DeviceAssignment object.""" return cls( VirtualDevice(Port(backend_domain, port_id, devclass), device_id), @@ -1286,7 +1271,7 @@ def new( mode, ) - def clone(self, **kwargs): + def clone(self, **kwargs) -> DeviceAssignment: """ Clone object and substitute attributes with explicitly given. """ @@ -1297,23 +1282,23 @@ def clone(self, **kwargs): "frontend_domain": self.frontend_domain, } attr.update(kwargs) - return self.__class__(**attr) + return self.__class__(**attr) # type: ignore - def __repr__(self): + def __repr__(self) -> str: return f"{self.virtual_device!r}" @property - def repr_for_qarg(self): + def repr_for_qarg(self) -> str: """Object representation for qrexec argument""" return self.virtual_device.repr_for_qarg - def __str__(self): + def __str__(self) -> str: return f"{self.virtual_device}" - def __hash__(self): + def __hash__(self) -> int: return hash(self.virtual_device) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, (VirtualDevice, DeviceAssignment)): result = ( self.port == other.port and self.device_id == other.device_id @@ -1321,7 +1306,7 @@ def __eq__(self, other): return result return False - def __lt__(self, other): + def __lt__(self, other: object) -> bool: if isinstance(other, DeviceAssignment): return self.virtual_device < other.virtual_device if isinstance(other, VirtualDevice): @@ -1332,7 +1317,7 @@ def __lt__(self, other): ) @property - def backend_domain(self) -> Optional[QubesVM]: + def backend_domain(self) -> QubesVM | None: # pylint: disable=missing-function-docstring return self.virtual_device.backend_domain @@ -1357,9 +1342,9 @@ def device_id(self) -> str: return self.virtual_device.device_id @property - def devices(self) -> List[DeviceInfo]: + def devices(self) -> list[DeviceInfo]: """Get DeviceInfo objects corresponding to this DeviceAssignment""" - result: List[DeviceInfo] = [] + result: list[DeviceInfo] = [] if not self.backend_domain: return result if self.port_id != "*": @@ -1399,17 +1384,17 @@ def port(self) -> Port: return Port(self.backend_domain, self.port_id, self.devclass) @property - def frontend_domain(self) -> Optional[QubesVM]: + def frontend_domain(self) -> QubesVM | None: """Which domain the device is attached/assigned to.""" return self.__frontend_domain @frontend_domain.setter - def frontend_domain(self, frontend_domain: Optional[Union[str, QubesVM]]): + def frontend_domain(self, frontend_domain: str | QubesVM | None) -> None: """Which domain the device is attached/assigned to.""" if isinstance(frontend_domain, str): if not self.backend_domain: raise ProtocolError("Cannot determine backend domain") - self.__frontend_domain: Optional[QubesVM] = ( + self.__frontend_domain: QubesVM | None = ( self.backend_domain.app.domains[frontend_domain] ) else: @@ -1448,12 +1433,12 @@ def attach_automatically(self) -> bool: ) @property - def options(self) -> Dict[str, Any]: + def options(self) -> dict[str, object]: """Device options (same as in the legacy API).""" return self.__options @options.setter - def options(self, options: Optional[Dict[str, Any]]): + def options(self, options: dict[str, object] | None) -> None: """Device options (same as in the legacy API).""" self.__options = options or {} @@ -1482,7 +1467,7 @@ def deserialize( cls, serialization: bytes, expected_device: VirtualDevice, - ) -> "DeviceAssignment": + ) -> DeviceAssignment: """ Recovers a serialized object, see: :py:meth:`serialize`. """ @@ -1497,7 +1482,7 @@ def _deserialize( cls, untrusted_serialization: bytes, expected_device: VirtualDevice, - ) -> "DeviceAssignment": + ) -> DeviceAssignment: """ Actually deserializes the object. """ diff --git a/qubesadmin/events/__init__.py b/qubesadmin/events/__init__.py index 9015f580..d2581867 100644 --- a/qubesadmin/events/__init__.py +++ b/qubesadmin/events/__init__.py @@ -246,6 +246,7 @@ def handle(self, subject_name: str | None, event: str, **kwargs) -> None: elif event in ('domain-pre-start', 'domain-start', 'domain-shutdown', 'domain-paused', 'domain-unpaused', 'domain-start-failed'): + assert subject is not None self.app._update_power_state_cache(subject, event, **kwargs) subject.devices.clear_cache() elif event == 'connection-established': @@ -257,6 +258,7 @@ def handle(self, subject_name: str | None, event: str, **kwargs) -> None: "device-unassign", "device-assignment-changed" ): + assert subject is not None devclass = event.split(":")[1] subject.devices[devclass]._assignment_cache = None elif event.split(":")[0] in ( @@ -264,9 +266,11 @@ def handle(self, subject_name: str | None, event: str, **kwargs) -> None: "device-detach", "device-removed" ): + assert subject is not None devclass = event.split(":")[1] subject.devices[devclass]._attachment_cache = None if event.split(":")[0] in ("device-removed",): + assert subject is not None devclass = event.split(":")[1] port_id = kwargs.get("port", ":").split(":")[1] try: diff --git a/qubesadmin/events/utils.py b/qubesadmin/events/utils.py index 9bcf288e..a946c3c4 100644 --- a/qubesadmin/events/utils.py +++ b/qubesadmin/events/utils.py @@ -24,7 +24,6 @@ from typing import Iterable import qubesadmin.events -import qubesadmin.exc from qubesadmin.events import EventsDispatcher from qubesadmin.vm import QubesVM diff --git a/qubesadmin/tests/__init__.py b/qubesadmin/tests/__init__.py index 87a0978e..dcdad57d 100644 --- a/qubesadmin/tests/__init__.py +++ b/qubesadmin/tests/__init__.py @@ -150,6 +150,9 @@ class QubesTest(qubesadmin.app.QubesBase): expected_calls = None actual_calls = None service_calls = None + # This is a special value used only for tests + # This is not valid / expected outside of tests + qubesd_connection_type: str = "none" # type: ignore def __init__(self): super().__init__() diff --git a/qubesadmin/tests/app.py b/qubesadmin/tests/app.py index 3b8796cd..37e8ba12 100644 --- a/qubesadmin/tests/app.py +++ b/qubesadmin/tests/app.py @@ -186,6 +186,71 @@ def test_012_getitem_cached_object(self): self.assertIsNot(vm1, vm4) self.assertAllCalled() + def test_013_get_blind_not_in_membership(self): + """Blind objects must not appear in membership operations.""" + self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00test-vm class=AppVM state=Running\n' + self.assertIn('test-vm', self.app.domains) + self.app.domains.get_blind('other-vm') + self.assertNotIn('other-vm', self.app.domains) + self.assertNotIn('other-vm', self.app.domains.keys()) + self.assertEqual([vm.name for vm in self.app.domains], ['test-vm']) + self.assertAllCalled() + + def test_014_refresh_cache_forces_reload(self): + """refresh_cache() must trigger a new admin.vm.List call even if + already initialised.""" + self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00test-vm class=AppVM state=Running\n' + self.assertIn('test-vm', self.app.domains) + self.assertNotIn('test-vm2', self.app.domains) + self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00test-vm2 class=AppVM state=Running\n' + self.app.domains.refresh_cache() + self.assertNotIn('test-vm', self.app.domains) + self.assertIn('test-vm2', self.app.domains) + self.assertAllCalled() + + def test_015_delitem_targeted_cleanup(self): + """del domains[name] must remove the VM immediately without a + further admin.vm.List call.""" + self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00test-vm class=AppVM state=Running\n' + self.app.expected_calls[('test-vm', 'admin.vm.Remove', None, None)] = \ + b'0\x00' + self.assertIn('test-vm', self.app.domains) + del self.app.domains['test-vm'] + self.assertNotIn('test-vm', self.app.domains) + self.assertAllCalled() + + def test_016_rename_preserves_identity(self): + """After a rename, refresh_cache() must return the same object + under the new name.""" + self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00test-vm class=AppVM state=Running\n' + vm = self.app.domains['test-vm'] + # pylint: disable=protected-access + vm._method_dest = 'new-name' + self.app.domains.clear_cache() + self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00new-name class=AppVM state=Running\n' + vm2 = self.app.domains['new-name'] + self.assertIs(vm, vm2) + self.assertAllCalled() + + def test_017_clear_cache_invalidate_name(self): + """clear_cache(invalidate_name) must drop that VM's object from the + cache so a fresh one is created on next access.""" + self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00test-vm class=AppVM state=Running\n' + vm1 = self.app.domains['test-vm'] + self.app.domains.clear_cache() + vm2 = self.app.domains['test-vm'] + self.assertIs(vm1, vm2) + self.app.domains.clear_cache(invalidate_name = 'test-vm') + vm3 = self.app.domains['test-vm'] + self.assertIsNot(vm1, vm3) + self.assertAllCalled() class TC_10_QubesBase(qubesadmin.tests.QubesTestCase): diff --git a/qubesadmin/tests/tools/qvm_ls.py b/qubesadmin/tests/tools/qvm_ls.py index 23a1d6f0..27e47ed2 100644 --- a/qubesadmin/tests/tools/qvm_ls.py +++ b/qubesadmin/tests/tools/qvm_ls.py @@ -380,17 +380,10 @@ def test_101_list_selected(self): b'0\x00vm1 class=AppVM state=Running\n' \ b'template1 class=TemplateVM state=Halted\n' \ b'sys-net class=AppVM state=Running\n' - self.app.expected_calls[ - ('vm1', 'admin.vm.CurrentState', None, None)] = \ - b'0\x00power_state=Running' - self.app.expected_calls[ - ('sys-net', 'admin.vm.CurrentState', None, None)] = \ - b'0\x00power_state=Running' props = { 'label': 'type=label green', 'template': 'type=vm template1', 'netvm': 'type=vm sys-net', -# 'virt_mode': b'type=str pv', } self.app.expected_calls[ ('vm1', 'admin.vm.property.GetAll', None, None)] = \ diff --git a/qubesadmin/tools/qvm_template.py b/qubesadmin/tools/qvm_template.py index 95a2b77d..6f2a5758 100644 --- a/qubesadmin/tools/qvm_template.py +++ b/qubesadmin/tools/qvm_template.py @@ -1200,7 +1200,7 @@ def verify(rpmfile, reponame, package_hdr=None): name, target + PATH_PREFIX + '/' + name]) - app.domains.refresh_cache(force=True) + app.domains.refresh_cache() tpl = app.domains[name] tpl.features['template-name'] = name diff --git a/qubesadmin/utils.py b/qubesadmin/utils.py index 68b89af5..9f6d5982 100644 --- a/qubesadmin/utils.py +++ b/qubesadmin/utils.py @@ -29,6 +29,7 @@ import re import qubesadmin.exc +from qubesadmin.exc import QubesValueError def parse_size(size): @@ -210,3 +211,25 @@ def release(self): """Unlock the file and close the file object""" fcntl.lockf(self.file, fcntl.LOCK_UN) self.file.close() + + +def qbool(value: str | int | bool) -> bool: + """ + Property setter for boolean properties. + + It accepts (case-insensitive) ``'0'``, ``'no'`` and ``false`` as + :py:obj:`False` and ``'1'``, ``'yes'`` and ``'true'`` as + :py:obj:`True`. + """ + + if isinstance(value, str): + lcvalue = value.lower() + if lcvalue in ("0", "no", "false", "off"): + return False + if lcvalue in ("1", "yes", "true", "on"): + return True + raise QubesValueError( + "Invalid literal for boolean property: {!r}".format(value) + ) + + return bool(value) diff --git a/qubesadmin/vm/__init__.py b/qubesadmin/vm/__init__.py index 00785db8..13e544d6 100644 --- a/qubesadmin/vm/__init__.py +++ b/qubesadmin/vm/__init__.py @@ -19,35 +19,38 @@ # with this program; if not, see . """Qubes VM objects.""" - +from __future__ import annotations import logging import shlex import subprocess +import typing import warnings from logging import Logger +from typing import Literal -import qubesadmin.base import qubesadmin.exc import qubesadmin.storage import qubesadmin.features import qubesadmin.devices import qubesadmin.device_protocol import qubesadmin.firewall -import qubesadmin.tags + +if typing.TYPE_CHECKING: + import qubesadmin.base + +Klass = Literal["AppVM", "AdminVM", "TemplateVM", "DispVM", "StandaloneVM"] +PowerState = Literal["Transient", "Running", "Halted", "Paused", +"Suspended", "Halting", "Dying", "Crashed", "NA"] class QubesVM(qubesadmin.base.PropertyHolder): """Qubes domain.""" log: Logger - tags: qubesadmin.tags.Tags - features: qubesadmin.features.Features - devices: qubesadmin.devices.DeviceManager - firewall: qubesadmin.firewall.Firewall def __init__(self, app, name, klass=None, power_state=None): @@ -64,7 +67,7 @@ def __init__(self, app, name, klass=None, power_state=None): self.firewall = qubesadmin.firewall.Firewall(self) @property - def name(self): + def name(self) -> str: """Domain name""" return self._method_dest @@ -80,7 +83,7 @@ def name(self, new_value): self._volumes = None self.app.domains.clear_cache() - def __str__(self): + def __str__(self) -> str: return self._method_dest def __lt__(self, other): @@ -205,7 +208,7 @@ def get_power_state(self): """ - if self._power_state_cache is not None: + if self._power_state_cache is not None and self.app.cache_enabled: return self._power_state_cache try: power_state = self._get_current_state()["power_state"]