diff --git a/later/unittest/case.py b/later/unittest/case.py index e4c55e1..3d2b771 100644 --- a/later/unittest/case.py +++ b/later/unittest/case.py @@ -25,7 +25,6 @@ import asyncio.tasks import sys import unittest.mock as mock -import weakref from functools import wraps from typing import Any, Callable, TypeVar @@ -40,7 +39,6 @@ _IGNORE_TASK_LEAKS_ATTR = "__later_testcase_ignore_tasks__" _IGNORE_AIO_ERRS_ATTR = "__later_testcase_ignore_asyncio__" atleastpy38: bool = sys.version_info[:2] >= (3, 8) -_unmanaged_tasks: weakref.WeakSet[asyncio.Task] = weakref.WeakSet() class TestTask(asyncio.Task): @@ -50,7 +48,6 @@ class TestTask(asyncio.Task): def __init__(self, coro, *args, **kws) -> None: # pyre-fixme[16]: Module `coroutines` has no attribute `_format_coroutine`. self._coro_repr = asyncio.coroutines._format_coroutine(coro) - _unmanaged_tasks.add(self) super().__init__(coro, *args, **kws) def __repr__(self) -> str: @@ -63,29 +60,24 @@ def __repr__(self) -> str: repr_info[1] = coro # pragma: nocover return f"<{self.__class__.__name__} {' '.join(repr_info)}>" - def _mark_managed(self): - if not self._managed: - self._managed = True - _unmanaged_tasks.remove(self) - def __await__(self): - self._mark_managed() + self._managed = True return super().__await__() def result(self): if self.done(): - self._mark_managed() + self._managed = True return super().result() def exception(self): if self.done(): - self._mark_managed() + self._managed = True return super().exception() def add_done_callback(self, fn, *, context=None) -> None: @wraps(fn) def mark_managed(fut): - self._mark_managed() + self._managed = True return fn(fut) super().add_done_callback(mark_managed, context=context) @@ -143,7 +135,7 @@ def all_tasks(loop): i = 0 while True: try: - tasks = list(_all_tasks) + list(_unmanaged_tasks) + tasks = list(_all_tasks) except RuntimeError: # pragma: nocover i += 1 if i >= 1000: