Skip to content

Commit

Permalink
Reapply cinder compatibility with safety this time
Browse files Browse the repository at this point in the history
Summary:
Cinder runtime's tasks.all_tasks has some mechanism where done() tasks don't show up in all_tasks but we need them to check managed tasks so instead we track our own tasks in our task class for tests.
This seems to make all the runtimes happy.

Reviewed By: itamaro

Differential Revision: D51713363

fbshipit-source-id: 3dc8007b615bb2e92a6e67cc041f2c0a916c172e
  • Loading branch information
fried authored and facebook-github-bot committed Nov 30, 2023
1 parent 475d2db commit 9b8dfb5
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions later/unittest/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import reprlib
import sys
import unittest.mock as mock
import weakref
from functools import wraps
from typing import Any, Callable, TypeVar

Expand All @@ -41,6 +42,7 @@
_IGNORE_TASK_LEAKS_ATTR = "__later_testcase_ignore_tasks__"
_IGNORE_AIO_ERRS_ATTR = "__later_testcase_ignore_asyncio__"
atleastpy38: bool = sys.version_info[:2] >= (3, 8)
_unmanaged_tasks: weakref.WeakSet[asyncio.Task] = weakref.WeakSet()


class TestTask(asyncio.Task):
Expand All @@ -50,6 +52,7 @@ class TestTask(asyncio.Task):
def __init__(self, coro, *args, **kws) -> None:
# pyre-fixme[16]: Module `coroutines` has no attribute `_format_coroutine`.
self._coro_repr = asyncio.coroutines._format_coroutine(coro)
_unmanaged_tasks.add(self)
super().__init__(coro, *args, **kws)

@reprlib.recursive_repr()
Expand All @@ -63,24 +66,29 @@ def __repr__(self) -> str:
repr_info[1] = coro # pragma: nocover
return f"<{self.__class__.__name__} {' '.join(repr_info)}>"

def _mark_managed(self):
if not self._managed:
self._managed = True
_unmanaged_tasks.discard(self)

def __await__(self):
self._managed = True
self._mark_managed()
return super().__await__()

def result(self):
if self.done():
self._managed = True
self._mark_managed()
return super().result()

def exception(self):
if self.done():
self._managed = True
self._mark_managed()
return super().exception()

def add_done_callback(self, fn, *, context=None) -> None:
@wraps(fn)
def mark_managed(fut):
self._managed = True
self._mark_managed()
return fn(fut)

super().add_done_callback(mark_managed, context=context)
Expand All @@ -89,25 +97,20 @@ def was_managed(self) -> bool:
if self._managed:
return True
# If the task is done() and the result is None, let it pass as managed
# We use super here so we don't manage ourselves.
return (
self.done()
and not self.cancelled()
and not self.exception()
and self.result() is None
and not super().exception()
and super().result() is None
)

def __del__(self) -> None:
# So a pattern is to create_task, and not save the results.
# we accept that as long as there was no result other than None
# thrift-py3 uses this pattern to call rpc methods in ServiceInterfaces
# where any result/execption is returned to the remote client.
managed = self.was_managed()
if not managed and not (
self.done()
and not self.cancelled()
and not self.exception()
and self.result() is None
):
if not self.was_managed():
context = {
"task": self,
"message": (
Expand Down Expand Up @@ -138,7 +141,7 @@ def all_tasks(loop):
i = 0
while True:
try:
tasks = list(_all_tasks)
tasks = list(_all_tasks) + list(_unmanaged_tasks)
except RuntimeError: # pragma: nocover
i += 1
if i >= 1000:
Expand Down

0 comments on commit 9b8dfb5

Please sign in to comment.