diff --git a/qubesadmin/devices.py b/qubesadmin/devices.py
index 228969d1..a4d254a5 100644
--- a/qubesadmin/devices.py
+++ b/qubesadmin/devices.py
@@ -31,8 +31,10 @@ class is implemented by an extension.
Devices are identified by pair of (backend domain, `port_id`), where `port_id`
is :py:class:`str`.
"""
+from __future__ import annotations
import itertools
-from typing import Iterable
+from typing import TYPE_CHECKING
+from collections.abc import Iterable, Iterator
import qubesadmin.exc
from qubesadmin.device_protocol import (
@@ -43,6 +45,8 @@ class is implemented by an extension.
VirtualDevice,
AssignmentMode, DeviceInterface,
)
+if TYPE_CHECKING:
+ from qubesadmin.vm import QubesVM
class DeviceCollection:
@@ -55,7 +59,7 @@ class DeviceCollection:
"""
- def __init__(self, vm, class_):
+ def __init__(self, vm: QubesVM, class_: str):
self._vm = vm
self._class = class_
self._dev_cache = {}
@@ -268,7 +272,7 @@ def get_exposed_devices(self) -> Iterable[DeviceInfo]:
def update_assignment(
self, device: VirtualDevice, required: AssignmentMode
- ):
+ ) -> None:
"""
Update assignment of already attached device.
@@ -288,7 +292,7 @@ def update_assignment(
__iter__ = get_exposed_devices
- def clear_cache(self):
+ def clear_cache(self) -> None:
"""
Clear cache of available devices.
"""
@@ -296,7 +300,7 @@ def clear_cache(self):
self._assignment_cache = None
self._attachment_cache = None
- def __getitem__(self, item):
+ def __getitem__(self, item: object) -> DeviceInfo:
"""Get device object with given port_id.
:returns: py:class:`DeviceInfo`
@@ -316,6 +320,8 @@ def __getitem__(self, item):
return dev
# if still nothing, return UnknownDevice instance for the reason
# explained in docstring, but don't cache it
+ if not isinstance(item, str | None):
+ raise NotImplementedError
return UnknownDevice(Port(self._vm, item, devclass=self._class))
@@ -325,21 +331,22 @@ class DeviceManager(dict):
:param vm: VM for which we manage devices
"""
- def __init__(self, vm):
+ def __init__(self, vm: QubesVM):
super().__init__()
self._vm = vm
- def __missing__(self, key):
+ def __missing__(self, key: str) -> DeviceCollection:
self[key] = DeviceCollection(self._vm, key)
return self[key]
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
return iter(self._vm.app.list_deviceclass())
- def keys(self):
+
+ def keys(self) -> list[str]: # type: ignore[override]
return self._vm.app.list_deviceclass()
- def deny(self, *interfaces: Iterable[DeviceInterface]):
+ def deny(self, *interfaces: Iterable[DeviceInterface]) -> None:
"""
Deny a device with any of the given interfaces from attaching to the VM.
"""
@@ -350,7 +357,7 @@ def deny(self, *interfaces: Iterable[DeviceInterface]):
"".join(repr(ifc) for ifc in interfaces).encode('ascii'),
)
- def allow(self, *interfaces: Iterable[DeviceInterface]):
+ def allow(self, *interfaces: Iterable[DeviceInterface]) -> None:
"""
Remove given interfaces from denied list.
"""
@@ -361,7 +368,7 @@ def allow(self, *interfaces: Iterable[DeviceInterface]):
"".join(repr(ifc) for ifc in interfaces).encode('ascii'),
)
- def clear_cache(self):
+ def clear_cache(self) -> None:
"""Clear cache of all available device classes"""
for devclass in self.values():
devclass.clear_cache()
diff --git a/qubesadmin/exc.py b/qubesadmin/exc.py
index 737bd6d9..282d318b 100644
--- a/qubesadmin/exc.py
+++ b/qubesadmin/exc.py
@@ -26,7 +26,7 @@
class QubesException(Exception):
"""Exception that can be shown to the user"""
- def __init__(self, message_format, *args, **kwargs):
+ def __init__(self, message_format: str, *args, **kwargs):
# TODO: handle translations
super().__init__(
message_format % tuple(int(d) if d.isdigit() else d for d in args),
@@ -37,7 +37,7 @@ def __init__(self, message_format, *args, **kwargs):
class QubesVMNotFoundError(QubesException, KeyError):
"""Domain cannot be found in the system"""
- def __str__(self):
+ def __str__(self) -> str:
# KeyError overrides __str__ method
return QubesException.__str__(self)
@@ -139,7 +139,7 @@ class QubesMemoryError(QubesVMError, MemoryError):
class QubesFeatureNotFoundError(QubesException, KeyError):
"""Feature not set for a given domain"""
- def __str__(self):
+ def __str__(self) -> str:
# KeyError overrides __str__ method
return QubesException.__str__(self)
@@ -147,7 +147,7 @@ def __str__(self):
class QubesTagNotFoundError(QubesException, KeyError):
"""Tag not set for a given domain"""
- def __str__(self):
+ def __str__(self) -> str:
# KeyError overrides __str__ method
return QubesException.__str__(self)
@@ -155,7 +155,7 @@ def __str__(self):
class QubesLabelNotFoundError(QubesException, KeyError):
"""Label does not exists"""
- def __str__(self):
+ def __str__(self) -> str:
# KeyError overrides __str__ method
return QubesException.__str__(self)
@@ -213,7 +213,7 @@ class QubesDaemonCommunicationError(QubesException):
class BackupRestoreError(QubesException):
"""Restoring a backup failed"""
- def __init__(self, msg, backup_log=None):
+ def __init__(self, msg: str, backup_log: bytes | None=None):
super().__init__(msg)
self.backup_log = backup_log
@@ -228,7 +228,7 @@ class QubesPropertyAccessError(QubesDaemonAccessError, AttributeError):
"""Failed to read/write property value, cause is unknown (insufficient
permissions, no such property, invalid value, other)"""
- def __init__(self, prop):
+ def __init__(self, prop: str):
super().__init__("Failed to access '%s' property" % prop)
diff --git a/qubesadmin/features.py b/qubesadmin/features.py
index 261b8571..9cc45660 100644
--- a/qubesadmin/features.py
+++ b/qubesadmin/features.py
@@ -19,7 +19,16 @@
# with this program; if not, see .
'''VM features interface'''
+from __future__ import annotations
+import typing
+from typing import TypeVar
+from collections.abc import Iterator, Generator
+
+if typing.TYPE_CHECKING:
+ from qubesadmin.vm import QubesVM
+
+T = TypeVar('T')
class Features:
'''Manager of the features.
@@ -33,14 +42,14 @@ class Features:
false in Python) will result in string `'0'`, which is considered true.
'''
- def __init__(self, vm):
+ def __init__(self, vm: QubesVM):
super().__init__()
self.vm = vm
- def __delitem__(self, key):
+ def __delitem__(self, key: str) -> None:
self.vm.qubesd_call(self.vm.name, 'admin.vm.feature.Remove', key)
- def __setitem__(self, key, value):
+ def __setitem__(self, key: str, value: object) -> None:
if isinstance(value, bool):
# False value needs to be serialized as empty string
self.vm.qubesd_call(self.vm.name, 'admin.vm.feature.Set', key,
@@ -49,25 +58,30 @@ def __setitem__(self, key, value):
self.vm.qubesd_call(self.vm.name, 'admin.vm.feature.Set', key,
str(value).encode())
- def __getitem__(self, item):
+ def __getitem__(self, item: str) -> str:
return self.vm.qubesd_call(
self.vm.name, 'admin.vm.feature.Get', item).decode('utf-8')
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
qubesd_response = self.vm.qubesd_call(self.vm.name,
'admin.vm.feature.List')
return iter(qubesd_response.decode('utf-8').splitlines())
keys = __iter__
- def items(self):
+ def items(self) -> Generator[tuple[str, str]]:
'''Return iterable of pairs (feature, value)'''
for key in self:
yield key, self[key]
NO_DEFAULT = object()
- def get(self, item, default=None):
+ @typing.overload
+ def get(self, item: str) -> str | None: ...
+ @typing.overload
+ def get(self, item: str, default: T) -> str | T: ...
+ # Overloaded to handle default None return type
+ def get(self, item: str, default: object = None) -> object:
'''Get a feature, return default value if missing.'''
try:
return self[item]
@@ -76,7 +90,13 @@ def get(self, item, default=None):
raise
return default
- def check_with_template(self, feature, default=None):
+ @typing.overload
+ def check_with_template(self, item: str) -> str | None: ...
+ @typing.overload
+ def check_with_template(self, item: str, default: T) -> str | T: ...
+ # Overloaded to handle default None return type
+ def check_with_template(self, feature: str,
+ default: object = None) -> object:
''' Check if the vm's template has the specified feature. '''
try:
qubesd_response = self.vm.qubesd_call(
diff --git a/qubesadmin/firewall.py b/qubesadmin/firewall.py
index 438d489f..8fd37159 100644
--- a/qubesadmin/firewall.py
+++ b/qubesadmin/firewall.py
@@ -19,31 +19,35 @@
# with this program; if not, see .
'''Firewall configuration interface'''
-
+from __future__ import annotations
import datetime
import socket
import string
+from typing import SupportsInt, TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from qubesadmin.vm import QubesVM
class RuleOption:
'''Base class for a single rule element'''
- def __init__(self, value):
+ def __init__(self, value: object) -> None:
self._value = str(value)
@property
- def rule(self):
+ def rule(self) -> str:
'''API representation of this rule element'''
raise NotImplementedError
@property
- def pretty_value(self):
+ def pretty_value(self) -> str:
'''Human readable representation'''
return str(self)
- def __str__(self):
+ def __str__(self) -> str:
return self._value
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
return str(self) == other
@@ -51,7 +55,7 @@ def __eq__(self, other):
class RuleChoice(RuleOption):
'''Base class for multiple-choices rule elements'''
# pylint: disable=abstract-method
- def __init__(self, value):
+ def __init__(self, value: object) -> None:
super().__init__(value)
self.allowed_values = \
[v for k, v in self.__class__.__dict__.items()
@@ -63,30 +67,30 @@ def __init__(self, value):
class Action(RuleChoice):
'''Rule action'''
- accept = 'accept'
- drop = 'drop'
+ accept: str = 'accept'
+ drop: str = 'drop'
@property
- def rule(self):
+ def rule(self) -> str:
'''API representation of this rule element'''
return 'action=' + str(self)
class Proto(RuleChoice):
'''Protocol name'''
- tcp = 'tcp'
- udp = 'udp'
- icmp = 'icmp'
+ tcp: str = 'tcp'
+ udp: str = 'udp'
+ icmp: str = 'icmp'
@property
- def rule(self):
+ def rule(self) -> str:
'''API representation of this rule element'''
return 'proto=' + str(self)
class DstHost(RuleOption):
'''Represent host/network address: either IPv4, IPv6, or DNS name'''
- def __init__(self, value, prefixlen=None):
+ def __init__(self, value: str, prefixlen: int | None=None):
# TODO: in python >= 3.3 ipaddress module could be used
if value.count('/') > 1:
raise ValueError('Too many /: ' + value)
@@ -125,8 +129,8 @@ def __init__(self, value, prefixlen=None):
if not all(c in safe_set for c in value):
raise ValueError('Invalid hostname')
else:
- host, prefixlen = value.split('/', 1)
- prefixlen = int(prefixlen)
+ host, prefixlen_str = value.split('/', 1)
+ prefixlen = int(prefixlen_str)
if prefixlen < 0:
raise ValueError('netmask must be non-negative')
self.prefixlen = prefixlen
@@ -150,7 +154,7 @@ def __init__(self, value, prefixlen=None):
super().__init__(value)
@property
- def rule(self):
+ def rule(self) -> str | None:
'''API representation of this rule element'''
if self.prefixlen == 0 and self.type != 'dsthost':
# 0.0.0.0/0 or ::/0, doesn't limit to any particular host,
@@ -161,9 +165,8 @@ def rule(self):
class DstPorts(RuleOption):
'''Destination port(s), for TCP/UDP only'''
- def __init__(self, value):
- if isinstance(value, int):
- value = str(value)
+ def __init__(self, value: int | str) -> None:
+ value: str = str(value) if isinstance(value, int) else value
if value.count('-') == 1:
self.range = [int(x) for x in value.split('-', 1)]
elif not value.count('-'):
@@ -179,54 +182,54 @@ def __init__(self, value):
else '{!s}-{!s}'.format(*self.range))
@property
- def rule(self):
+ def rule(self) -> str:
'''API representation of this rule element'''
return 'dstports=' + '{!s}-{!s}'.format(*self.range)
class IcmpType(RuleOption):
'''ICMP packet type'''
- def __init__(self, value):
+ def __init__(self, value: SupportsInt) -> None:
super().__init__(value)
value = int(value)
if value < 0 or value > 255:
raise ValueError('ICMP type out of range')
@property
- def rule(self):
+ def rule(self) -> str:
'''API representation of this rule element'''
return 'icmptype=' + str(self)
class SpecialTarget(RuleChoice):
'''Special destination'''
- dns = 'dns'
+ dns: str = 'dns'
@property
- def rule(self):
+ def rule(self) -> str:
'''API representation of this rule element'''
return 'specialtarget=' + str(self)
class Expire(RuleOption):
'''Rule expire time'''
- def __init__(self, value):
+ def __init__(self, value: SupportsInt) -> None:
super().__init__(value)
self.datetime = datetime.datetime.fromtimestamp(int(value),
datetime.timezone.utc)
@property
- def rule(self):
+ def rule(self) -> str:
'''API representation of this rule element'''
return 'expire=' + str(self)
@property
- def expired(self):
+ def expired(self) -> bool:
'''Has this rule expired already?'''
return self.datetime < datetime.datetime.now(datetime.timezone.utc)
@property
- def pretty_value(self):
+ def pretty_value(self) -> str:
'''Human readable representation'''
now = datetime.datetime.now(datetime.timezone.utc)
duration = (self.datetime - now).total_seconds()
@@ -236,7 +239,7 @@ def pretty_value(self):
class Comment(RuleOption):
'''User comment'''
@property
- def rule(self):
+ def rule(self) -> str:
'''API representation of this rule element'''
return 'comment=' + str(self)
@@ -244,20 +247,20 @@ def rule(self):
class Rule:
'''A single firewall rule'''
- def __init__(self, rule, **kwargs):
+ def __init__(self, rule: str | None, **kwargs):
'''Single firewall rule
:param xml: XML element describing rule, or None
:param kwargs: rule elements
'''
- self._action = None
- self._proto = None
- self._dsthost = None
- self._dstports = None
- self._icmptype = None
- self._specialtarget = None
- self._expire = None
- self._comment = None
+ self._action: Action | None = None
+ self._proto: Proto | None = None
+ self._dsthost: DstHost | None = None
+ self._dstports: DstPorts | None = None
+ self._icmptype: IcmpType | None = None
+ self._specialtarget: SpecialTarget | None = None
+ self._expire: Expire | None = None
+ self._comment: Comment | None = None
rule_dict = {}
if rule is not None:
@@ -287,23 +290,23 @@ def __init__(self, rule, **kwargs):
raise ValueError('missing action=')
@property
- def action(self):
+ def action(self) -> Action | None:
'''rule action'''
return self._action
@action.setter
- def action(self, value):
+ def action(self, value: object) -> None:
if not isinstance(value, Action):
value = Action(value)
self._action = value
@property
- def proto(self):
+ def proto(self) -> Proto | None:
'''protocol to match'''
return self._proto
@proto.setter
- def proto(self, value):
+ def proto(self, value: object) -> None:
if value is not None and not isinstance(value, Proto):
value = Proto(value)
if value not in ('tcp', 'udp'):
@@ -313,23 +316,23 @@ def proto(self, value):
self._proto = value
@property
- def dsthost(self):
+ def dsthost(self) -> DstHost | None:
'''destination host/network'''
return self._dsthost
@dsthost.setter
- def dsthost(self, value):
+ def dsthost(self, value: str | DstHost | None) -> None:
if value is not None and not isinstance(value, DstHost):
value = DstHost(value)
self._dsthost = value
@property
- def dstports(self):
+ def dstports(self) -> DstPorts | None:
''''Destination port(s) (for \'tcp\' and \'udp\' protocol only)'''
return self._dstports
@dstports.setter
- def dstports(self, value):
+ def dstports(self, value: str | int | DstPorts | None) -> None:
if value is not None:
if self.proto not in ('tcp', 'udp'):
raise ValueError(
@@ -339,12 +342,12 @@ def dstports(self, value):
self._dstports = value
@property
- def icmptype(self):
+ def icmptype(self) -> IcmpType | None:
'''ICMP packet type (for \'icmp\' protocol only)'''
return self._icmptype
@icmptype.setter
- def icmptype(self, value):
+ def icmptype(self, value: IcmpType | SupportsInt | None) -> None:
if value is not None:
if self.proto not in ('icmp',):
raise ValueError('icmptype valid only for \'icmp\' protocol')
@@ -353,40 +356,40 @@ def icmptype(self, value):
self._icmptype = value
@property
- def specialtarget(self):
+ def specialtarget(self) -> SpecialTarget | None:
'''Special target, for now only \'dns\' supported'''
return self._specialtarget
@specialtarget.setter
- def specialtarget(self, value):
+ def specialtarget(self, value: object) -> None:
if not isinstance(value, SpecialTarget):
value = SpecialTarget(value)
self._specialtarget = value
@property
- def expire(self):
+ def expire(self) -> Expire | None:
'''Timestamp (UNIX epoch) on which this rule expire'''
return self._expire
@expire.setter
- def expire(self, value):
+ def expire(self, value: Expire | SupportsInt) -> None:
if not isinstance(value, Expire):
value = Expire(value)
self._expire = value
@property
- def comment(self):
+ def comment(self) -> Comment | None:
'''User comment'''
return self._comment
@comment.setter
- def comment(self, value):
+ def comment(self, value: object) -> None:
if not isinstance(value, Comment):
value = Comment(value)
self._comment = value
@property
- def rule(self):
+ def rule(self) -> str:
'''API representation of this rule'''
values = []
# comment must be the last one
@@ -400,36 +403,36 @@ def rule(self):
values.append(value.rule)
return ' '.join(values)
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
if isinstance(other, Rule):
return self.rule == other.rule
if isinstance(other, str):
return self.rule == str
return NotImplemented
- def __repr__(self):
+ def __repr__(self) -> str:
return 'Rule(\'{}\')'.format(self.rule)
class Firewall:
'''Firewal manager for a VM'''
- def __init__(self, vm):
+ def __init__(self, vm: QubesVM):
self.vm = vm
self._rules: list[Rule] = []
self._policy = None
self._loaded = False
- def load_rules(self):
+ def load_rules(self) -> None:
'''Force (re-)loading firewall rules'''
rules_str = self.vm.qubesd_call(None, 'admin.vm.firewall.Get')
- rules = []
+ rules: list[Rule] = []
for rule_str in rules_str.decode().splitlines():
rules.append(Rule(rule_str))
self._rules = rules
self._loaded = True
@property
- def rules(self):
+ def rules(self) -> list[Rule]:
'''Firewall rules
You can either copy them, edit and then assign new rules list to this
@@ -442,11 +445,11 @@ def rules(self):
return self._rules
@rules.setter
- def rules(self, value):
+ def rules(self, value: list[Rule]) -> None:
self.save_rules(value)
self._rules = value
- def save_rules(self, rules=None):
+ def save_rules(self, rules: list[Rule] | None=None) -> None:
'''Save firewall rules. Needs to be called after in-place editing
:py:attr:`rules`.
'''
@@ -457,11 +460,11 @@ def save_rules(self, rules=None):
for rule in rules)).encode('ascii'))
@property
- def policy(self):
+ def policy(self) -> Action:
'''Default action to take if no rule matches'''
return Action('drop')
- def reload(self):
+ def reload(self) -> None:
'''Force reload the same firewall rules.
Can be used for example to force again names resolution.