diff --git a/later/unittest/case.py b/later/unittest/case.py index d3a636b..b28b018 100644 --- a/later/unittest/case.py +++ b/later/unittest/case.py @@ -27,6 +27,7 @@ import reprlib import sys import unittest.mock as mock +import weakref from functools import wraps from typing import Any, Callable, TypeVar @@ -41,6 +42,7 @@ _IGNORE_TASK_LEAKS_ATTR = "__later_testcase_ignore_tasks__" _IGNORE_AIO_ERRS_ATTR = "__later_testcase_ignore_asyncio__" atleastpy38: bool = sys.version_info[:2] >= (3, 8) +_unmanaged_tasks: weakref.WeakSet[asyncio.Task] = weakref.WeakSet() class TestTask(asyncio.Task): @@ -50,6 +52,7 @@ class TestTask(asyncio.Task): def __init__(self, coro, *args, **kws) -> None: # pyre-fixme[16]: Module `coroutines` has no attribute `_format_coroutine`. self._coro_repr = asyncio.coroutines._format_coroutine(coro) + _unmanaged_tasks.add(self) super().__init__(coro, *args, **kws) @reprlib.recursive_repr() @@ -63,24 +66,29 @@ def __repr__(self) -> str: repr_info[1] = coro # pragma: nocover return f"<{self.__class__.__name__} {' '.join(repr_info)}>" + def _mark_managed(self): + if not self._managed: + self._managed = True + _unmanaged_tasks.discard(self) + def __await__(self): - self._managed = True + self._mark_managed() return super().__await__() def result(self): if self.done(): - self._managed = True + self._mark_managed() return super().result() def exception(self): if self.done(): - self._managed = True + self._mark_managed() return super().exception() def add_done_callback(self, fn, *, context=None) -> None: @wraps(fn) def mark_managed(fut): - self._managed = True + self._mark_managed() return fn(fut) super().add_done_callback(mark_managed, context=context) @@ -89,11 +97,12 @@ def was_managed(self) -> bool: if self._managed: return True # If the task is done() and the result is None, let it pass as managed + # We use super here so we don't manage ourselves. return ( self.done() and not self.cancelled() - and not self.exception() - and self.result() is None + and not super().exception() + and super().result() is None ) def __del__(self) -> None: @@ -101,13 +110,7 @@ def __del__(self) -> None: # we accept that as long as there was no result other than None # thrift-py3 uses this pattern to call rpc methods in ServiceInterfaces # where any result/execption is returned to the remote client. - managed = self.was_managed() - if not managed and not ( - self.done() - and not self.cancelled() - and not self.exception() - and self.result() is None - ): + if not self.was_managed(): context = { "task": self, "message": ( @@ -138,7 +141,7 @@ def all_tasks(loop): i = 0 while True: try: - tasks = list(_all_tasks) + tasks = list(_all_tasks) + list(_unmanaged_tasks) except RuntimeError: # pragma: nocover i += 1 if i >= 1000: