diff --git a/qubesadmin/devices.py b/qubesadmin/devices.py index 228969d1..a4d254a5 100644 --- a/qubesadmin/devices.py +++ b/qubesadmin/devices.py @@ -31,8 +31,10 @@ class is implemented by an extension. Devices are identified by pair of (backend domain, `port_id`), where `port_id` is :py:class:`str`. """ +from __future__ import annotations import itertools -from typing import Iterable +from typing import TYPE_CHECKING +from collections.abc import Iterable, Iterator import qubesadmin.exc from qubesadmin.device_protocol import ( @@ -43,6 +45,8 @@ class is implemented by an extension. VirtualDevice, AssignmentMode, DeviceInterface, ) +if TYPE_CHECKING: + from qubesadmin.vm import QubesVM class DeviceCollection: @@ -55,7 +59,7 @@ class DeviceCollection: """ - def __init__(self, vm, class_): + def __init__(self, vm: QubesVM, class_: str): self._vm = vm self._class = class_ self._dev_cache = {} @@ -268,7 +272,7 @@ def get_exposed_devices(self) -> Iterable[DeviceInfo]: def update_assignment( self, device: VirtualDevice, required: AssignmentMode - ): + ) -> None: """ Update assignment of already attached device. @@ -288,7 +292,7 @@ def update_assignment( __iter__ = get_exposed_devices - def clear_cache(self): + def clear_cache(self) -> None: """ Clear cache of available devices. """ @@ -296,7 +300,7 @@ def clear_cache(self): self._assignment_cache = None self._attachment_cache = None - def __getitem__(self, item): + def __getitem__(self, item: object) -> DeviceInfo: """Get device object with given port_id. :returns: py:class:`DeviceInfo` @@ -316,6 +320,8 @@ def __getitem__(self, item): return dev # if still nothing, return UnknownDevice instance for the reason # explained in docstring, but don't cache it + if not isinstance(item, str | None): + raise NotImplementedError return UnknownDevice(Port(self._vm, item, devclass=self._class)) @@ -325,21 +331,22 @@ class DeviceManager(dict): :param vm: VM for which we manage devices """ - def __init__(self, vm): + def __init__(self, vm: QubesVM): super().__init__() self._vm = vm - def __missing__(self, key): + def __missing__(self, key: str) -> DeviceCollection: self[key] = DeviceCollection(self._vm, key) return self[key] - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._vm.app.list_deviceclass()) - def keys(self): + + def keys(self) -> list[str]: # type: ignore[override] return self._vm.app.list_deviceclass() - def deny(self, *interfaces: Iterable[DeviceInterface]): + def deny(self, *interfaces: Iterable[DeviceInterface]) -> None: """ Deny a device with any of the given interfaces from attaching to the VM. """ @@ -350,7 +357,7 @@ def deny(self, *interfaces: Iterable[DeviceInterface]): "".join(repr(ifc) for ifc in interfaces).encode('ascii'), ) - def allow(self, *interfaces: Iterable[DeviceInterface]): + def allow(self, *interfaces: Iterable[DeviceInterface]) -> None: """ Remove given interfaces from denied list. """ @@ -361,7 +368,7 @@ def allow(self, *interfaces: Iterable[DeviceInterface]): "".join(repr(ifc) for ifc in interfaces).encode('ascii'), ) - def clear_cache(self): + def clear_cache(self) -> None: """Clear cache of all available device classes""" for devclass in self.values(): devclass.clear_cache() diff --git a/qubesadmin/exc.py b/qubesadmin/exc.py index 737bd6d9..282d318b 100644 --- a/qubesadmin/exc.py +++ b/qubesadmin/exc.py @@ -26,7 +26,7 @@ class QubesException(Exception): """Exception that can be shown to the user""" - def __init__(self, message_format, *args, **kwargs): + def __init__(self, message_format: str, *args, **kwargs): # TODO: handle translations super().__init__( message_format % tuple(int(d) if d.isdigit() else d for d in args), @@ -37,7 +37,7 @@ def __init__(self, message_format, *args, **kwargs): class QubesVMNotFoundError(QubesException, KeyError): """Domain cannot be found in the system""" - def __str__(self): + def __str__(self) -> str: # KeyError overrides __str__ method return QubesException.__str__(self) @@ -139,7 +139,7 @@ class QubesMemoryError(QubesVMError, MemoryError): class QubesFeatureNotFoundError(QubesException, KeyError): """Feature not set for a given domain""" - def __str__(self): + def __str__(self) -> str: # KeyError overrides __str__ method return QubesException.__str__(self) @@ -147,7 +147,7 @@ def __str__(self): class QubesTagNotFoundError(QubesException, KeyError): """Tag not set for a given domain""" - def __str__(self): + def __str__(self) -> str: # KeyError overrides __str__ method return QubesException.__str__(self) @@ -155,7 +155,7 @@ def __str__(self): class QubesLabelNotFoundError(QubesException, KeyError): """Label does not exists""" - def __str__(self): + def __str__(self) -> str: # KeyError overrides __str__ method return QubesException.__str__(self) @@ -213,7 +213,7 @@ class QubesDaemonCommunicationError(QubesException): class BackupRestoreError(QubesException): """Restoring a backup failed""" - def __init__(self, msg, backup_log=None): + def __init__(self, msg: str, backup_log: bytes | None=None): super().__init__(msg) self.backup_log = backup_log @@ -228,7 +228,7 @@ class QubesPropertyAccessError(QubesDaemonAccessError, AttributeError): """Failed to read/write property value, cause is unknown (insufficient permissions, no such property, invalid value, other)""" - def __init__(self, prop): + def __init__(self, prop: str): super().__init__("Failed to access '%s' property" % prop) diff --git a/qubesadmin/features.py b/qubesadmin/features.py index 261b8571..9cc45660 100644 --- a/qubesadmin/features.py +++ b/qubesadmin/features.py @@ -19,7 +19,16 @@ # with this program; if not, see . '''VM features interface''' +from __future__ import annotations +import typing +from typing import TypeVar +from collections.abc import Iterator, Generator + +if typing.TYPE_CHECKING: + from qubesadmin.vm import QubesVM + +T = TypeVar('T') class Features: '''Manager of the features. @@ -33,14 +42,14 @@ class Features: false in Python) will result in string `'0'`, which is considered true. ''' - def __init__(self, vm): + def __init__(self, vm: QubesVM): super().__init__() self.vm = vm - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: self.vm.qubesd_call(self.vm.name, 'admin.vm.feature.Remove', key) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: object) -> None: if isinstance(value, bool): # False value needs to be serialized as empty string self.vm.qubesd_call(self.vm.name, 'admin.vm.feature.Set', key, @@ -49,25 +58,30 @@ def __setitem__(self, key, value): self.vm.qubesd_call(self.vm.name, 'admin.vm.feature.Set', key, str(value).encode()) - def __getitem__(self, item): + def __getitem__(self, item: str) -> str: return self.vm.qubesd_call( self.vm.name, 'admin.vm.feature.Get', item).decode('utf-8') - def __iter__(self): + def __iter__(self) -> Iterator[str]: qubesd_response = self.vm.qubesd_call(self.vm.name, 'admin.vm.feature.List') return iter(qubesd_response.decode('utf-8').splitlines()) keys = __iter__ - def items(self): + def items(self) -> Generator[tuple[str, str]]: '''Return iterable of pairs (feature, value)''' for key in self: yield key, self[key] NO_DEFAULT = object() - def get(self, item, default=None): + @typing.overload + def get(self, item: str) -> str | None: ... + @typing.overload + def get(self, item: str, default: T) -> str | T: ... + # Overloaded to handle default None return type + def get(self, item: str, default: object = None) -> object: '''Get a feature, return default value if missing.''' try: return self[item] @@ -76,7 +90,13 @@ def get(self, item, default=None): raise return default - def check_with_template(self, feature, default=None): + @typing.overload + def check_with_template(self, item: str) -> str | None: ... + @typing.overload + def check_with_template(self, item: str, default: T) -> str | T: ... + # Overloaded to handle default None return type + def check_with_template(self, feature: str, + default: object = None) -> object: ''' Check if the vm's template has the specified feature. ''' try: qubesd_response = self.vm.qubesd_call( diff --git a/qubesadmin/firewall.py b/qubesadmin/firewall.py index 438d489f..8fd37159 100644 --- a/qubesadmin/firewall.py +++ b/qubesadmin/firewall.py @@ -19,31 +19,35 @@ # with this program; if not, see . '''Firewall configuration interface''' - +from __future__ import annotations import datetime import socket import string +from typing import SupportsInt, TYPE_CHECKING + +if TYPE_CHECKING: + from qubesadmin.vm import QubesVM class RuleOption: '''Base class for a single rule element''' - def __init__(self, value): + def __init__(self, value: object) -> None: self._value = str(value) @property - def rule(self): + def rule(self) -> str: '''API representation of this rule element''' raise NotImplementedError @property - def pretty_value(self): + def pretty_value(self) -> str: '''Human readable representation''' return str(self) - def __str__(self): + def __str__(self) -> str: return self._value - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return str(self) == other @@ -51,7 +55,7 @@ def __eq__(self, other): class RuleChoice(RuleOption): '''Base class for multiple-choices rule elements''' # pylint: disable=abstract-method - def __init__(self, value): + def __init__(self, value: object) -> None: super().__init__(value) self.allowed_values = \ [v for k, v in self.__class__.__dict__.items() @@ -63,30 +67,30 @@ def __init__(self, value): class Action(RuleChoice): '''Rule action''' - accept = 'accept' - drop = 'drop' + accept: str = 'accept' + drop: str = 'drop' @property - def rule(self): + def rule(self) -> str: '''API representation of this rule element''' return 'action=' + str(self) class Proto(RuleChoice): '''Protocol name''' - tcp = 'tcp' - udp = 'udp' - icmp = 'icmp' + tcp: str = 'tcp' + udp: str = 'udp' + icmp: str = 'icmp' @property - def rule(self): + def rule(self) -> str: '''API representation of this rule element''' return 'proto=' + str(self) class DstHost(RuleOption): '''Represent host/network address: either IPv4, IPv6, or DNS name''' - def __init__(self, value, prefixlen=None): + def __init__(self, value: str, prefixlen: int | None=None): # TODO: in python >= 3.3 ipaddress module could be used if value.count('/') > 1: raise ValueError('Too many /: ' + value) @@ -125,8 +129,8 @@ def __init__(self, value, prefixlen=None): if not all(c in safe_set for c in value): raise ValueError('Invalid hostname') else: - host, prefixlen = value.split('/', 1) - prefixlen = int(prefixlen) + host, prefixlen_str = value.split('/', 1) + prefixlen = int(prefixlen_str) if prefixlen < 0: raise ValueError('netmask must be non-negative') self.prefixlen = prefixlen @@ -150,7 +154,7 @@ def __init__(self, value, prefixlen=None): super().__init__(value) @property - def rule(self): + def rule(self) -> str | None: '''API representation of this rule element''' if self.prefixlen == 0 and self.type != 'dsthost': # 0.0.0.0/0 or ::/0, doesn't limit to any particular host, @@ -161,9 +165,8 @@ def rule(self): class DstPorts(RuleOption): '''Destination port(s), for TCP/UDP only''' - def __init__(self, value): - if isinstance(value, int): - value = str(value) + def __init__(self, value: int | str) -> None: + value: str = str(value) if isinstance(value, int) else value if value.count('-') == 1: self.range = [int(x) for x in value.split('-', 1)] elif not value.count('-'): @@ -179,54 +182,54 @@ def __init__(self, value): else '{!s}-{!s}'.format(*self.range)) @property - def rule(self): + def rule(self) -> str: '''API representation of this rule element''' return 'dstports=' + '{!s}-{!s}'.format(*self.range) class IcmpType(RuleOption): '''ICMP packet type''' - def __init__(self, value): + def __init__(self, value: SupportsInt) -> None: super().__init__(value) value = int(value) if value < 0 or value > 255: raise ValueError('ICMP type out of range') @property - def rule(self): + def rule(self) -> str: '''API representation of this rule element''' return 'icmptype=' + str(self) class SpecialTarget(RuleChoice): '''Special destination''' - dns = 'dns' + dns: str = 'dns' @property - def rule(self): + def rule(self) -> str: '''API representation of this rule element''' return 'specialtarget=' + str(self) class Expire(RuleOption): '''Rule expire time''' - def __init__(self, value): + def __init__(self, value: SupportsInt) -> None: super().__init__(value) self.datetime = datetime.datetime.fromtimestamp(int(value), datetime.timezone.utc) @property - def rule(self): + def rule(self) -> str: '''API representation of this rule element''' return 'expire=' + str(self) @property - def expired(self): + def expired(self) -> bool: '''Has this rule expired already?''' return self.datetime < datetime.datetime.now(datetime.timezone.utc) @property - def pretty_value(self): + def pretty_value(self) -> str: '''Human readable representation''' now = datetime.datetime.now(datetime.timezone.utc) duration = (self.datetime - now).total_seconds() @@ -236,7 +239,7 @@ def pretty_value(self): class Comment(RuleOption): '''User comment''' @property - def rule(self): + def rule(self) -> str: '''API representation of this rule element''' return 'comment=' + str(self) @@ -244,20 +247,20 @@ def rule(self): class Rule: '''A single firewall rule''' - def __init__(self, rule, **kwargs): + def __init__(self, rule: str | None, **kwargs): '''Single firewall rule :param xml: XML element describing rule, or None :param kwargs: rule elements ''' - self._action = None - self._proto = None - self._dsthost = None - self._dstports = None - self._icmptype = None - self._specialtarget = None - self._expire = None - self._comment = None + self._action: Action | None = None + self._proto: Proto | None = None + self._dsthost: DstHost | None = None + self._dstports: DstPorts | None = None + self._icmptype: IcmpType | None = None + self._specialtarget: SpecialTarget | None = None + self._expire: Expire | None = None + self._comment: Comment | None = None rule_dict = {} if rule is not None: @@ -287,23 +290,23 @@ def __init__(self, rule, **kwargs): raise ValueError('missing action=') @property - def action(self): + def action(self) -> Action | None: '''rule action''' return self._action @action.setter - def action(self, value): + def action(self, value: object) -> None: if not isinstance(value, Action): value = Action(value) self._action = value @property - def proto(self): + def proto(self) -> Proto | None: '''protocol to match''' return self._proto @proto.setter - def proto(self, value): + def proto(self, value: object) -> None: if value is not None and not isinstance(value, Proto): value = Proto(value) if value not in ('tcp', 'udp'): @@ -313,23 +316,23 @@ def proto(self, value): self._proto = value @property - def dsthost(self): + def dsthost(self) -> DstHost | None: '''destination host/network''' return self._dsthost @dsthost.setter - def dsthost(self, value): + def dsthost(self, value: str | DstHost | None) -> None: if value is not None and not isinstance(value, DstHost): value = DstHost(value) self._dsthost = value @property - def dstports(self): + def dstports(self) -> DstPorts | None: ''''Destination port(s) (for \'tcp\' and \'udp\' protocol only)''' return self._dstports @dstports.setter - def dstports(self, value): + def dstports(self, value: str | int | DstPorts | None) -> None: if value is not None: if self.proto not in ('tcp', 'udp'): raise ValueError( @@ -339,12 +342,12 @@ def dstports(self, value): self._dstports = value @property - def icmptype(self): + def icmptype(self) -> IcmpType | None: '''ICMP packet type (for \'icmp\' protocol only)''' return self._icmptype @icmptype.setter - def icmptype(self, value): + def icmptype(self, value: IcmpType | SupportsInt | None) -> None: if value is not None: if self.proto not in ('icmp',): raise ValueError('icmptype valid only for \'icmp\' protocol') @@ -353,40 +356,40 @@ def icmptype(self, value): self._icmptype = value @property - def specialtarget(self): + def specialtarget(self) -> SpecialTarget | None: '''Special target, for now only \'dns\' supported''' return self._specialtarget @specialtarget.setter - def specialtarget(self, value): + def specialtarget(self, value: object) -> None: if not isinstance(value, SpecialTarget): value = SpecialTarget(value) self._specialtarget = value @property - def expire(self): + def expire(self) -> Expire | None: '''Timestamp (UNIX epoch) on which this rule expire''' return self._expire @expire.setter - def expire(self, value): + def expire(self, value: Expire | SupportsInt) -> None: if not isinstance(value, Expire): value = Expire(value) self._expire = value @property - def comment(self): + def comment(self) -> Comment | None: '''User comment''' return self._comment @comment.setter - def comment(self, value): + def comment(self, value: object) -> None: if not isinstance(value, Comment): value = Comment(value) self._comment = value @property - def rule(self): + def rule(self) -> str: '''API representation of this rule''' values = [] # comment must be the last one @@ -400,36 +403,36 @@ def rule(self): values.append(value.rule) return ' '.join(values) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Rule): return self.rule == other.rule if isinstance(other, str): return self.rule == str return NotImplemented - def __repr__(self): + def __repr__(self) -> str: return 'Rule(\'{}\')'.format(self.rule) class Firewall: '''Firewal manager for a VM''' - def __init__(self, vm): + def __init__(self, vm: QubesVM): self.vm = vm self._rules: list[Rule] = [] self._policy = None self._loaded = False - def load_rules(self): + def load_rules(self) -> None: '''Force (re-)loading firewall rules''' rules_str = self.vm.qubesd_call(None, 'admin.vm.firewall.Get') - rules = [] + rules: list[Rule] = [] for rule_str in rules_str.decode().splitlines(): rules.append(Rule(rule_str)) self._rules = rules self._loaded = True @property - def rules(self): + def rules(self) -> list[Rule]: '''Firewall rules You can either copy them, edit and then assign new rules list to this @@ -442,11 +445,11 @@ def rules(self): return self._rules @rules.setter - def rules(self, value): + def rules(self, value: list[Rule]) -> None: self.save_rules(value) self._rules = value - def save_rules(self, rules=None): + def save_rules(self, rules: list[Rule] | None=None) -> None: '''Save firewall rules. Needs to be called after in-place editing :py:attr:`rules`. ''' @@ -457,11 +460,11 @@ def save_rules(self, rules=None): for rule in rules)).encode('ascii')) @property - def policy(self): + def policy(self) -> Action: '''Default action to take if no rule matches''' return Action('drop') - def reload(self): + def reload(self) -> None: '''Force reload the same firewall rules. Can be used for example to force again names resolution.