diff --git a/qubesadmin/app.py b/qubesadmin/app.py
index 808b0f7f..83f584e8 100644
--- a/qubesadmin/app.py
+++ b/qubesadmin/app.py
@@ -31,7 +31,11 @@
import sys
import logging
+import typing
from logging import Logger
+from subprocess import Popen
+from typing import IO
+from collections.abc import Generator, Iterable
import qubesadmin.base
import qubesadmin.exc
@@ -39,69 +43,94 @@
import qubesadmin.storage
import qubesadmin.utils
import qubesadmin.vm
+from qubesadmin.label import Label
+from qubesadmin.vm import Klass, PowerState
import qubesadmin.config
import qubesadmin.device_protocol
from qubesadmin.vm import QubesVM
try:
- import qubesdb
+ import qubesdb # type: ignore
has_qubesdb = True
except ImportError:
has_qubesdb = False
+DeviceClass = typing.Literal["mic", "block", "pci", "usb", "webcam"]
+
class VMCollection:
- """Collection of VMs objects"""
+ """A lazily-loaded, cached view of the VMs known to qubesd.
+
+ Membership (``__contains__``, ``__iter__``, ``keys()``, ``values()``)
+ reflects exclusively the VMs returned by the most recent
+ ``admin.vm.List`` call. That list is fetched on the first membership
+ query and cached; call :py:meth:`refresh_cache` to reload.
+
+ :py:meth:`get_blind` is **not** a membership operation.
+ It returns (or creates) a :py:class:`~qubesadmin.vm.QubesVM` handle for
+ the given name without querying the API and without registering the name
+ in the membership view. Its purpose is to provide a unique
+ :py:class:`~qubesadmin.vm.QubesVM` object for the same name on repeated
+ calls.
+ Blind objects desynced with ``vm.List`` are pruned from the internal cache
+ on the next :py:meth:`refresh_cache` call.
+
+ ``del domains[name]`` removes the VM by calling ``admin.vm.Remove``.
+ There is no ``__setitem__``; use
+ :py:meth:`~qubesadmin.app.QubesBase.add_new_vm` to create VMs.
+ """
- def __init__(self, app):
+ def __init__(self, app: "QubesBase"):
self.app = app
- self._vm_list = None
- self._vm_objects = {}
+ # object cache — may contain blind objects not confirmed by vm.List
+ self._vms: dict[str, QubesVM] = {}
+ # names returned by the last admin.vm.List call
+ self._known_names: set[str] = set()
+ self._initialized: bool = False
- def clear_cache(self, invalidate_name=None):
+ def clear_cache(self, invalidate_name: str | None=None) -> None:
"""Clear cached list of VMs
- If *invalidate_name* is given, remove that object from cache
- explicitly too.
+ If *invalidate_name* is given, remove that object from the collection.
"""
- self._vm_list = None
+ self._initialized = False
+ self._known_names = set()
if invalidate_name:
- self._vm_objects.pop(invalidate_name, None)
+ self._vms.pop(invalidate_name, None)
- def refresh_cache(self, force=False):
- """Refresh cached list of VMs"""
- if not force and self._vm_list is not None:
- return
+ def _ensure_cache_loaded(self) -> None:
+ """Load the VM list from qubesd if not already cached."""
+ if not self._initialized:
+ self.refresh_cache()
+
+ def refresh_cache(self) -> None:
+ """Reload the VM list from qubesd, discarding any cached data."""
vm_list_data = self.app.qubesd_call("dom0", "admin.vm.List")
- new_vm_list = {}
- # FIXME: this will probably change
+ # Handle renamed VMs
+ vms_by_current_name = {vm.name: vm for vm in self._vms.values()}
+ new_known_names: set[str] = set()
for vm_data in vm_list_data.splitlines():
vm_name, props = vm_data.decode("ascii").split(" ", 1)
- vm_name = str(vm_name)
- props = props.split(" ")
- new_vm_list[vm_name] = dict(
- [vm_prop.split("=", 1) for vm_prop in props]
- )
- # if cache not enabled, drop power state
- if not self.app.cache_enabled:
- try:
- del new_vm_list[vm_name]["state"]
- except KeyError:
- pass
-
- self._vm_list = new_vm_list
- for name, vm in list(self._vm_objects.items()):
- if vm.name not in self._vm_list:
- # VM no longer exists
- del self._vm_objects[name]
- elif vm.klass != self._vm_list[vm.name]["class"]:
- # VM class have changed
- del self._vm_objects[name]
- # TODO: some generation ID, to detect VM re-creation
- elif name != vm.name:
- # renamed
- self._vm_objects[vm.name] = vm
- del self._vm_objects[name]
+ props_dict = dict(vm_prop.split("=", 1)
+ for vm_prop in props.split(" "))
+ klass = typing.cast(Klass, props_dict["class"])
+ power_state = typing.cast(PowerState, props_dict.get("state"))
+ new_known_names.add(vm_name)
+ existing_vm = self._vms.get(vm_name) or\
+ vms_by_current_name.get(vm_name)
+ # TODO: some generation ID (e.g. uuid), to detect VM re-creation
+ if existing_vm is None or existing_vm.klass != klass:
+ self._vms[vm_name] = QubesVM(
+ self.app, vm_name, klass=klass, power_state=power_state
+ )
+ elif existing_vm is not self._vms.get(vm_name):
+ # renamed: existing_vm was found via vms_by_current_name
+ self._vms[vm_name] = existing_vm
+ # Drop objects for VMs that no longer exist
+ self._vms = {name: vm for name, vm in self._vms.items()
+ if name in new_known_names}
+ self._known_names = new_known_names
+ self._initialized = True
def __getitem__(self, item: str | QubesVM) -> QubesVM:
if isinstance(item, QubesVM):
@@ -112,25 +141,15 @@ def __getitem__(self, item: str | QubesVM) -> QubesVM:
def get_blind(self, item: str) -> QubesVM:
"""
- Get a vm without downloading the list
- and checking if exists
+ Get a vm from the collection. If the vm is not in the collection
+ already, a new basic entry will be created from the provided name.
"""
- if item not in self._vm_objects:
- cls = qubesadmin.vm.QubesVM
- # provide class name to constructor, if already cached (which can be
- # done by 'item not in self' check above, unless blind_mode is
- # enabled
- klass = None
- power_state = None
- if self._vm_list and item in self._vm_list:
- klass = self._vm_list[item]["class"]
- power_state = self._vm_list[item].get("state")
- self._vm_objects[item] = cls(
- self.app, item, klass=klass, power_state=power_state
- )
- return self._vm_objects[item]
+ if item not in self._vms:
+ self._vms[item] = QubesVM(self.app, item)
+ return self._vms[item]
- def get(self, item, default=None) -> QubesVM:
+ def get(self, item: str | QubesVM, default: QubesVM | None=None)\
+ -> QubesVM | None:
"""
Get a VM object, or return *default* if it can't be found.
"""
@@ -139,30 +158,31 @@ def get(self, item, default=None) -> QubesVM:
except KeyError:
return default
- def __contains__(self, item):
+ def __contains__(self, item: QubesVM | str) -> bool:
if isinstance(item, qubesadmin.vm.QubesVM):
item = item.name
- self.refresh_cache()
- return item in self._vm_list
+ self._ensure_cache_loaded()
+ return item in self._known_names
- def __delitem__(self, key):
+ def __delitem__(self, key: str) -> None:
self.app.qubesd_call(key, "admin.vm.Remove")
- self.clear_cache()
+ self._known_names.discard(key)
+ self._vms.pop(key, None)
- def __iter__(self):
- self.refresh_cache()
- for vm in sorted(self._vm_list):
- yield self[vm]
+ def __iter__(self) -> Generator[QubesVM, None, None]:
+ self._ensure_cache_loaded()
+ for vm in sorted(self._known_names):
+ yield self.get_blind(vm)
- def keys(self):
+ def keys(self) -> Iterable[str]:
"""Get list of VM names."""
- self.refresh_cache()
- return self._vm_list.keys()
+ self._ensure_cache_loaded()
+ return self._known_names
- def values(self):
+ def values(self) -> list[QubesVM]:
"""Get list of VM objects."""
- self.refresh_cache()
- return [self[name] for name in self._vm_list]
+ self._ensure_cache_loaded()
+ return [self.get_blind(name) for name in self._known_names]
class QubesBase(qubesadmin.base.PropertyHolder):
@@ -177,12 +197,11 @@ class in py:class:`qubesadmin.Qubes` instead, which points at
#: domains (VMs) collection
domains: VMCollection
#: labels collection
- labels: qubesadmin.base.WrapperObjectsCollection
+ labels: qubesadmin.base.WrapperObjectsCollection[Label]
#: storage pools
- pools: qubesadmin.base.WrapperObjectsCollection
+ pools: qubesadmin.base.WrapperObjectsCollection[qubesadmin.storage.Pool]
#: type of qubesd connection: either 'socket' or 'qrexec'
- qubesd_connection_type: str | None = None # See in PR#416 why we keep
- # =None here to not trip the CI
+ qubesd_connection_type: typing.Literal["socket", "qrexec"]
#: logger
log: Logger
#: do not check for object (VM, label etc) existence before really needed
@@ -190,7 +209,7 @@ class in py:class:`qubesadmin.Qubes` instead, which points at
#: cache retrieved properties values
cache_enabled: bool = False
- def __init__(self):
+ def __init__(self) -> None:
super().__init__(self, "admin.property.", "dom0")
self.domains = VMCollection(self)
self.labels = qubesadmin.base.WrapperObjectsCollection(
@@ -200,27 +219,32 @@ def __init__(self):
self, "admin.pool.List", qubesadmin.storage.Pool
)
#: cache for available storage pool drivers and options to create them
- self._pool_drivers = None
+ self._pool_drivers: dict[str, list[str]] | None = None
self.log = logging.getLogger("app")
self._local_name = None
- def list_vmclass(self):
+ def list_vmclass(self) -> list[Klass]:
"""Call Qubesd in order to obtain the vm classes list"""
vmclass = (
self.qubesd_call("dom0", "admin.vmclass.List").decode().splitlines()
)
- return sorted(vmclass)
+ for e in vmclass:
+ assert e in typing.get_args(Klass)
+ return typing.cast(list[Klass], sorted(vmclass))
- def list_deviceclass(self):
+ def list_deviceclass(self) -> list[DeviceClass]:
"""Call Qubesd in order to obtain the device classes list"""
deviceclasses = (
self.qubesd_call("dom0", "admin.deviceclass.List")
.decode()
.splitlines()
)
- return sorted(deviceclasses)
+ for e in deviceclasses:
+ assert e in typing.get_args(DeviceClass)
+
+ return typing.cast(list[DeviceClass], sorted(deviceclasses))
- def _refresh_pool_drivers(self):
+ def _refresh_pool_drivers(self) -> None:
"""
Refresh cached storage pool drivers and their parameters.
@@ -240,17 +264,19 @@ def _refresh_pool_drivers(self):
self._pool_drivers = pool_drivers
@property
- def pool_drivers(self):
+ def pool_drivers(self) -> Iterable[str]:
"""Available storage pool drivers"""
self._refresh_pool_drivers()
+ assert self._pool_drivers is not None
return self._pool_drivers.keys()
- def pool_driver_parameters(self, driver):
+ def pool_driver_parameters(self, driver: str) -> list[str]:
"""Parameters to initialize storage pool using given driver"""
self._refresh_pool_drivers()
+ assert self._pool_drivers is not None
return self._pool_drivers[driver]
- def add_pool(self, name, driver, **kwargs):
+ def add_pool(self, name: str, driver: str, **kwargs) -> None:
"""Add a storage pool to config
:param name: name of storage pool to create
@@ -268,12 +294,12 @@ def add_pool(self, name, driver, **kwargs):
"dom0", "admin.pool.Add", driver, payload.encode("utf-8")
)
- def remove_pool(self, name):
+ def remove_pool(self, name: str) -> None:
"""Remove a storage pool"""
self.qubesd_call("dom0", "admin.pool.Remove", name, None)
@property
- def local_name(self):
+ def local_name(self) -> str:
"""Get localhost name"""
if not self._local_name:
local_name = None
@@ -291,7 +317,7 @@ def local_name(self):
return self._local_name
- def get_label(self, label):
+ def get_label(self, label: str | int) -> Label:
"""Get label as identified by index or name
:throws QubesLabelNotFoundError: when label is not found
@@ -308,10 +334,11 @@ def get_label(self, label):
for i in self.labels.values():
if i.index == int(label):
return i
+ assert isinstance(label, str)
raise qubesadmin.exc.QubesLabelNotFoundError(label)
@staticmethod
- def get_vm_class(clsname):
+ def get_vm_class(clsname: str) -> str:
"""Find the class for a domain.
Compatibility function, client tools use str to identify domain classes.
@@ -323,8 +350,10 @@ def get_vm_class(clsname):
return clsname
def add_new_vm(
- self, cls, name, label, template=None, pool=None, pools=None
- ):
+ self, cls: str | type[QubesVM], name: str, label: str,
+ template: str | QubesVM | None=None, pool: str | None=None,
+ pools: dict | None=None
+ ) -> QubesVM:
"""Create new Virtual Machine
Example usage with custom storage pools:
@@ -380,16 +409,16 @@ def add_new_vm(
def clone_vm(
self,
- src_vm,
- new_name,
- new_cls=None,
+ src_vm: str | QubesVM,
+ new_name: str,
+ new_cls: str | None=None,
*,
- pool=None,
- pools=None,
- ignore_errors=False,
- ignore_volumes=None,
- ignore_devices=False,
- ):
+ pool: str | None=None,
+ pools: dict | None=None,
+ ignore_errors: bool=False,
+ ignore_volumes: list | None=None,
+ ignore_devices: bool=False,
+ ) -> QubesVM:
# pylint: disable=too-many-statements
# pylint: disable=too-many-branches
"""Clone Virtual Machine
@@ -625,8 +654,9 @@ def clone_vm(
return dst_vm
def qubesd_call(
- self, dest, method, arg=None, payload=None, payload_stream=None
- ):
+ self, dest: str | None, method: str, arg: str | None=None,
+ payload: bytes | None=None, payload_stream: IO | None=None
+ ) -> bytes:
"""
Execute Admin API method.
@@ -649,16 +679,16 @@ def qubesd_call(
def run_service(
self,
- dest,
- service,
- user=None,
+ dest: str,
+ service: str,
+ user: str | None=None,
*,
- filter_esc=False,
- localcmd=None,
- wait=True,
- autostart=True,
+ filter_esc: bool=False,
+ localcmd: str | None=None,
+ wait: bool=True,
+ autostart: bool=True,
**kwargs,
- ):
+ ) -> Popen:
"""Run qrexec service in a given destination
*kwargs* are passed verbatim to :py:meth:`subprocess.Popen`.
@@ -680,7 +710,9 @@ def run_service(
)
@staticmethod
- def _call_with_stream(command, payload, payload_stream):
+ def _call_with_stream(command: str | list[str], payload: bytes | None,
+ payload_stream: IO)\
+ -> tuple[Popen, bytes, bytes]:
"""Helper method to pass data to qubesd. Calls a command with
payload and payload_stream as input.
@@ -701,6 +733,7 @@ def _call_with_stream(command, payload, payload_stream):
# because the process can get blocked on stdout or stderr pipe.
# However, in practice the output should be always smaller
# than 4K.
+ assert proc.stdin is not None
proc.stdin.write(payload)
try:
shutil.copyfileobj(payload_stream, proc.stdin)
@@ -713,7 +746,8 @@ def _call_with_stream(command, payload, payload_stream):
stdout, stderr = proc.communicate()
return proc, stdout, stderr
- def _invalidate_cache(self, subject, event, name, **kwargs):
+ def _invalidate_cache(self, subject: QubesVM | None,
+ event: str, name: str, **kwargs) -> None:
"""Invalidate cached value of a property.
This method is designed to be hooked as an event handler for:
@@ -734,15 +768,18 @@ def _invalidate_cache(self, subject, event, name, **kwargs):
:return: none
""" # pylint: disable=unused-argument
if subject is None:
- subject = self
+ subject_or_self = self
+ else:
+ subject_or_self = subject
try:
# pylint: disable=protected-access
- del subject._properties_cache[name]
+ del subject_or_self._properties_cache[name]
except KeyError:
pass
- def _update_power_state_cache(self, subject, event, **kwargs):
+ def _update_power_state_cache(self, subject: QubesVM,
+ event: str, **kwargs) -> None:
"""Update cached VM power state.
This method is designed to be hooked as an event handler for:
@@ -784,7 +821,7 @@ def _update_power_state_cache(self, subject, event, **kwargs):
# pylint: disable=protected-access
subject._power_state_cache = power_state
- def _invalidate_cache_all(self):
+ def _invalidate_cache_all(self) -> None:
"""Invalidate all cached data
@@ -800,7 +837,7 @@ def _invalidate_cache_all(self):
"""
# pylint: disable=protected-access
self.domains.clear_cache()
- for vm in self.domains._vm_objects.values():
+ for vm in self.domains._vms.values():
assert isinstance(vm, qubesadmin.vm.QubesVM)
vm._power_state_cache = None
vm._properties_cache = {}
@@ -817,8 +854,9 @@ class QubesLocal(QubesBase):
qubesd_connection_type = "socket"
def qubesd_call(
- self, dest, method, arg=None, payload=None, payload_stream=None
- ):
+ self, dest: str | None, method: str, arg: str | None=None,
+ payload: bytes | None=None, payload_stream: IO | None=None
+ ) -> bytes:
"""
Execute Admin API method.
@@ -846,13 +884,16 @@ def qubesd_call(
raise qubesadmin.exc.QubesDaemonCommunicationError(
"{} not found".format(method_path)
)
+ assert arg is not None
+ assert dest is not None
command = [
"env",
"QREXEC_REMOTE_DOMAIN=dom0",
"QREXEC_REQUESTED_TARGET=" + dest,
method_path,
- arg,
]
+ if arg is not None:
+ command.append(arg)
if os.getuid() != 0:
command.insert(0, "sudo")
(_, stdout, _) = self._call_with_stream(
@@ -881,16 +922,16 @@ def qubesd_call(
def run_service(
self,
- dest,
- service,
- user=None,
+ dest: str,
+ service: str,
+ user: str | None=None,
*,
- filter_esc=False,
- localcmd=None,
- wait=True,
- autostart=True,
+ filter_esc: bool=False,
+ localcmd: str | None=None,
+ wait: bool=True,
+ autostart: bool=True,
**kwargs,
- ):
+ ) -> Popen:
"""Run qrexec service in a given destination
:param str dest: Destination - may be a VM name or empty
@@ -990,8 +1031,9 @@ class QubesRemote(QubesBase):
qubesd_connection_type = "qrexec"
def qubesd_call(
- self, dest, method, arg=None, payload=None, payload_stream=None
- ):
+ self, dest: str | None, method: str, arg: str | None=None,
+ payload: bytes | None=None, payload_stream: IO | None=None
+ ) -> bytes:
"""
Execute Admin API method.
@@ -1008,6 +1050,7 @@ def qubesd_call(
.. warning:: *payload_stream* will get closed by this function
"""
service_name = method
+ assert dest is not None
if arg is not None:
service_name += "+" + arg
command = [qubesadmin.config.QREXEC_CLIENT_VM, dest, service_name]
@@ -1032,16 +1075,16 @@ def qubesd_call(
def run_service(
self,
- dest,
- service,
- user=None,
+ dest: str,
+ service: str,
+ user: str | None=None,
*,
- filter_esc=False,
- localcmd=None,
- wait=True,
- autostart=True,
+ filter_esc: bool=False,
+ localcmd: str | None=None,
+ wait: bool=True,
+ autostart: bool=True,
**kwargs,
- ):
+ ) -> Popen:
"""Run qrexec service in a given destination
:param str dest: Destination - may be a VM name or empty
diff --git a/qubesadmin/backup/core3.py b/qubesadmin/backup/core3.py
index bb540680..ce47d27f 100644
--- a/qubesadmin/backup/core3.py
+++ b/qubesadmin/backup/core3.py
@@ -28,12 +28,10 @@
import qubesadmin.backup
import qubesadmin.firewall
-from qubesadmin import device_protocol
+from qubesadmin import utils
from qubesadmin.vm import QubesVM
-
-
class Core3VM(qubesadmin.backup.BackupVM):
'''VM object'''
@property
@@ -138,7 +136,7 @@ def import_core3_vm(self, element: _Element) -> None:
for opt_node in node.findall('./option'):
opt_name = opt_node.get('name')
options[opt_name] = opt_node.text
- options['required'] = device_protocol.qbool(
+ options['required'] = utils.qbool(
node.get('required', 'yes'))
vm.devices[bus_name][(backend_domain, port_id)] = options
diff --git a/qubesadmin/base.py b/qubesadmin/base.py
index ca61ba22..6f1f8500 100644
--- a/qubesadmin/base.py
+++ b/qubesadmin/base.py
@@ -19,11 +19,24 @@
# with this program; if not, see .
'''Base classes for managed objects'''
+from __future__ import annotations
+
+import typing
+from typing import BinaryIO, Any, TypeAlias, TypeVar, Generic
+from collections.abc import Generator
import qubesadmin.exc
+if typing.TYPE_CHECKING:
+ from qubesadmin.vm import QubesVM
+ from qubesadmin.app import QubesBase
+
DEFAULT = object()
+# We use Any because the dynamic metatada handling of the current code
+# is too complex for type checkers otherwise
+VMProperty: TypeAlias = Any # noqa: ANN401
+
class PropertyHolder:
'''A base class for object having properties retrievable using mgmt API.
@@ -34,28 +47,29 @@ class PropertyHolder:
'''
#: a place for appropriate Qubes() object (QubesLocal or QubesRemote),
# use None for self
- app = None
+ app: QubesBase
- def __init__(self, app, method_prefix, method_dest):
+ def __init__(self, app: QubesBase, method_prefix: str, method_dest: str):
#: appropriate Qubes() object (QubesLocal or QubesRemote), use None
# for self
self.app = app
self._method_prefix = method_prefix
self._method_dest = method_dest
- self._properties = None
+ self._properties: list[str] | None = None
self._properties_help = None
# the cache is maintained by EventsDispatcher(),
# through helper functions in QubesBase()
self._properties_cache = {}
- def clear_cache(self):
+ def clear_cache(self) -> None:
"""
Clear property cache.
"""
self._properties_cache = {}
- def qubesd_call(self, dest, method, arg=None, payload=None,
- payload_stream=None):
+ def qubesd_call(self, dest: str | None, method: str,
+ arg: str | None=None, payload: bytes | None=None,
+ payload_stream: BinaryIO | None=None) -> bytes:
'''
Call into qubesd using appropriate mechanism. This method should be
defined by a subclass.
@@ -69,10 +83,7 @@ def qubesd_call(self, dest, method, arg=None, payload=None,
:param payload_stream: file-like object to read payload from
:return: Data returned by qubesd (string)
'''
- if not self.app:
- raise NotImplementedError
- if dest is None:
- dest = self._method_dest
+ dest: str = dest or self._method_dest
if (
getattr(self, "_redirect_dispvm_calls", False)
and dest.startswith("@dispvm")
@@ -80,7 +91,7 @@ def qubesd_call(self, dest, method, arg=None, payload=None,
if dest.startswith("@dispvm:"):
dest = dest[len("@dispvm:") :]
else:
- dest = getattr(self.app, "default_dispvm", None)
+ dest: QubesVM | None = getattr(self.app, "default_dispvm", None)
if dest:
dest = dest.name
# have the actual implementation at Qubes() instance
@@ -88,7 +99,7 @@ def qubesd_call(self, dest, method, arg=None, payload=None,
payload_stream)
@staticmethod
- def _parse_qubesd_response(response_data):
+ def _parse_qubesd_response(response_data: bytes) -> bytes:
'''Parse response from qubesd.
In case of success, return actual data. In case of error,
@@ -122,7 +133,7 @@ def _parse_qubesd_response(response_data):
raise qubesadmin.exc.QubesDaemonCommunicationError(
'Invalid response format')
- def property_list(self):
+ def property_list(self) -> list[str]:
'''
List available properties (their names).
@@ -138,7 +149,7 @@ def property_list(self):
# TODO: make it somehow immutable
return self._properties
- def property_help(self, name):
+ def property_help(self, name: str) -> str:
'''
Get description of a property.
@@ -151,7 +162,7 @@ def property_help(self, name):
None)
return help_text.decode('ascii')
- def property_is_default(self, item):
+ def property_is_default(self, item: str) -> bool:
'''
Check if given property have default value
@@ -183,7 +194,7 @@ def property_is_default(self, item):
self._properties_cache[item] = (is_default, value)
return is_default
- def property_get_default(self, item):
+ def property_get_default(self, item: str) -> VMProperty:
'''
Get default property value, regardless of the current value
@@ -206,7 +217,8 @@ def property_get_default(self, item):
(prop_type, value) = property_str.split(b' ', 1)
return self._parse_type_value(prop_type, value)
- def clone_properties(self, src, proplist=None):
+ def clone_properties(self, src: PropertyHolder,
+ proplist: list[str] | None=None) -> None:
'''Clone properties from other object.
:param PropertyHolder src: source object
@@ -223,7 +235,7 @@ def clone_properties(self, src, proplist=None):
except AttributeError:
continue
- def __getattr__(self, item):
+ def __getattr__(self, item: str) -> VMProperty:
if item.startswith('_'):
raise AttributeError(item)
# pre-fill cache if enabled
@@ -254,7 +266,8 @@ def __getattr__(self, item):
raise AttributeError(item)
return value
- def _deserialize_property(self, api_response):
+ def _deserialize_property(self, api_response: bytes) \
+ -> tuple[bool, VMProperty]:
"""
Deserialize property.Get response format
:param api_response: bytes, as retrieved from qubesd
@@ -267,7 +280,7 @@ def _deserialize_property(self, api_response):
value = self._parse_type_value(prop_type, value)
return is_default, value
- def _parse_type_value(self, prop_type, value):
+ def _parse_type_value(self, prop_type: bytes, value: bytes) -> VMProperty:
'''
Parse `type=... ...` qubesd response format. Return a value of
appropriate type.
@@ -278,20 +291,23 @@ def _parse_type_value(self, prop_type, value):
:return: parsed value
'''
# pylint: disable=too-many-return-statements
- prop_type = prop_type.decode('ascii')
+ prop_type: str = prop_type.decode('ascii')
if not prop_type.startswith('type='):
raise qubesadmin.exc.QubesDaemonCommunicationError(
'Invalid type prefix received: {}'.format(prop_type))
(_, prop_type) = prop_type.split('=', 1)
- value = value.decode()
+ value: str = value.decode()
if prop_type == 'str':
return str(value)
if prop_type == 'bool':
if value == '':
+ # TODO shouldn't that at least be ValueError ?
+ # but then we need to properly propagate that modification
return AttributeError
return value == "True"
if prop_type == 'int':
if value == '':
+ # TODO same as above
return AttributeError
return int(value)
if prop_type == 'vm':
@@ -305,7 +321,7 @@ def _parse_type_value(self, prop_type, value):
raise qubesadmin.exc.QubesDaemonCommunicationError(
'Received invalid value type: {}'.format(prop_type))
- def _fetch_all_properties(self):
+ def _fetch_all_properties(self) -> None:
"""
Retrieve all properties values at once using (prefix).property.GetAll
method. If it succeed, save retrieved values in the properties cache.
@@ -315,7 +331,7 @@ def _fetch_all_properties(self):
:return: None
"""
- def unescape(line):
+ def unescape(line: bytes) -> Generator[int]:
"""Handle \\-escaped values, generates a list of character codes"""
escaped = False
for char in line:
@@ -342,15 +358,15 @@ def unescape(line):
return
for line in properties_str.splitlines():
# decode newlines
- line = bytes(unescape(line))
- name, property_str = line.split(b' ', 1)
+ line_bytes = bytes(list(unescape(line)))
+ name, property_str = line_bytes.split(b' ', 1)
name = name.decode()
is_default, value = self._deserialize_property(property_str)
self._properties_cache[name] = (is_default, value)
self._properties = list(self._properties_cache.keys())
@classmethod
- def _local_properties(cls):
+ def _local_properties(cls) -> set:
'''
Get set of property names that are properties on the Python object,
and must not be set on the remote object
@@ -367,7 +383,7 @@ def _local_properties(cls):
return cls._local_properties_set
- def __setattr__(self, key, value):
+ def __setattr__(self, key: str, value: typing.Any) -> None: # noqa: ANN401
if key.startswith('_') or key in self._local_properties():
return super().__setattr__(key, value)
if value is qubesadmin.DEFAULT:
@@ -381,7 +397,9 @@ def __setattr__(self, key, value):
qubesadmin.exc.QubesVMNotFoundError):
raise qubesadmin.exc.QubesPropertyAccessError(key)
else:
- if isinstance(value, qubesadmin.vm.QubesVM):
+ # Dynamic import because qubesadmin.vm imports base.py
+ from qubesadmin.vm import QubesVM
+ if isinstance(value, QubesVM):
value = value.name
if value is None:
value = ''
@@ -395,7 +413,7 @@ def __setattr__(self, key, value):
qubesadmin.exc.QubesVMNotFoundError):
raise qubesadmin.exc.QubesPropertyAccessError(key)
- def __delattr__(self, name):
+ def __delattr__(self, name: str) -> None:
if name.startswith('_') or name in self._local_properties():
return super().__delattr__(name)
try:
@@ -408,10 +426,13 @@ def __delattr__(self, name):
qubesadmin.exc.QubesVMNotFoundError):
raise qubesadmin.exc.QubesPropertyAccessError(name)
+WrapperObjectsCollectionKey: TypeAlias = int | str
+T = TypeVar('T')
-class WrapperObjectsCollection:
+class WrapperObjectsCollection(Generic[T]):
'''Collection of simple named objects'''
- def __init__(self, app, list_method, object_class):
+ def __init__(self, app: QubesBase,
+ list_method: str, object_class: type[T]):
'''
Construct manager of named wrapper objects.
@@ -425,11 +446,13 @@ def __init__(self, app, list_method, object_class):
self._list_method = list_method
self._object_class = object_class
#: names cache
- self._names_list = None
+ self._names_list: list[WrapperObjectsCollectionKey] | None = None
#: returned objects cache
- self._objects = {}
+ self._objects: dict[WrapperObjectsCollectionKey, T] = {}
- def clear_cache(self, invalidate_name=None):
+ def clear_cache(self,
+ invalidate_name: WrapperObjectsCollectionKey | None=None)\
+ -> None:
"""Clear cached list of names.
If *invalidate_name* is given, remove that object from cache
explicitly too.
@@ -438,7 +461,7 @@ def clear_cache(self, invalidate_name=None):
if invalidate_name:
self._objects.pop(invalidate_name, None)
- def refresh_cache(self, force=False):
+ def refresh_cache(self, force: bool=False) -> None:
'''Refresh cached list of names'''
if not force and self._names_list is not None:
return
@@ -447,17 +470,18 @@ def refresh_cache(self, force=False):
assert list_data[-1] == '\n'
self._names_list = [str(name) for name in list_data[:-1].splitlines()]
- for name, obj in list(self._objects.items()):
+ for name, obj in self._objects.items():
+ assert hasattr(obj, "name")
if obj.name not in self._names_list:
# Object no longer exists
del self._objects[name]
- def __getitem__(self, item):
+ def __getitem__(self, item: WrapperObjectsCollectionKey) -> T:
if not self.app.blind_mode and item not in self:
raise KeyError(item)
return self.get_blind(item)
- def get_blind(self, item):
+ def get_blind(self, item: WrapperObjectsCollectionKey) -> T:
'''
Get a property without downloading the list
and checking if it's present
@@ -466,25 +490,30 @@ def get_blind(self, item):
self._objects[item] = self._object_class(self.app, item)
return self._objects[item]
- def __contains__(self, item):
+ def __contains__(self, item: WrapperObjectsCollectionKey) -> bool:
self.refresh_cache()
+ assert self._names_list is not None
return item in self._names_list
- def __iter__(self):
+ def __iter__(self) -> Generator[WrapperObjectsCollectionKey]:
self.refresh_cache()
+ assert self._names_list is not None
yield from self._names_list
- def keys(self):
+ def keys(self) -> list[WrapperObjectsCollectionKey]:
'''Get list of names.'''
self.refresh_cache()
+ assert self._names_list is not None
return list(self._names_list)
- def items(self):
+ def items(self) -> list[tuple[WrapperObjectsCollectionKey, T]]:
'''Get list of (key, value) pairs'''
self.refresh_cache()
+ assert self._names_list is not None
return [(key, self.get_blind(key)) for key in self._names_list]
- def values(self):
+ def values(self) -> list[T]:
'''Get list of objects'''
self.refresh_cache()
+ assert self._names_list is not None
return [self.get_blind(key) for key in self._names_list]
diff --git a/qubesadmin/device_protocol.py b/qubesadmin/device_protocol.py
index c7b60a54..6234266d 100644
--- a/qubesadmin/device_protocol.py
+++ b/qubesadmin/device_protocol.py
@@ -29,15 +29,20 @@
The same in `qubes-core-admin` and `qubes-core-admin-client`,
should be moved to one place.
"""
-
+from __future__ import annotations
import string
import sys
from enum import Enum
-from typing import Optional, Dict, Any, List, Union, Tuple, Callable
+from typing import Any, TYPE_CHECKING
+from collections.abc import Callable
import qubesadmin.exc
+
from qubesadmin.exc import QubesValueError
+if TYPE_CHECKING:
+ from qubesadmin.vm import QubesVM
+ from qubesadmin.app import VMCollection
class ProtocolError(AssertionError):
@@ -46,51 +51,27 @@ class ProtocolError(AssertionError):
"""
-QubesVM = 'qubesadmin.vm.QubesVM'
-
-
class UnexpectedDeviceProperty(qubesadmin.exc.QubesException, ValueError):
"""
Device has unexpected property such as backend_domain, devclass etc.
"""
-def qbool(value):
- """
- Property setter for boolean properties.
-
- It accepts (case-insensitive) ``'0'``, ``'no'`` and ``false`` as
- :py:obj:`False` and ``'1'``, ``'yes'`` and ``'true'`` as
- :py:obj:`True`.
- """
-
- if isinstance(value, str):
- lcvalue = value.lower()
- if lcvalue in ("0", "no", "false", "off"):
- return False
- if lcvalue in ("1", "yes", "true", "on"):
- return True
- raise QubesValueError(
- "Invalid literal for boolean property: {!r}".format(value)
- )
-
- return bool(value)
-
-
class DeviceSerializer:
"""
Group of method for serialization of device properties.
"""
- ALLOWED_CHARS_KEY = set(
+ ALLOWED_CHARS_KEY: set[str] = set(
string.digits + string.ascii_letters + r"!#$%&()*+,-./:;<>?@[\]^_{|}~"
)
- ALLOWED_CHARS_PARAM = ALLOWED_CHARS_KEY.union(set(string.punctuation + " "))
+ ALLOWED_CHARS_PARAM: set[str]\
+ = ALLOWED_CHARS_KEY.union(set(string.punctuation + " "))
@classmethod
def unpack_properties(
cls, untrusted_serialization: bytes
- ) -> Tuple[Dict, Dict]:
+ ) -> tuple[dict, dict]:
"""
Unpacks basic port properties from a serialized encoded string.
@@ -106,8 +87,8 @@ def unpack_properties(
"ascii", errors="strict"
).strip()
- properties: Dict[str, str] = {}
- options: Dict[str, str] = {}
+ properties: dict[str, str] = {}
+ options: dict[str, str] = {}
if not ut_decoded:
return properties, options
@@ -155,7 +136,7 @@ def unpack_properties(
return properties, options
@classmethod
- def pack_property(cls, key: str, value: Optional[str]):
+ def pack_property(cls, key: str, value: object) -> bytes:
"""
Add property `key=value` to serialization.
"""
@@ -175,8 +156,8 @@ def pack_property(cls, key: str, value: Optional[str]):
@staticmethod
def parse_basic_device_properties(
- expected_device: "VirtualDevice", properties: Dict[str, Any]
- ):
+ expected_device: VirtualDevice, properties: dict
+ ) -> None:
"""
Validates properties against an expected port configuration.
@@ -228,7 +209,7 @@ def parse_basic_device_properties(
properties["port"] = expected
@staticmethod
- def serialize_str(value: str) -> str:
+ def serialize_str(value: object) -> str:
"""
Serialize python string to ensure consistency.
"""
@@ -245,7 +226,7 @@ def deserialize_str(value: str) -> str:
def sanitize_str(
untrusted_value: str,
allowed_chars: set,
- replace_char: Optional[str] = None,
+ replace_char: str | None = None,
error_message: str = "",
) -> str:
"""
@@ -281,18 +262,18 @@ class Port:
def __init__(
self,
- backend_domain: Optional[QubesVM],
- port_id: Optional[str],
- devclass: Optional[str],
+ backend_domain: QubesVM | None,
+ port_id: str | None,
+ devclass: str | None,
):
self.__backend_domain = backend_domain
self.__port_id = port_id
self.__devclass = devclass
- def __hash__(self):
+ def __hash__(self) -> int:
return hash((self.backend_name, self.port_id, self.devclass))
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
if isinstance(other, Port):
return (
self.backend_name == other.backend_name
@@ -301,7 +282,7 @@ def __eq__(self, other):
)
return False
- def __lt__(self, other):
+ def __lt__(self, other: object) -> bool:
if isinstance(other, Port):
return (self.backend_name, self.devclass, self.port_id) < (
other.backend_name,
@@ -313,10 +294,10 @@ def __lt__(self, other):
"is not supported"
)
- def __repr__(self):
+ def __repr__(self) -> str:
return f"{self.backend_name}+{self.port_id}"
- def __str__(self):
+ def __str__(self) -> str:
return f"{self.backend_name}:{self.port_id}"
@property
@@ -328,8 +309,9 @@ def backend_name(self) -> str:
@classmethod
def from_qarg(
- cls, representation: str, devclass, domains, blind=False
- ) -> "Port":
+ cls, representation: str, devclass: str,
+ domains: VMCollection, blind: bool=False
+ ) -> Port:
"""
Parse qrexec argument + to retrieve Port.
"""
@@ -341,8 +323,9 @@ def from_qarg(
@classmethod
def from_str(
- cls, representation: str, devclass, domains, blind=False
- ) -> "Port":
+ cls, representation: str, devclass: str,
+ domains: VMCollection, blind: bool=False
+ ) -> Port:
"""
Parse string : to retrieve Port.
"""
@@ -355,7 +338,7 @@ def from_str(
@classmethod
def _parse(
cls, representation: str, devclass: str, get_domain: Callable, sep: str
- ) -> "Port":
+ ) -> Port:
"""
Parse string representation and return instance of Port.
"""
@@ -375,7 +358,7 @@ def port_id(self) -> str:
return "*"
@property
- def backend_domain(self) -> Optional[QubesVM]:
+ def backend_domain(self) -> QubesVM | None:
"""Which domain exposed this port. (immutable)"""
return self.__backend_domain
@@ -390,7 +373,7 @@ def devclass(self) -> str:
return "peripheral"
@property
- def has_devclass(self):
+ def has_devclass(self) -> bool:
"""Returns True if devclass is set."""
return self.__devclass is not None
@@ -401,10 +384,10 @@ class AnyPort(Port):
def __init__(self, devclass: str):
super().__init__(None, "*", devclass)
- def __repr__(self):
+ def __repr__(self) -> str:
return "*"
- def __str__(self):
+ def __str__(self) -> str:
return "*"
@@ -419,18 +402,18 @@ class VirtualDevice:
def __init__(
self,
- port: Optional[Port] = None,
- device_id: Optional[str] = None,
+ port: Port | None = None,
+ device_id: str | None = None,
):
assert not isinstance(port, AnyPort) or device_id is not None
- self.port: Optional[Port] = port # type: ignore
+ self.port: Port | None = port
self._device_id = device_id
- def clone(self, **kwargs) -> "VirtualDevice":
+ def clone(self, **kwargs) -> VirtualDevice:
"""
Clone object and substitute attributes with explicitly given.
"""
- attr: Dict[str, Any] = {
+ attr: dict[str, Any] = { # noqa:ANN401
"port": self.port,
"device_id": self.device_id,
}
@@ -443,7 +426,7 @@ def port(self) -> Port:
return self._port
@port.setter
- def port(self, value: Union[Port, str, None]):
+ def port(self, value: Port | str | None) -> None:
# pylint: disable=missing-function-docstring
if isinstance(value, Port):
self._port = value
@@ -469,7 +452,7 @@ def is_device_id_set(self) -> bool:
return self._device_id is not None
@property
- def backend_domain(self) -> Optional[QubesVM]:
+ def backend_domain(self) -> QubesVM | None:
# pylint: disable=missing-function-docstring
return self.port.backend_domain
@@ -499,10 +482,10 @@ def description(self) -> str:
return "any device"
return self.device_id
- def __hash__(self):
+ def __hash__(self) -> int:
return hash((self.port, self.device_id))
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
if isinstance(other, (VirtualDevice, DeviceAssignment)):
result = (
self.port == other.port and self.device_id == other.device_id
@@ -512,7 +495,7 @@ def __eq__(self, other):
return self.port == other and self.device_id == "*"
return super().__eq__(other)
- def __lt__(self, other):
+ def __lt__(self, other: object) -> bool:
"""
Desired order (important for auto-attachment):
@@ -543,11 +526,11 @@ def __lt__(self, other):
"is not supported"
)
- def __repr__(self):
+ def __repr__(self) -> str:
return f"{self.port!r}:{self.device_id}"
@property
- def repr_for_qarg(self):
+ def repr_for_qarg(self) -> str:
"""Object representation for qrexec argument"""
res = repr(self).replace(":", "+")
# replace '?' in category
@@ -555,18 +538,18 @@ def repr_for_qarg(self):
res = res.replace(unknown_dev, "_" * len(unknown_dev))
return res.replace("*", "_")
- def __str__(self):
+ def __str__(self) -> str:
return f"{self.port}:{self.device_id}"
@classmethod
def from_qarg(
cls,
representation: str,
- devclass: Optional[str],
- domains,
+ devclass: str | None,
+ domains: VMCollection,
blind: bool = False,
- backend: Optional[QubesVM] = None,
- ) -> "VirtualDevice":
+ backend: QubesVM | None = None,
+ ) -> VirtualDevice:
"""
Parse qrexec argument +: to get device info
"""
@@ -583,15 +566,16 @@ def from_qarg(
def from_str(
cls,
representation: str,
- devclass: Optional[str],
- domains,
+ devclass: str | None,
+ domains: VMCollection | None,
blind: bool = False,
- backend: Optional[QubesVM] = None,
- ) -> "VirtualDevice":
+ backend: QubesVM | None = None,
+ ) -> VirtualDevice:
"""
Parse string +: to get device info
"""
if backend is None:
+ assert domains is not None
if blind:
get_domain = domains.get_blind
else:
@@ -604,15 +588,16 @@ def from_str(
def _parse(
cls,
representation: str,
- devclass: Optional[str],
- get_domain: Callable,
- backend: Optional[QubesVM],
+ devclass: str | None,
+ get_domain: Callable | None,
+ backend: QubesVM | None,
sep: str,
- ) -> "VirtualDevice":
+ ) -> VirtualDevice:
"""
Parse string representation and return instance of VirtualDevice.
"""
if backend is None:
+ assert get_domain is not None
backend_name, identity = representation.split(sep, 1)
if backend_name == "_":
backend_name = "*"
@@ -690,7 +675,7 @@ class DeviceCategory(Enum):
PCI_USB = ("p0c03**",)
@staticmethod
- def from_str(interface_encoding: str) -> "DeviceCategory":
+ def from_str(interface_encoding: str) -> DeviceCategory:
"""
Returns `DeviceCategory` from data encoded in string.
"""
@@ -721,7 +706,7 @@ class DeviceInterface:
Peripheral device interface wrapper.
"""
- def __init__(self, interface_encoding: str, devclass: Optional[str] = None):
+ def __init__(self, interface_encoding: str, devclass: str | None = None):
ifc_padded = interface_encoding.ljust(6, "*")
if devclass:
if len(ifc_padded) > 6:
@@ -762,7 +747,7 @@ def __init__(self, interface_encoding: str, devclass: Optional[str] = None):
self._category = DeviceCategory.from_str(self._interface_encoding)
@property
- def devclass(self) -> Optional[str]:
+ def devclass(self) -> str | None:
"""Immutable Device class such like: 'usb', 'pci' etc."""
return self._devclass
@@ -772,12 +757,12 @@ def category(self) -> DeviceCategory:
return self._category
@classmethod
- def unknown(cls) -> "DeviceInterface":
+ def unknown(cls) -> DeviceInterface:
"""Value for unknown device interface."""
return cls("?******")
@staticmethod
- def from_str_bulk(interfaces: Optional[str]) -> List["DeviceInterface"]:
+ def from_str_bulk(interfaces: str | None) -> list[DeviceInterface]:
"""Interprets string of interfaces as list of `DeviceInterface`.
Examples:
@@ -796,18 +781,18 @@ def from_str_bulk(interfaces: Optional[str]) -> List["DeviceInterface"]:
for i in range(0, len(interfaces), 7)
]
- def __repr__(self):
+ def __repr__(self) -> str:
return self._interface_encoding
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(repr(self))
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
if not isinstance(other, DeviceInterface):
return False
return repr(self) == repr(other)
- def __str__(self):
+ def __str__(self) -> str:
if self.devclass == "block":
return "Block Device"
if self.devclass in ("usb", "pci"):
@@ -848,7 +833,7 @@ def __str__(self):
return repr(self)
@staticmethod
- def _load_classes(bus: str):
+ def _load_classes(bus: str) -> dict:
"""
List of known device classes, subclasses and programming interfaces.
"""
@@ -884,7 +869,7 @@ def _load_classes(bus: str):
return result
- def matches(self, other: "DeviceInterface") -> bool:
+ def matches(self, other: DeviceInterface) -> bool:
"""
Check if this `DeviceInterface` (pattern) matches given one.
@@ -913,15 +898,15 @@ def __init__(
self,
port: Port,
*,
- vendor: Optional[str] = None,
- product: Optional[str] = None,
- manufacturer: Optional[str] = None,
- name: Optional[str] = None,
- serial: Optional[str] = None,
- interfaces: Optional[List[DeviceInterface]] = None,
- parent: Optional["DeviceInfo"] = None,
- attachment: Optional[QubesVM] = None,
- device_id: Optional[str] = None,
+ vendor: str | None = None,
+ product: str | None = None,
+ manufacturer: str | None = None,
+ name: str | None = None,
+ serial: str | None = None,
+ interfaces: list[DeviceInterface] | None = None,
+ parent: DeviceInfo | None = None,
+ attachment: QubesVM | None = None,
+ device_id: str | None = None,
**kwargs,
):
super().__init__(port, device_id)
@@ -1044,7 +1029,7 @@ def description(self) -> str:
return f"{cat}: {vendor} {prod}"
@property
- def interfaces(self) -> List[DeviceInterface]:
+ def interfaces(self) -> list[DeviceInterface]:
"""
Non-empty list of device interfaces.
@@ -1055,7 +1040,7 @@ def interfaces(self) -> List[DeviceInterface]:
return self._interfaces
@property
- def parent_device(self) -> Optional[VirtualDevice]:
+ def parent_device(self) -> VirtualDevice | None:
"""
The parent device, if any.
@@ -1065,7 +1050,7 @@ def parent_device(self) -> Optional[VirtualDevice]:
return self._parent
@property
- def subdevices(self) -> List[VirtualDevice]:
+ def subdevices(self) -> list[VirtualDevice]:
"""
The list of children devices if any.
@@ -1082,7 +1067,7 @@ def subdevices(self) -> List[VirtualDevice]:
]
@property
- def attachment(self) -> Optional[QubesVM]:
+ def attachment(self) -> QubesVM | None:
"""
VM to which device is attached (frontend domain).
"""
@@ -1135,8 +1120,8 @@ def deserialize(
cls,
serialization: bytes,
expected_backend_domain: QubesVM,
- expected_devclass: Optional[str] = None,
- ) -> "DeviceInfo":
+ expected_devclass: str | None = None,
+ ) -> DeviceInfo:
"""
Recovers a serialized object, see: :py:meth:`serialize`.
"""
@@ -1160,7 +1145,7 @@ def deserialize(
@classmethod
def _deserialize(
cls, untrusted_serialization: bytes, expected_device: VirtualDevice
- ) -> "DeviceInfo":
+ ) -> DeviceInfo:
"""
Actually deserializes the object.
"""
@@ -1217,7 +1202,7 @@ def device_id(self) -> str:
return self._device_id
@device_id.setter
- def device_id(self, value):
+ def device_id(self, value: str) -> None:
# Do not auto-override value like in super class
self._device_id = value
@@ -1226,7 +1211,7 @@ class UnknownDevice(DeviceInfo):
"""Unknown device - for example, exposed by domain not running currently"""
@staticmethod
- def from_device(device: VirtualDevice) -> "UnknownDevice":
+ def from_device(device: VirtualDevice) -> UnknownDevice:
"""
Return `UnknownDevice` based on any virtual device.
"""
@@ -1252,9 +1237,9 @@ class DeviceAssignment:
def __init__(
self,
device: VirtualDevice,
- frontend_domain=None,
- options=None,
- mode: Union[str, AssignmentMode] = "manual",
+ frontend_domain: QubesVM | None=None,
+ options: dict[str, object] | None=None,
+ mode: str | AssignmentMode = AssignmentMode.MANUAL,
):
if isinstance(device, DeviceInfo):
device = VirtualDevice(device.port, device.device_id)
@@ -1272,12 +1257,12 @@ def new(
backend_domain: QubesVM,
port_id: str,
devclass: str,
- device_id: Optional[str] = None,
+ device_id: str | None = None,
*,
- frontend_domain: Optional[QubesVM] = None,
- options=None,
- mode: Union[str, AssignmentMode] = "manual",
- ) -> "DeviceAssignment":
+ frontend_domain: QubesVM | None = None,
+ options: dict[str, object] | None=None,
+ mode: str | AssignmentMode = AssignmentMode.MANUAL,
+ ) -> DeviceAssignment:
"""Helper method to create a DeviceAssignment object."""
return cls(
VirtualDevice(Port(backend_domain, port_id, devclass), device_id),
@@ -1286,7 +1271,7 @@ def new(
mode,
)
- def clone(self, **kwargs):
+ def clone(self, **kwargs) -> DeviceAssignment:
"""
Clone object and substitute attributes with explicitly given.
"""
@@ -1297,23 +1282,23 @@ def clone(self, **kwargs):
"frontend_domain": self.frontend_domain,
}
attr.update(kwargs)
- return self.__class__(**attr)
+ return self.__class__(**attr) # type: ignore
- def __repr__(self):
+ def __repr__(self) -> str:
return f"{self.virtual_device!r}"
@property
- def repr_for_qarg(self):
+ def repr_for_qarg(self) -> str:
"""Object representation for qrexec argument"""
return self.virtual_device.repr_for_qarg
- def __str__(self):
+ def __str__(self) -> str:
return f"{self.virtual_device}"
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(self.virtual_device)
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
if isinstance(other, (VirtualDevice, DeviceAssignment)):
result = (
self.port == other.port and self.device_id == other.device_id
@@ -1321,7 +1306,7 @@ def __eq__(self, other):
return result
return False
- def __lt__(self, other):
+ def __lt__(self, other: object) -> bool:
if isinstance(other, DeviceAssignment):
return self.virtual_device < other.virtual_device
if isinstance(other, VirtualDevice):
@@ -1332,7 +1317,7 @@ def __lt__(self, other):
)
@property
- def backend_domain(self) -> Optional[QubesVM]:
+ def backend_domain(self) -> QubesVM | None:
# pylint: disable=missing-function-docstring
return self.virtual_device.backend_domain
@@ -1357,9 +1342,9 @@ def device_id(self) -> str:
return self.virtual_device.device_id
@property
- def devices(self) -> List[DeviceInfo]:
+ def devices(self) -> list[DeviceInfo]:
"""Get DeviceInfo objects corresponding to this DeviceAssignment"""
- result: List[DeviceInfo] = []
+ result: list[DeviceInfo] = []
if not self.backend_domain:
return result
if self.port_id != "*":
@@ -1399,17 +1384,17 @@ def port(self) -> Port:
return Port(self.backend_domain, self.port_id, self.devclass)
@property
- def frontend_domain(self) -> Optional[QubesVM]:
+ def frontend_domain(self) -> QubesVM | None:
"""Which domain the device is attached/assigned to."""
return self.__frontend_domain
@frontend_domain.setter
- def frontend_domain(self, frontend_domain: Optional[Union[str, QubesVM]]):
+ def frontend_domain(self, frontend_domain: str | QubesVM | None) -> None:
"""Which domain the device is attached/assigned to."""
if isinstance(frontend_domain, str):
if not self.backend_domain:
raise ProtocolError("Cannot determine backend domain")
- self.__frontend_domain: Optional[QubesVM] = (
+ self.__frontend_domain: QubesVM | None = (
self.backend_domain.app.domains[frontend_domain]
)
else:
@@ -1448,12 +1433,12 @@ def attach_automatically(self) -> bool:
)
@property
- def options(self) -> Dict[str, Any]:
+ def options(self) -> dict[str, object]:
"""Device options (same as in the legacy API)."""
return self.__options
@options.setter
- def options(self, options: Optional[Dict[str, Any]]):
+ def options(self, options: dict[str, object] | None) -> None:
"""Device options (same as in the legacy API)."""
self.__options = options or {}
@@ -1482,7 +1467,7 @@ def deserialize(
cls,
serialization: bytes,
expected_device: VirtualDevice,
- ) -> "DeviceAssignment":
+ ) -> DeviceAssignment:
"""
Recovers a serialized object, see: :py:meth:`serialize`.
"""
@@ -1497,7 +1482,7 @@ def _deserialize(
cls,
untrusted_serialization: bytes,
expected_device: VirtualDevice,
- ) -> "DeviceAssignment":
+ ) -> DeviceAssignment:
"""
Actually deserializes the object.
"""
diff --git a/qubesadmin/events/__init__.py b/qubesadmin/events/__init__.py
index 9015f580..d2581867 100644
--- a/qubesadmin/events/__init__.py
+++ b/qubesadmin/events/__init__.py
@@ -246,6 +246,7 @@ def handle(self, subject_name: str | None, event: str, **kwargs) -> None:
elif event in ('domain-pre-start', 'domain-start', 'domain-shutdown',
'domain-paused', 'domain-unpaused',
'domain-start-failed'):
+ assert subject is not None
self.app._update_power_state_cache(subject, event, **kwargs)
subject.devices.clear_cache()
elif event == 'connection-established':
@@ -257,6 +258,7 @@ def handle(self, subject_name: str | None, event: str, **kwargs) -> None:
"device-unassign",
"device-assignment-changed"
):
+ assert subject is not None
devclass = event.split(":")[1]
subject.devices[devclass]._assignment_cache = None
elif event.split(":")[0] in (
@@ -264,9 +266,11 @@ def handle(self, subject_name: str | None, event: str, **kwargs) -> None:
"device-detach",
"device-removed"
):
+ assert subject is not None
devclass = event.split(":")[1]
subject.devices[devclass]._attachment_cache = None
if event.split(":")[0] in ("device-removed",):
+ assert subject is not None
devclass = event.split(":")[1]
port_id = kwargs.get("port", ":").split(":")[1]
try:
diff --git a/qubesadmin/events/utils.py b/qubesadmin/events/utils.py
index 9bcf288e..a946c3c4 100644
--- a/qubesadmin/events/utils.py
+++ b/qubesadmin/events/utils.py
@@ -24,7 +24,6 @@
from typing import Iterable
import qubesadmin.events
-import qubesadmin.exc
from qubesadmin.events import EventsDispatcher
from qubesadmin.vm import QubesVM
diff --git a/qubesadmin/tests/__init__.py b/qubesadmin/tests/__init__.py
index 87a0978e..dcdad57d 100644
--- a/qubesadmin/tests/__init__.py
+++ b/qubesadmin/tests/__init__.py
@@ -150,6 +150,9 @@ class QubesTest(qubesadmin.app.QubesBase):
expected_calls = None
actual_calls = None
service_calls = None
+ # This is a special value used only for tests
+ # This is not valid / expected outside of tests
+ qubesd_connection_type: str = "none" # type: ignore
def __init__(self):
super().__init__()
diff --git a/qubesadmin/tests/app.py b/qubesadmin/tests/app.py
index 3b8796cd..37e8ba12 100644
--- a/qubesadmin/tests/app.py
+++ b/qubesadmin/tests/app.py
@@ -186,6 +186,71 @@ def test_012_getitem_cached_object(self):
self.assertIsNot(vm1, vm4)
self.assertAllCalled()
+ def test_013_get_blind_not_in_membership(self):
+ """Blind objects must not appear in membership operations."""
+ self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \
+ b'0\x00test-vm class=AppVM state=Running\n'
+ self.assertIn('test-vm', self.app.domains)
+ self.app.domains.get_blind('other-vm')
+ self.assertNotIn('other-vm', self.app.domains)
+ self.assertNotIn('other-vm', self.app.domains.keys())
+ self.assertEqual([vm.name for vm in self.app.domains], ['test-vm'])
+ self.assertAllCalled()
+
+ def test_014_refresh_cache_forces_reload(self):
+ """refresh_cache() must trigger a new admin.vm.List call even if
+ already initialised."""
+ self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \
+ b'0\x00test-vm class=AppVM state=Running\n'
+ self.assertIn('test-vm', self.app.domains)
+ self.assertNotIn('test-vm2', self.app.domains)
+ self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \
+ b'0\x00test-vm2 class=AppVM state=Running\n'
+ self.app.domains.refresh_cache()
+ self.assertNotIn('test-vm', self.app.domains)
+ self.assertIn('test-vm2', self.app.domains)
+ self.assertAllCalled()
+
+ def test_015_delitem_targeted_cleanup(self):
+ """del domains[name] must remove the VM immediately without a
+ further admin.vm.List call."""
+ self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \
+ b'0\x00test-vm class=AppVM state=Running\n'
+ self.app.expected_calls[('test-vm', 'admin.vm.Remove', None, None)] = \
+ b'0\x00'
+ self.assertIn('test-vm', self.app.domains)
+ del self.app.domains['test-vm']
+ self.assertNotIn('test-vm', self.app.domains)
+ self.assertAllCalled()
+
+ def test_016_rename_preserves_identity(self):
+ """After a rename, refresh_cache() must return the same object
+ under the new name."""
+ self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \
+ b'0\x00test-vm class=AppVM state=Running\n'
+ vm = self.app.domains['test-vm']
+ # pylint: disable=protected-access
+ vm._method_dest = 'new-name'
+ self.app.domains.clear_cache()
+ self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \
+ b'0\x00new-name class=AppVM state=Running\n'
+ vm2 = self.app.domains['new-name']
+ self.assertIs(vm, vm2)
+ self.assertAllCalled()
+
+ def test_017_clear_cache_invalidate_name(self):
+ """clear_cache(invalidate_name) must drop that VM's object from the
+ cache so a fresh one is created on next access."""
+ self.app.expected_calls[('dom0', 'admin.vm.List', None, None)] = \
+ b'0\x00test-vm class=AppVM state=Running\n'
+ vm1 = self.app.domains['test-vm']
+ self.app.domains.clear_cache()
+ vm2 = self.app.domains['test-vm']
+ self.assertIs(vm1, vm2)
+ self.app.domains.clear_cache(invalidate_name = 'test-vm')
+ vm3 = self.app.domains['test-vm']
+ self.assertIsNot(vm1, vm3)
+ self.assertAllCalled()
class TC_10_QubesBase(qubesadmin.tests.QubesTestCase):
diff --git a/qubesadmin/tests/tools/qvm_ls.py b/qubesadmin/tests/tools/qvm_ls.py
index 23a1d6f0..27e47ed2 100644
--- a/qubesadmin/tests/tools/qvm_ls.py
+++ b/qubesadmin/tests/tools/qvm_ls.py
@@ -380,17 +380,10 @@ def test_101_list_selected(self):
b'0\x00vm1 class=AppVM state=Running\n' \
b'template1 class=TemplateVM state=Halted\n' \
b'sys-net class=AppVM state=Running\n'
- self.app.expected_calls[
- ('vm1', 'admin.vm.CurrentState', None, None)] = \
- b'0\x00power_state=Running'
- self.app.expected_calls[
- ('sys-net', 'admin.vm.CurrentState', None, None)] = \
- b'0\x00power_state=Running'
props = {
'label': 'type=label green',
'template': 'type=vm template1',
'netvm': 'type=vm sys-net',
-# 'virt_mode': b'type=str pv',
}
self.app.expected_calls[
('vm1', 'admin.vm.property.GetAll', None, None)] = \
diff --git a/qubesadmin/tools/qvm_template.py b/qubesadmin/tools/qvm_template.py
index 95a2b77d..6f2a5758 100644
--- a/qubesadmin/tools/qvm_template.py
+++ b/qubesadmin/tools/qvm_template.py
@@ -1200,7 +1200,7 @@ def verify(rpmfile, reponame, package_hdr=None):
name,
target + PATH_PREFIX + '/' + name])
- app.domains.refresh_cache(force=True)
+ app.domains.refresh_cache()
tpl = app.domains[name]
tpl.features['template-name'] = name
diff --git a/qubesadmin/utils.py b/qubesadmin/utils.py
index 68b89af5..9f6d5982 100644
--- a/qubesadmin/utils.py
+++ b/qubesadmin/utils.py
@@ -29,6 +29,7 @@
import re
import qubesadmin.exc
+from qubesadmin.exc import QubesValueError
def parse_size(size):
@@ -210,3 +211,25 @@ def release(self):
"""Unlock the file and close the file object"""
fcntl.lockf(self.file, fcntl.LOCK_UN)
self.file.close()
+
+
+def qbool(value: str | int | bool) -> bool:
+ """
+ Property setter for boolean properties.
+
+ It accepts (case-insensitive) ``'0'``, ``'no'`` and ``false`` as
+ :py:obj:`False` and ``'1'``, ``'yes'`` and ``'true'`` as
+ :py:obj:`True`.
+ """
+
+ if isinstance(value, str):
+ lcvalue = value.lower()
+ if lcvalue in ("0", "no", "false", "off"):
+ return False
+ if lcvalue in ("1", "yes", "true", "on"):
+ return True
+ raise QubesValueError(
+ "Invalid literal for boolean property: {!r}".format(value)
+ )
+
+ return bool(value)
diff --git a/qubesadmin/vm/__init__.py b/qubesadmin/vm/__init__.py
index 00785db8..13e544d6 100644
--- a/qubesadmin/vm/__init__.py
+++ b/qubesadmin/vm/__init__.py
@@ -19,35 +19,38 @@
# with this program; if not, see .
"""Qubes VM objects."""
-
+from __future__ import annotations
import logging
import shlex
import subprocess
+import typing
import warnings
from logging import Logger
+from typing import Literal
-import qubesadmin.base
import qubesadmin.exc
import qubesadmin.storage
import qubesadmin.features
import qubesadmin.devices
import qubesadmin.device_protocol
import qubesadmin.firewall
-import qubesadmin.tags
+
+if typing.TYPE_CHECKING:
+ import qubesadmin.base
+
+Klass = Literal["AppVM", "AdminVM", "TemplateVM", "DispVM", "StandaloneVM"]
+PowerState = Literal["Transient", "Running", "Halted", "Paused",
+"Suspended", "Halting", "Dying", "Crashed", "NA"]
class QubesVM(qubesadmin.base.PropertyHolder):
"""Qubes domain."""
log: Logger
-
tags: qubesadmin.tags.Tags
-
features: qubesadmin.features.Features
-
devices: qubesadmin.devices.DeviceManager
-
firewall: qubesadmin.firewall.Firewall
def __init__(self, app, name, klass=None, power_state=None):
@@ -64,7 +67,7 @@ def __init__(self, app, name, klass=None, power_state=None):
self.firewall = qubesadmin.firewall.Firewall(self)
@property
- def name(self):
+ def name(self) -> str:
"""Domain name"""
return self._method_dest
@@ -80,7 +83,7 @@ def name(self, new_value):
self._volumes = None
self.app.domains.clear_cache()
- def __str__(self):
+ def __str__(self) -> str:
return self._method_dest
def __lt__(self, other):
@@ -205,7 +208,7 @@ def get_power_state(self):
"""
- if self._power_state_cache is not None:
+ if self._power_state_cache is not None and self.app.cache_enabled:
return self._power_state_cache
try:
power_state = self._get_current_state()["power_state"]