diff --git a/qubesusbproxy/core3ext.py b/qubesusbproxy/core3ext.py index 950c3f8..7eee767 100644 --- a/qubesusbproxy/core3ext.py +++ b/qubesusbproxy/core3ext.py @@ -521,6 +521,37 @@ def __init__(self): "/etc/qubes-rpc/qubes.USB" ) self.devices_cache = collections.defaultdict(dict) + self.autoattach_locks = collections.defaultdict(asyncio.Lock) + + async def _auto_attach_devices(self, vm): + async with self.autoattach_locks[vm.uuid]: + to_attach = {} + assignments = get_assigned_devices(vm.devices["usb"]) + # the most specific assignments first + for assignment in reversed(sorted(assignments)): + for device in assignment.devices: + if isinstance(device, qubes.device_protocol.UnknownDevice): + continue + if device.attachment: + continue + if not assignment.matches(device): + print( + "Unrecognized identity, skipping attachment of device " + f"from the port {assignment}", + file=sys.stderr, + ) + continue + # chose first assignment (the most specific) and ignore rest + if device not in to_attach: + # make it unique + to_attach[device] = assignment.clone(device=device) + in_progress = set() + for assignment in to_attach.values(): + in_progress.add( + asyncio.ensure_future(self.attach_and_notify(vm, assignment)) + ) + if in_progress: + await asyncio.wait(in_progress) @qubes.ext.handler("domain-init", "domain-load") def on_domain_init_load(self, vm, event): @@ -744,41 +775,22 @@ async def on_device_assign_usb(self, vm, event, device, options): @qubes.ext.handler("domain-start") async def on_domain_start(self, vm, _event, **_kwargs): # pylint: disable=unused-argument - to_attach = {} - assignments = get_assigned_devices(vm.devices["usb"]) - # the most specific assignments first - for assignment in reversed(sorted(assignments)): - for device in assignment.devices: - if isinstance(device, qubes.device_protocol.UnknownDevice): - continue - if device.attachment: - continue - if not assignment.matches(device): - print( - "Unrecognized identity, skipping attachment of device " - f"from the port {assignment}", - file=sys.stderr, - ) - continue - # chose first assignment (the most specific) and ignore rest - if device not in to_attach: - # make it unique - to_attach[device] = assignment.clone(device=device) - in_progress = set() - for assignment in to_attach.values(): - in_progress.add( - asyncio.ensure_future(self.attach_and_notify(vm, assignment)) - ) - if in_progress: - await asyncio.wait(in_progress) + await self._auto_attach_devices(vm) @qubes.ext.handler("domain-shutdown") async def on_domain_shutdown(self, vm, _event, **_kwargs): # pylint: disable=unused-argument vm.fire_event("device-list-change:usb") utils.device_list_change(self, {}, vm, None, USBDevice) + del self.autoattach_locks[vm.uuid] + + @qubes.ext.handler("domain-resumed") + async def on_domain_resumed(self, vm, _event, **_kwargs): + # pylint: disable=unused-argument + await self._auto_attach_devices(vm) @qubes.ext.handler("qubes-close", system=True) def on_qubes_close(self, app, event): # pylint: disable=unused-argument self.devices_cache.clear() + self.autoattach_locks.clear() diff --git a/qubesusbproxy/tests.py b/qubesusbproxy/tests.py index d44b6ef..2ac01ff 100644 --- a/qubesusbproxy/tests.py +++ b/qubesusbproxy/tests.py @@ -22,6 +22,7 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # import time +import uuid import unittest from unittest import mock from unittest.mock import Mock, AsyncMock @@ -620,6 +621,7 @@ class TestVM(qubes.tests.TestEmitter): def __init__(self, qdb, running=True, name="test-vm", **kwargs): super().__init__(**kwargs) self.name = name + self.uuid = uuid.uuid4() self.klass = "AdminVM" if name == "dom0" else "AppVM" self.icon = "red" self.untrusted_qdb = TestQubesDB(qdb)