Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
323 changes: 183 additions & 140 deletions qubesadmin/app.py

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions qubesadmin/backup/core3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@

import qubesadmin.backup
import qubesadmin.firewall
from qubesadmin import device_protocol
from qubesadmin import utils
from qubesadmin.vm import QubesVM




class Core3VM(qubesadmin.backup.BackupVM):
'''VM object'''
@property
Expand Down Expand Up @@ -138,7 +136,7 @@ def import_core3_vm(self, element: _Element) -> None:
for opt_node in node.findall('./option'):
opt_name = opt_node.get('name')
options[opt_name] = opt_node.text
options['required'] = device_protocol.qbool(
options['required'] = utils.qbool(
node.get('required', 'yes'))
vm.devices[bus_name][(backend_domain, port_id)] = options

Expand Down
117 changes: 73 additions & 44 deletions qubesadmin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,24 @@
# with this program; if not, see <http://www.gnu.org/licenses/>.

'''Base classes for managed objects'''
from __future__ import annotations

import typing
from typing import BinaryIO, Any, TypeAlias, TypeVar, Generic
from collections.abc import Generator

import qubesadmin.exc

if typing.TYPE_CHECKING:
from qubesadmin.vm import QubesVM
from qubesadmin.app import QubesBase

DEFAULT = object()

# We use Any because the dynamic metatada handling of the current code
# is too complex for type checkers otherwise
VMProperty: TypeAlias = Any # noqa: ANN401


class PropertyHolder:
'''A base class for object having properties retrievable using mgmt API.
Expand All @@ -34,28 +47,29 @@ class PropertyHolder:
'''
#: a place for appropriate Qubes() object (QubesLocal or QubesRemote),
# use None for self
app = None
app: QubesBase

def __init__(self, app, method_prefix, method_dest):
def __init__(self, app: QubesBase, method_prefix: str, method_dest: str):
#: appropriate Qubes() object (QubesLocal or QubesRemote), use None
# for self
self.app = app
self._method_prefix = method_prefix
self._method_dest = method_dest
self._properties = None
self._properties: list[str] | None = None
self._properties_help = None
# the cache is maintained by EventsDispatcher(),
# through helper functions in QubesBase()
self._properties_cache = {}

def clear_cache(self):
def clear_cache(self) -> None:
"""
Clear property cache.
"""
self._properties_cache = {}

def qubesd_call(self, dest, method, arg=None, payload=None,
payload_stream=None):
def qubesd_call(self, dest: str | None, method: str,
arg: str | None=None, payload: bytes | None=None,
payload_stream: BinaryIO | None=None) -> bytes:
'''
Call into qubesd using appropriate mechanism. This method should be
defined by a subclass.
Expand All @@ -69,26 +83,23 @@ def qubesd_call(self, dest, method, arg=None, payload=None,
:param payload_stream: file-like object to read payload from
:return: Data returned by qubesd (string)
'''
if not self.app:
raise NotImplementedError
if dest is None:
dest = self._method_dest
dest: str = dest or self._method_dest
if (
getattr(self, "_redirect_dispvm_calls", False)
and dest.startswith("@dispvm")
):
if dest.startswith("@dispvm:"):
dest = dest[len("@dispvm:") :]
else:
dest = getattr(self.app, "default_dispvm", None)
dest: QubesVM | None = getattr(self.app, "default_dispvm", None)
if dest:
dest = dest.name
# have the actual implementation at Qubes() instance
return self.app.qubesd_call(dest, method, arg, payload,
payload_stream)

@staticmethod
def _parse_qubesd_response(response_data):
def _parse_qubesd_response(response_data: bytes) -> bytes:
'''Parse response from qubesd.

In case of success, return actual data. In case of error,
Expand Down Expand Up @@ -122,7 +133,7 @@ def _parse_qubesd_response(response_data):
raise qubesadmin.exc.QubesDaemonCommunicationError(
'Invalid response format')

def property_list(self):
def property_list(self) -> list[str]:
'''
List available properties (their names).

Expand All @@ -138,7 +149,7 @@ def property_list(self):
# TODO: make it somehow immutable
return self._properties

def property_help(self, name):
def property_help(self, name: str) -> str:
'''
Get description of a property.

Expand All @@ -151,7 +162,7 @@ def property_help(self, name):
None)
return help_text.decode('ascii')

def property_is_default(self, item):
def property_is_default(self, item: str) -> bool:
'''
Check if given property have default value

Expand Down Expand Up @@ -183,7 +194,7 @@ def property_is_default(self, item):
self._properties_cache[item] = (is_default, value)
return is_default

def property_get_default(self, item):
def property_get_default(self, item: str) -> VMProperty:
'''
Get default property value, regardless of the current value

Expand All @@ -206,7 +217,8 @@ def property_get_default(self, item):
(prop_type, value) = property_str.split(b' ', 1)
return self._parse_type_value(prop_type, value)

def clone_properties(self, src, proplist=None):
def clone_properties(self, src: PropertyHolder,
proplist: list[str] | None=None) -> None:
'''Clone properties from other object.

:param PropertyHolder src: source object
Expand All @@ -223,7 +235,7 @@ def clone_properties(self, src, proplist=None):
except AttributeError:
continue

def __getattr__(self, item):
def __getattr__(self, item: str) -> VMProperty:
if item.startswith('_'):
raise AttributeError(item)
# pre-fill cache if enabled
Expand Down Expand Up @@ -254,7 +266,8 @@ def __getattr__(self, item):
raise AttributeError(item)
return value

def _deserialize_property(self, api_response):
def _deserialize_property(self, api_response: bytes) \
-> tuple[bool, VMProperty]:
"""
Deserialize property.Get response format
:param api_response: bytes, as retrieved from qubesd
Expand All @@ -267,7 +280,7 @@ def _deserialize_property(self, api_response):
value = self._parse_type_value(prop_type, value)
return is_default, value

def _parse_type_value(self, prop_type, value):
def _parse_type_value(self, prop_type: bytes, value: bytes) -> VMProperty:
'''
Parse `type=... ...` qubesd response format. Return a value of
appropriate type.
Expand All @@ -278,20 +291,23 @@ def _parse_type_value(self, prop_type, value):
:return: parsed value
'''
# pylint: disable=too-many-return-statements
prop_type = prop_type.decode('ascii')
prop_type: str = prop_type.decode('ascii')
if not prop_type.startswith('type='):
raise qubesadmin.exc.QubesDaemonCommunicationError(
'Invalid type prefix received: {}'.format(prop_type))
(_, prop_type) = prop_type.split('=', 1)
value = value.decode()
value: str = value.decode()
if prop_type == 'str':
return str(value)
if prop_type == 'bool':
if value == '':
# TODO shouldn't that at least be ValueError ?
# but then we need to properly propagate that modification
return AttributeError
return value == "True"
if prop_type == 'int':
if value == '':
# TODO same as above
return AttributeError
return int(value)
if prop_type == 'vm':
Expand All @@ -305,7 +321,7 @@ def _parse_type_value(self, prop_type, value):
raise qubesadmin.exc.QubesDaemonCommunicationError(
'Received invalid value type: {}'.format(prop_type))

def _fetch_all_properties(self):
def _fetch_all_properties(self) -> None:
"""
Retrieve all properties values at once using (prefix).property.GetAll
method. If it succeed, save retrieved values in the properties cache.
Expand All @@ -315,7 +331,7 @@ def _fetch_all_properties(self):
:return: None
"""

def unescape(line):
def unescape(line: bytes) -> Generator[int]:
"""Handle \\-escaped values, generates a list of character codes"""
escaped = False
for char in line:
Expand All @@ -342,15 +358,15 @@ def unescape(line):
return
for line in properties_str.splitlines():
# decode newlines
line = bytes(unescape(line))
name, property_str = line.split(b' ', 1)
line_bytes = bytes(list(unescape(line)))
name, property_str = line_bytes.split(b' ', 1)
name = name.decode()
is_default, value = self._deserialize_property(property_str)
self._properties_cache[name] = (is_default, value)
self._properties = list(self._properties_cache.keys())

@classmethod
def _local_properties(cls):
def _local_properties(cls) -> set:
'''
Get set of property names that are properties on the Python object,
and must not be set on the remote object
Expand All @@ -367,7 +383,7 @@ def _local_properties(cls):

return cls._local_properties_set

def __setattr__(self, key, value):
def __setattr__(self, key: str, value: typing.Any) -> None: # noqa: ANN401
if key.startswith('_') or key in self._local_properties():
return super().__setattr__(key, value)
if value is qubesadmin.DEFAULT:
Expand All @@ -381,7 +397,9 @@ def __setattr__(self, key, value):
qubesadmin.exc.QubesVMNotFoundError):
raise qubesadmin.exc.QubesPropertyAccessError(key)
else:
if isinstance(value, qubesadmin.vm.QubesVM):
# Dynamic import because qubesadmin.vm imports base.py
from qubesadmin.vm import QubesVM
if isinstance(value, QubesVM):
value = value.name
if value is None:
value = ''
Expand All @@ -395,7 +413,7 @@ def __setattr__(self, key, value):
qubesadmin.exc.QubesVMNotFoundError):
raise qubesadmin.exc.QubesPropertyAccessError(key)

def __delattr__(self, name):
def __delattr__(self, name: str) -> None:
if name.startswith('_') or name in self._local_properties():
return super().__delattr__(name)
try:
Expand All @@ -408,10 +426,13 @@ def __delattr__(self, name):
qubesadmin.exc.QubesVMNotFoundError):
raise qubesadmin.exc.QubesPropertyAccessError(name)

WrapperObjectsCollectionKey: TypeAlias = int | str
T = TypeVar('T')

class WrapperObjectsCollection:
class WrapperObjectsCollection(Generic[T]):
'''Collection of simple named objects'''
def __init__(self, app, list_method, object_class):
def __init__(self, app: QubesBase,
list_method: str, object_class: type[T]):
'''
Construct manager of named wrapper objects.

Expand All @@ -425,11 +446,13 @@ def __init__(self, app, list_method, object_class):
self._list_method = list_method
self._object_class = object_class
#: names cache
self._names_list = None
self._names_list: list[WrapperObjectsCollectionKey] | None = None
#: returned objects cache
self._objects = {}
self._objects: dict[WrapperObjectsCollectionKey, T] = {}

def clear_cache(self, invalidate_name=None):
def clear_cache(self,
invalidate_name: WrapperObjectsCollectionKey | None=None)\
-> None:
"""Clear cached list of names.
If *invalidate_name* is given, remove that object from cache
explicitly too.
Expand All @@ -438,7 +461,7 @@ def clear_cache(self, invalidate_name=None):
if invalidate_name:
self._objects.pop(invalidate_name, None)

def refresh_cache(self, force=False):
def refresh_cache(self, force: bool=False) -> None:
'''Refresh cached list of names'''
if not force and self._names_list is not None:
return
Expand All @@ -447,17 +470,18 @@ def refresh_cache(self, force=False):
assert list_data[-1] == '\n'
self._names_list = [str(name) for name in list_data[:-1].splitlines()]

for name, obj in list(self._objects.items()):
for name, obj in self._objects.items():
assert hasattr(obj, "name")
if obj.name not in self._names_list:
# Object no longer exists
del self._objects[name]

def __getitem__(self, item):
def __getitem__(self, item: WrapperObjectsCollectionKey) -> T:
if not self.app.blind_mode and item not in self:
raise KeyError(item)
return self.get_blind(item)

def get_blind(self, item):
def get_blind(self, item: WrapperObjectsCollectionKey) -> T:
'''
Get a property without downloading the list
and checking if it's present
Expand All @@ -466,25 +490,30 @@ def get_blind(self, item):
self._objects[item] = self._object_class(self.app, item)
return self._objects[item]

def __contains__(self, item):
def __contains__(self, item: WrapperObjectsCollectionKey) -> bool:
self.refresh_cache()
assert self._names_list is not None
return item in self._names_list

def __iter__(self):
def __iter__(self) -> Generator[WrapperObjectsCollectionKey]:
self.refresh_cache()
assert self._names_list is not None
yield from self._names_list

def keys(self):
def keys(self) -> list[WrapperObjectsCollectionKey]:
'''Get list of names.'''
self.refresh_cache()
assert self._names_list is not None
return list(self._names_list)

def items(self):
def items(self) -> list[tuple[WrapperObjectsCollectionKey, T]]:
'''Get list of (key, value) pairs'''
self.refresh_cache()
assert self._names_list is not None
return [(key, self.get_blind(key)) for key in self._names_list]

def values(self):
def values(self) -> list[T]:
'''Get list of objects'''
self.refresh_cache()
assert self._names_list is not None
return [self.get_blind(key) for key in self._names_list]
Loading