From c4e6cd9174ebcdff941ef59c26e92f768b4bc644 Mon Sep 17 00:00:00 2001 From: Jason Fried Date: Wed, 29 Nov 2023 15:34:53 -0800 Subject: [PATCH] cinder compatability for differences in all_tasks Summary: cinder runtime's tasks.all_tasks has some mechanism where done() tasks don't show up in all_tasks but we need them to check managed tasks so instead we track our own tasks in our task class for tests. This seems to make all the runtimes happy. Reviewed By: aleivag Differential Revision: D51681518 fbshipit-source-id: 455484fc9a52f7f07135ff2c52d569979ecd15f5 --- later/unittest/case.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/later/unittest/case.py b/later/unittest/case.py index 3d2b771..e4c55e1 100644 --- a/later/unittest/case.py +++ b/later/unittest/case.py @@ -25,6 +25,7 @@ import asyncio.tasks import sys import unittest.mock as mock +import weakref from functools import wraps from typing import Any, Callable, TypeVar @@ -39,6 +40,7 @@ _IGNORE_TASK_LEAKS_ATTR = "__later_testcase_ignore_tasks__" _IGNORE_AIO_ERRS_ATTR = "__later_testcase_ignore_asyncio__" atleastpy38: bool = sys.version_info[:2] >= (3, 8) +_unmanaged_tasks: weakref.WeakSet[asyncio.Task] = weakref.WeakSet() class TestTask(asyncio.Task): @@ -48,6 +50,7 @@ class TestTask(asyncio.Task): def __init__(self, coro, *args, **kws) -> None: # pyre-fixme[16]: Module `coroutines` has no attribute `_format_coroutine`. self._coro_repr = asyncio.coroutines._format_coroutine(coro) + _unmanaged_tasks.add(self) super().__init__(coro, *args, **kws) def __repr__(self) -> str: @@ -60,24 +63,29 @@ def __repr__(self) -> str: repr_info[1] = coro # pragma: nocover return f"<{self.__class__.__name__} {' '.join(repr_info)}>" + def _mark_managed(self): + if not self._managed: + self._managed = True + _unmanaged_tasks.remove(self) + def __await__(self): - self._managed = True + self._mark_managed() return super().__await__() def result(self): if self.done(): - self._managed = True + self._mark_managed() return super().result() def exception(self): if self.done(): - self._managed = True + self._mark_managed() return super().exception() def add_done_callback(self, fn, *, context=None) -> None: @wraps(fn) def mark_managed(fut): - self._managed = True + self._mark_managed() return fn(fut) super().add_done_callback(mark_managed, context=context) @@ -135,7 +143,7 @@ def all_tasks(loop): i = 0 while True: try: - tasks = list(_all_tasks) + tasks = list(_all_tasks) + list(_unmanaged_tasks) except RuntimeError: # pragma: nocover i += 1 if i >= 1000: