diff --git a/qubesadmin/events/__init__.py b/qubesadmin/events/__init__.py index af76d446..bd2826f7 100644 --- a/qubesadmin/events/__init__.py +++ b/qubesadmin/events/__init__.py @@ -257,10 +257,11 @@ def handle(self, subject, event, **kwargs): ): devclass = event.split(":")[1] subject.devices[devclass]._attachment_cache = None - elif event.split(":")[0] in ("device-removed",): + if event.split(":")[0] in ("device-removed",): devclass = event.split(":")[1] + port_id = kwargs.get("port", ":").split(":")[1] try: - subject.devices[devclass]._dev_cache[kwargs["port"]] + del subject.devices[devclass]._dev_cache[port_id] except KeyError: pass diff --git a/qubesadmin/tests/devices.py b/qubesadmin/tests/devices.py index 6b923856..da143a9a 100644 --- a/qubesadmin/tests/devices.py +++ b/qubesadmin/tests/devices.py @@ -20,6 +20,7 @@ # pylint: disable=missing-docstring +from unittest import mock import qubesadmin.tests import qubesadmin.device_protocol @@ -27,7 +28,7 @@ from qubesadmin.device_protocol import ( DeviceAssignment, DeviceInfo, UnknownDevice, AssignmentMode) - +from qubesadmin.events import EventsDispatcher serialized_test_device = ( b"0\0dev1 port_id='dev1' devclass='test' vendor='itl' product='test-device'" @@ -416,3 +417,54 @@ def test_085_allow_device_multiple(self): qubesadmin.device_protocol.DeviceInterface("m******"), ) self.assertAllCalled() + + def test_100_cache_invalidate(self): + # this also enables caching + dispatcher = EventsDispatcher(self.app) + handler = mock.Mock() + dispatcher.add_handler("device-added:test", handler) + self.app.expected_calls[ + ("test-vm", "admin.vm.device.test.Available", None, None) + ] = ( + serialized_test_device + + b"device_id='1234:5678:0123456789:?*******'\n" + ) + vm = self.app.domains.get("test-vm") + # this also populates cache + dev = vm.devices["test"]["dev1"] + self.assertIsInstance(dev, DeviceInfo) + self.assertNotIsInstance(dev, UnknownDevice) + self.assertEqual(dev.device_id, "1234:5678:0123456789:?*******") + dispatcher.handle( + "test-vm", + "device-added:test", + device="test-vm:dev1:1234:5678:0123456789:?*******", + ) + handler.assert_called_once_with(vm, "device-added:test", device=dev) + handler.reset_mock() + dispatcher.handle("test-vm", "device-removed:test", port="test-vm:dev1") + self.app.expected_calls[ + ("test-vm", "admin.vm.device.test.Available", None, None) + ] = ( + serialized_test_device + + b"device_id='8765:4321:0123456789:?*******'\n" + ) + dispatcher.handle( + "test-vm", + "device-added:test", + device="test-vm:dev1:8765:4321:0123456789:?*******", + ) + handler.assert_called_once_with( + vm, "device-added:test", device=mock.ANY + ) + self.assertIsInstance( + handler.mock_calls[0].kwargs["device"], DeviceInfo + ) + self.assertNotIsInstance( + handler.mock_calls[0].kwargs["device"], UnknownDevice + ) + self.assertEqual( + handler.mock_calls[0].kwargs["device"].device_id, + "8765:4321:0123456789:?*******", + ) + self.assertAllCalled()