diff --git a/qubesadmin/tests/tools/qvm_shutdown.py b/qubesadmin/tests/tools/qvm_shutdown.py index e4257967..2a9da86e 100644 --- a/qubesadmin/tests/tools/qvm_shutdown.py +++ b/qubesadmin/tests/tools/qvm_shutdown.py @@ -194,9 +194,6 @@ def test_015_wait_all_kill_timeout(self): self.app.expected_calls[ ('sys-net', 'admin.vm.Shutdown', 'force', None)] = \ b'0\x00' - self.app.expected_calls[ - ('sys-net', 'admin.vm.Kill', None, None)] = \ - b'2\x00QubesVMNotStartedError\x00\x00Domain is powered off\x00' self.app.expected_calls[ ('dom0', 'admin.vm.List', None, None)] = \ b'0\x00' \ @@ -207,16 +204,203 @@ def test_015_wait_all_kill_timeout(self): ('some-vm', 'admin.vm.CurrentState', None, None)] = [ b'0\x00power_state=Running', b'0\x00power_state=Running', + b'0\x00power_state=Running', ] self.app.expected_calls[ ('other-vm', 'admin.vm.CurrentState', None, None)] = [ b'0\x00power_state=Running', b'0\x00power_state=Running', + b'0\x00power_state=Running', ] self.app.expected_calls[ - ('sys-net', 'admin.vm.CurrentState', None, None)] = \ - b'0\x00power_state=Halted' + ('sys-net', 'admin.vm.CurrentState', None, None)] = [ + b'0\x00power_state=Halted', + b'0\x00power_state=Halted', + b'0\x00power_state=Halted', + ] with self.assertRaisesRegex(SystemExit, '2'): qubesadmin.tools.qvm_shutdown.main( ['--wait', '--all', '--timeout=1'], app=self.app) self.assertAllCalled() + + def test_005_force(self): + '''test --force sends force flag to shutdown call''' + self.app.expected_calls[ + ('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00some-vm class=AppVM state=Running\n' + self.app.expected_calls[ + ('some-vm', 'admin.vm.Shutdown', 'force', None)] = b'0\x00' + qubesadmin.tools.qvm_shutdown.main( + ['--force', 'some-vm'], app=self.app) + self.assertAllCalled() + + def test_006_dry_run(self): + '''test --dry-run skips shutdown calls''' + self.app.expected_calls[ + ('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00some-vm class=AppVM state=Running\n' + qubesadmin.tools.qvm_shutdown.main( + ['--dry-run', 'some-vm'], app=self.app) + self.assertAllCalled() + + @unittest.skipUnless(qubesadmin.tools.qvm_shutdown.have_events, + 'Events not present') + def test_011_wait_retry(self): + '''test --wait retries VMs whose shutdown request failed''' + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + mock_events = unittest.mock.AsyncMock() + patch = unittest.mock.patch( + 'qubesadmin.events.EventsDispatcher._get_events_reader', + mock_events) + patch.start() + self.addCleanup(patch.stop) + mock_events.side_effect = qubesadmin.tests.tools.MockEventsReader([ + # round 1: wait for some-vm + b'1\0\0connection-established\0\0', + b'1\0some-vm\0domain-shutdown\0\0', + # round 2: wait for other-vm + b'1\0\0connection-established\0\0', + b'1\0other-vm\0domain-shutdown\0\0', + ]) + + self.app.expected_calls[ + ('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00' \ + b'some-vm class=AppVM state=Running\n' \ + b'other-vm class=AppVM state=Running\n' + self.app.expected_calls[ + ('some-vm', 'admin.vm.Shutdown', None, None)] = \ + b'0\x00' + # other-vm fails first attempt, succeeds on retry + self.app.expected_calls[ + ('other-vm', 'admin.vm.Shutdown', None, None)] = [ + b'2\x00QubesException\x00\x00Shutdown refused\x00', + b'0\x00', + ] + self.app.expected_calls[ + ('some-vm', 'admin.vm.CurrentState', None, None)] = [ + b'0\x00power_state=Running', + b'0\x00power_state=Halted', + ] + self.app.expected_calls[ + ('other-vm', 'admin.vm.CurrentState', None, None)] = [ + b'0\x00power_state=Running', + b'0\x00power_state=Halted', + ] + qubesadmin.tools.qvm_shutdown.main( + ['--wait', 'some-vm', 'other-vm'], app=self.app) + self.assertAllCalled() + + @unittest.skipUnless(qubesadmin.tools.qvm_shutdown.have_events, + 'Events not present') + def test_013_wait_all_shutdown_fail(self): + '''test --wait exits with error when all shutdown requests fail''' + self.app.expected_calls[ + ('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00some-vm class=AppVM state=Running\n' + self.app.expected_calls[ + ('some-vm', 'admin.vm.Shutdown', None, None)] = \ + b'2\x00QubesException\x00\x00Shutdown refused\x00' + self.app.expected_calls[ + ('some-vm', 'admin.vm.CurrentState', None, None)] = \ + b'0\x00power_state=Running' + with self.assertRaises(SystemExit): + qubesadmin.tools.qvm_shutdown.main( + ['--wait', 'some-vm'], app=self.app) + self.assertAllCalled() + + @unittest.skipUnless(qubesadmin.tools.qvm_shutdown.have_events, + 'Events not present') + def test_016_wait_kill_exception(self): + '''test --wait timeout where kill raises QubesException''' + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + mock_events = unittest.mock.AsyncMock() + patch = unittest.mock.patch( + 'qubesadmin.events.EventsDispatcher._get_events_reader', + mock_events) + patch.start() + self.addCleanup(patch.stop) + mock_events.side_effect = qubesadmin.tests.tools.MockEventsReader([ + b'1\0\0connection-established\0\0', + ]) + + self.app.expected_calls[ + ('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00some-vm class=AppVM state=Running\n' + self.app.expected_calls[ + ('some-vm', 'admin.vm.Shutdown', None, None)] = \ + b'0\x00' + self.app.expected_calls[ + ('some-vm', 'admin.vm.Kill', None, None)] = \ + b'2\x00QubesException\x00\x00Kill failed\x00' + self.app.expected_calls[ + ('some-vm', 'admin.vm.CurrentState', None, None)] = [ + b'0\x00power_state=Running', + b'0\x00power_state=Running', + ] + with self.assertRaises(SystemExit): + qubesadmin.tools.qvm_shutdown.main( + ['--wait', '--timeout=1', 'some-vm'], app=self.app) + self.assertAllCalled() + + @unittest.skipUnless(qubesadmin.tools.qvm_shutdown.have_events, + 'Events not present') + def test_017_wait_dispvm_na(self): + '''test --wait treats DispVM with NA power state as shut down''' + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + mock_events = unittest.mock.AsyncMock() + patch = unittest.mock.patch( + 'qubesadmin.events.EventsDispatcher._get_events_reader', + mock_events) + patch.start() + self.addCleanup(patch.stop) + mock_events.side_effect = qubesadmin.tests.tools.MockEventsReader([ + b'1\0\0connection-established\0\0', + b'1\0disp123\0domain-shutdown\0\0', + ]) + + self.app.expected_calls[ + ('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00disp123 class=DispVM state=Running\n' + self.app.expected_calls[ + ('disp123', 'admin.vm.Shutdown', None, None)] = \ + b'0\x00' + self.app.expected_calls[ + ('disp123', 'admin.vm.CurrentState', None, None)] = [ + b'0\x00power_state=Running', + # failed_domains: first get_power_state() != 'Halted', + # then klass == 'DispVM' triggers second get_power_state() + b'0\x00power_state=NA', + b'0\x00power_state=NA', + ] + qubesadmin.tools.qvm_shutdown.main( + ['--wait', 'disp123'], app=self.app) + self.assertAllCalled() + + def test_018_wait_polling_fallback(self): + '''test --wait uses polling when events are unavailable''' + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + self.app.expected_calls[ + ('dom0', 'admin.vm.List', None, None)] = \ + b'0\x00some-vm class=AppVM state=Running\n' + self.app.expected_calls[ + ('some-vm', 'admin.vm.Shutdown', None, None)] = \ + b'0\x00' + self.app.expected_calls[ + ('some-vm', 'admin.vm.CurrentState', None, None)] = [ + b'0\x00power_state=Halted', + b'0\x00power_state=Halted', + ] + with unittest.mock.patch.object( + qubesadmin.tools.qvm_shutdown, 'have_events', False): + qubesadmin.tools.qvm_shutdown.main( + ['--wait', 'some-vm'], app=self.app) + self.assertAllCalled() diff --git a/qubesadmin/tools/qvm_shutdown.py b/qubesadmin/tools/qvm_shutdown.py index 765b1039..246179c5 100644 --- a/qubesadmin/tools/qvm_shutdown.py +++ b/qubesadmin/tools/qvm_shutdown.py @@ -25,9 +25,12 @@ from __future__ import print_function import sys -import time import asyncio +from typing import Iterable + +from qubesadmin.app import QubesBase +from qubesadmin.vm import QubesVM try: import qubesadmin.events.utils @@ -69,22 +72,32 @@ def failed_domains(vms): if not (vm.get_power_state() == 'Halted' or (vm.klass == 'DispVM' and vm.get_power_state() == 'NA'))] +async def _wait_for_shutdown_polling(vms: Iterable[QubesVM], app: QubesBase)\ + -> None: + """Fallback polling coroutine when events are not available.""" + current_vms = list(vms) + while True: + current_vms = failed_domains(current_vms) + if not current_vms: + break + app.log.info('Waiting for shutdown: {}'.format( + ', '.join([str(vm) for vm in current_vms]))) + await asyncio.sleep(1) + def main(args=None, app=None): # pylint: disable=missing-docstring args = parser.parse_args(args, app=app) force = args.force or bool(args.all_domains) - if have_events: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - remaining_domains = args.domains + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + remaining_domains = set(args.domains) for _ in range(len(args.domains)): - this_round_domains = set(remaining_domains) - if not this_round_domains: + if not remaining_domains: break - remaining_domains = set() + shutdown_failed = set() if not args.dry_run: - for vm in this_round_domains: + for vm in remaining_domains: try: vm.shutdown(force=force) except qubesadmin.exc.QubesVMNotStartedError: @@ -92,48 +105,29 @@ def main(args=None, app=None): # pylint: disable=missing-docstring except qubesadmin.exc.QubesException as e: if not args.wait: vm.log.error('Shutdown error: {}'.format(e)) - else: - remaining_domains.add(vm) + shutdown_failed.add(vm) if not args.wait: - if remaining_domains: - parser.error_runtime( - 'Failed to shut down: ' + - ', '.join(vm.name for vm in remaining_domains), - len(remaining_domains)) + assert not shutdown_failed return - this_round_domains.difference_update(remaining_domains) - if not this_round_domains: - # no VM shutdown request succeed, no sense to try again + awaiting = remaining_domains - shutdown_failed + remaining_domains = shutdown_failed + if not awaiting: + # no VM shutdown request succeeded, no sense to try again break + if have_events: - try: - # pylint: disable=no-member - loop.run_until_complete(asyncio.wait_for( - qubesadmin.events.utils.wait_for_domain_shutdown( - this_round_domains), - args.timeout)) - except asyncio.TimeoutError: - if not args.dry_run: - for vm in this_round_domains: - try: - vm.kill() - except qubesadmin.exc.QubesVMNotStartedError: - # already shut down - pass - except qubesadmin.exc.QubesException as e: - parser.error_runtime(e) + wait_coro = qubesadmin.events.utils.wait_for_domain_shutdown( + awaiting) else: - timeout = args.timeout - current_vms = list(sorted(this_round_domains)) - while timeout >= 0: - current_vms = failed_domains(current_vms) - if not current_vms: - break - args.app.log.info('Waiting for shutdown ({}): {}'.format( - timeout, ', '.join([str(vm) for vm in current_vms]))) - time.sleep(1) - timeout -= 1 + wait_coro = _wait_for_shutdown_polling(awaiting, args.app) + + try: + # pylint: disable=no-member + loop.run_until_complete(asyncio.wait_for( + wait_coro, args.timeout)) + except (TimeoutError, asyncio.TimeoutError): if not args.dry_run: + current_vms = failed_domains(awaiting) if current_vms: args.app.log.info( 'Killing remaining qubes: {}' @@ -147,15 +141,13 @@ def main(args=None, app=None): # pylint: disable=missing-docstring except qubesadmin.exc.QubesException as e: parser.error_runtime(e) - if args.wait: - if have_events: - loop.close() - failed = failed_domains(args.domains) - if failed: - parser.error_runtime( - 'Failed to shut down: ' + - ', '.join(vm.name for vm in failed), - len(failed)) + loop.close() + failed = failed_domains(args.domains) + if failed: + parser.error_runtime( + 'Failed to shut down: ' + + ', '.join(vm.name for vm in failed), + len(failed)) if __name__ == '__main__':