Skip to content

Commit 455ce7e

Browse files
author
anon
committed
type-hint base.py
1 parent 1aa82a5 commit 455ce7e

1 file changed

Lines changed: 78 additions & 45 deletions

File tree

qubesadmin/base.py

Lines changed: 78 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,24 @@
1919
# with this program; if not, see <http://www.gnu.org/licenses/>.
2020

2121
'''Base classes for managed objects'''
22+
from __future__ import annotations
23+
24+
import typing
25+
from typing import BinaryIO, Any, TypeAlias, TypeVar, Generic
26+
from collections.abc import Generator
2227

2328
import qubesadmin.exc
2429

30+
if typing.TYPE_CHECKING:
31+
from qubesadmin.vm import QubesVM
32+
from qubesadmin.app import QubesBase
33+
2534
DEFAULT = object()
2635

36+
# We use Any because the dynamic metatada handling of the current code
37+
# is too complex for type checkers otherwise
38+
VMProperty: TypeAlias = Any # noqa: ANN401
39+
2740

2841
class PropertyHolder:
2942
'''A base class for object having properties retrievable using mgmt API.
@@ -34,28 +47,29 @@ class PropertyHolder:
3447
'''
3548
#: a place for appropriate Qubes() object (QubesLocal or QubesRemote),
3649
# use None for self
37-
app = None
50+
app: QubesBase
3851

39-
def __init__(self, app, method_prefix, method_dest):
52+
def __init__(self, app: QubesBase, method_prefix: str, method_dest: str):
4053
#: appropriate Qubes() object (QubesLocal or QubesRemote), use None
4154
# for self
4255
self.app = app
4356
self._method_prefix = method_prefix
4457
self._method_dest = method_dest
45-
self._properties = None
58+
self._properties: list[str] | None = None
4659
self._properties_help = None
4760
# the cache is maintained by EventsDispatcher(),
4861
# through helper functions in QubesBase()
4962
self._properties_cache = {}
5063

51-
def clear_cache(self):
64+
def clear_cache(self) -> None:
5265
"""
5366
Clear property cache.
5467
"""
5568
self._properties_cache = {}
5669

57-
def qubesd_call(self, dest, method, arg=None, payload=None,
58-
payload_stream=None):
70+
def qubesd_call(self, dest: str | None, method: str,
71+
arg: str | None=None, payload: bytes | None=None,
72+
payload_stream: BinaryIO | None=None) -> bytes:
5973
'''
6074
Call into qubesd using appropriate mechanism. This method should be
6175
defined by a subclass.
@@ -69,26 +83,27 @@ def qubesd_call(self, dest, method, arg=None, payload=None,
6983
:param payload_stream: file-like object to read payload from
7084
:return: Data returned by qubesd (string)
7185
'''
72-
if not self.app:
73-
raise NotImplementedError
74-
if dest is None:
75-
dest = self._method_dest
86+
dest: str = dest or self._method_dest
7687
if (
7788
getattr(self, "_redirect_dispvm_calls", False)
7889
and dest.startswith("@dispvm")
7990
):
8091
if dest.startswith("@dispvm:"):
8192
dest = dest[len("@dispvm:") :]
8293
else:
83-
dest = getattr(self.app, "default_dispvm", None)
94+
# TODO what if `dest` remains None here ?
95+
# qubesd_call expects a non-None `dest` arg
96+
dest: QubesVM | None = getattr(self.app, "default_dispvm", None)
97+
assert dest is not None
8498
if dest:
85-
dest = dest.name
99+
dest: str = dest.name
100+
assert isinstance(dest, str)
86101
# have the actual implementation at Qubes() instance
87102
return self.app.qubesd_call(dest, method, arg, payload,
88103
payload_stream)
89104

90105
@staticmethod
91-
def _parse_qubesd_response(response_data):
106+
def _parse_qubesd_response(response_data: bytes) -> bytes:
92107
'''Parse response from qubesd.
93108
94109
In case of success, return actual data. In case of error,
@@ -122,7 +137,7 @@ def _parse_qubesd_response(response_data):
122137
raise qubesadmin.exc.QubesDaemonCommunicationError(
123138
'Invalid response format')
124139

125-
def property_list(self):
140+
def property_list(self) -> list[str]:
126141
'''
127142
List available properties (their names).
128143
@@ -138,7 +153,7 @@ def property_list(self):
138153
# TODO: make it somehow immutable
139154
return self._properties
140155

141-
def property_help(self, name):
156+
def property_help(self, name: str) -> str:
142157
'''
143158
Get description of a property.
144159
@@ -151,7 +166,7 @@ def property_help(self, name):
151166
None)
152167
return help_text.decode('ascii')
153168

154-
def property_is_default(self, item):
169+
def property_is_default(self, item: str) -> bool:
155170
'''
156171
Check if given property have default value
157172
@@ -183,7 +198,7 @@ def property_is_default(self, item):
183198
self._properties_cache[item] = (is_default, value)
184199
return is_default
185200

186-
def property_get_default(self, item):
201+
def property_get_default(self, item: str) -> VMProperty:
187202
'''
188203
Get default property value, regardless of the current value
189204
@@ -206,7 +221,8 @@ def property_get_default(self, item):
206221
(prop_type, value) = property_str.split(b' ', 1)
207222
return self._parse_type_value(prop_type, value)
208223

209-
def clone_properties(self, src, proplist=None):
224+
def clone_properties(self, src: PropertyHolder,
225+
proplist: list[str] | None=None) -> None:
210226
'''Clone properties from other object.
211227
212228
:param PropertyHolder src: source object
@@ -223,7 +239,7 @@ def clone_properties(self, src, proplist=None):
223239
except AttributeError:
224240
continue
225241

226-
def __getattr__(self, item):
242+
def __getattr__(self, item: str) -> VMProperty:
227243
if item.startswith('_'):
228244
raise AttributeError(item)
229245
# pre-fill cache if enabled
@@ -254,7 +270,8 @@ def __getattr__(self, item):
254270
raise AttributeError(item)
255271
return value
256272

257-
def _deserialize_property(self, api_response):
273+
def _deserialize_property(self, api_response: bytes) \
274+
-> tuple[bool, VMProperty]:
258275
"""
259276
Deserialize property.Get response format
260277
:param api_response: bytes, as retrieved from qubesd
@@ -267,7 +284,7 @@ def _deserialize_property(self, api_response):
267284
value = self._parse_type_value(prop_type, value)
268285
return is_default, value
269286

270-
def _parse_type_value(self, prop_type, value):
287+
def _parse_type_value(self, prop_type: bytes, value: bytes) -> VMProperty:
271288
'''
272289
Parse `type=... ...` qubesd response format. Return a value of
273290
appropriate type.
@@ -278,20 +295,23 @@ def _parse_type_value(self, prop_type, value):
278295
:return: parsed value
279296
'''
280297
# pylint: disable=too-many-return-statements
281-
prop_type = prop_type.decode('ascii')
298+
prop_type: str = prop_type.decode('ascii')
282299
if not prop_type.startswith('type='):
283300
raise qubesadmin.exc.QubesDaemonCommunicationError(
284301
'Invalid type prefix received: {}'.format(prop_type))
285302
(_, prop_type) = prop_type.split('=', 1)
286-
value = value.decode()
303+
value: str = value.decode()
287304
if prop_type == 'str':
288305
return str(value)
289306
if prop_type == 'bool':
290307
if value == '':
308+
# TODO shouldn't that at least be ValueError ?
309+
# but then we need to properly propagate that modification
291310
return AttributeError
292311
return value == "True"
293312
if prop_type == 'int':
294313
if value == '':
314+
# TODO same as above
295315
return AttributeError
296316
return int(value)
297317
if prop_type == 'vm':
@@ -305,7 +325,7 @@ def _parse_type_value(self, prop_type, value):
305325
raise qubesadmin.exc.QubesDaemonCommunicationError(
306326
'Received invalid value type: {}'.format(prop_type))
307327

308-
def _fetch_all_properties(self):
328+
def _fetch_all_properties(self) -> None:
309329
"""
310330
Retrieve all properties values at once using (prefix).property.GetAll
311331
method. If it succeed, save retrieved values in the properties cache.
@@ -315,7 +335,7 @@ def _fetch_all_properties(self):
315335
:return: None
316336
"""
317337

318-
def unescape(line):
338+
def unescape(line: bytes) -> Generator[int]:
319339
"""Handle \\-escaped values, generates a list of character codes"""
320340
escaped = False
321341
for char in line:
@@ -342,15 +362,15 @@ def unescape(line):
342362
return
343363
for line in properties_str.splitlines():
344364
# decode newlines
345-
line = bytes(unescape(line))
346-
name, property_str = line.split(b' ', 1)
365+
line_bytes = bytes(list(unescape(line)))
366+
name, property_str = line_bytes.split(b' ', 1)
347367
name = name.decode()
348368
is_default, value = self._deserialize_property(property_str)
349369
self._properties_cache[name] = (is_default, value)
350370
self._properties = list(self._properties_cache.keys())
351371

352372
@classmethod
353-
def _local_properties(cls):
373+
def _local_properties(cls) -> set:
354374
'''
355375
Get set of property names that are properties on the Python object,
356376
and must not be set on the remote object
@@ -367,7 +387,7 @@ def _local_properties(cls):
367387

368388
return cls._local_properties_set
369389

370-
def __setattr__(self, key, value):
390+
def __setattr__(self, key: str, value: typing.Any) -> None: # noqa: ANN401
371391
if key.startswith('_') or key in self._local_properties():
372392
return super().__setattr__(key, value)
373393
if value is qubesadmin.DEFAULT:
@@ -381,7 +401,9 @@ def __setattr__(self, key, value):
381401
qubesadmin.exc.QubesVMNotFoundError):
382402
raise qubesadmin.exc.QubesPropertyAccessError(key)
383403
else:
384-
if isinstance(value, qubesadmin.vm.QubesVM):
404+
# Dynamic import because qubesadmin.vm imports base.py
405+
from qubesadmin.vm import QubesVM
406+
if isinstance(value, QubesVM):
385407
value = value.name
386408
if value is None:
387409
value = ''
@@ -395,7 +417,7 @@ def __setattr__(self, key, value):
395417
qubesadmin.exc.QubesVMNotFoundError):
396418
raise qubesadmin.exc.QubesPropertyAccessError(key)
397419

398-
def __delattr__(self, name):
420+
def __delattr__(self, name: str) -> None:
399421
if name.startswith('_') or name in self._local_properties():
400422
return super().__delattr__(name)
401423
try:
@@ -408,10 +430,13 @@ def __delattr__(self, name):
408430
qubesadmin.exc.QubesVMNotFoundError):
409431
raise qubesadmin.exc.QubesPropertyAccessError(name)
410432

433+
WrapperObjectsCollectionKey: TypeAlias = int | str
434+
T = TypeVar('T')
411435

412-
class WrapperObjectsCollection:
436+
class WrapperObjectsCollection(Generic[T]):
413437
'''Collection of simple named objects'''
414-
def __init__(self, app, list_method, object_class):
438+
def __init__(self, app: QubesBase,
439+
list_method: str, object_class: type[T]):
415440
'''
416441
Construct manager of named wrapper objects.
417442
@@ -425,11 +450,13 @@ def __init__(self, app, list_method, object_class):
425450
self._list_method = list_method
426451
self._object_class = object_class
427452
#: names cache
428-
self._names_list = None
453+
self._names_list: list[WrapperObjectsCollectionKey] | None = None
429454
#: returned objects cache
430-
self._objects = {}
455+
self._objects: dict[WrapperObjectsCollectionKey, T] = {}
431456

432-
def clear_cache(self, invalidate_name=None):
457+
def clear_cache(self,
458+
invalidate_name: WrapperObjectsCollectionKey | None=None)\
459+
-> None:
433460
"""Clear cached list of names.
434461
If *invalidate_name* is given, remove that object from cache
435462
explicitly too.
@@ -438,7 +465,7 @@ def clear_cache(self, invalidate_name=None):
438465
if invalidate_name:
439466
self._objects.pop(invalidate_name, None)
440467

441-
def refresh_cache(self, force=False):
468+
def refresh_cache(self, force: bool=False) -> None:
442469
'''Refresh cached list of names'''
443470
if not force and self._names_list is not None:
444471
return
@@ -447,17 +474,18 @@ def refresh_cache(self, force=False):
447474
assert list_data[-1] == '\n'
448475
self._names_list = [str(name) for name in list_data[:-1].splitlines()]
449476

450-
for name, obj in list(self._objects.items()):
477+
for name, obj in self._objects.items():
478+
assert hasattr(obj, "name")
451479
if obj.name not in self._names_list:
452480
# Object no longer exists
453481
del self._objects[name]
454482

455-
def __getitem__(self, item):
483+
def __getitem__(self, item: WrapperObjectsCollectionKey) -> T:
456484
if not self.app.blind_mode and item not in self:
457485
raise KeyError(item)
458486
return self.get_blind(item)
459487

460-
def get_blind(self, item):
488+
def get_blind(self, item: WrapperObjectsCollectionKey) -> T:
461489
'''
462490
Get a property without downloading the list
463491
and checking if it's present
@@ -466,25 +494,30 @@ def get_blind(self, item):
466494
self._objects[item] = self._object_class(self.app, item)
467495
return self._objects[item]
468496

469-
def __contains__(self, item):
497+
def __contains__(self, item: WrapperObjectsCollectionKey) -> bool:
470498
self.refresh_cache()
499+
assert self._names_list is not None
471500
return item in self._names_list
472501

473-
def __iter__(self):
502+
def __iter__(self) -> Generator[WrapperObjectsCollectionKey]:
474503
self.refresh_cache()
504+
assert self._names_list is not None
475505
yield from self._names_list
476506

477-
def keys(self):
507+
def keys(self) -> list[WrapperObjectsCollectionKey]:
478508
'''Get list of names.'''
479509
self.refresh_cache()
510+
assert self._names_list is not None
480511
return list(self._names_list)
481512

482-
def items(self):
513+
def items(self) -> list[tuple[WrapperObjectsCollectionKey, T]]:
483514
'''Get list of (key, value) pairs'''
484515
self.refresh_cache()
516+
assert self._names_list is not None
485517
return [(key, self.get_blind(key)) for key in self._names_list]
486518

487-
def values(self):
519+
def values(self) -> list[T]:
488520
'''Get list of objects'''
489521
self.refresh_cache()
522+
assert self._names_list is not None
490523
return [self.get_blind(key) for key in self._names_list]

0 commit comments

Comments
 (0)