From 31fd94a55a2a6aa0740d14e54a2633e5bc933907 Mon Sep 17 00:00:00 2001 From: Jason Fried Date: Wed, 29 Nov 2023 18:23:59 -0800 Subject: [PATCH] include herd() decorator Summary: While working on later figured its been a while since we add to this toolkit so including the herd() decorator for controlling thundering herd of tasks. Reviewed By: itamaro Differential Revision: D51688031 fbshipit-source-id: 5c65b6c1fdf3d3cb9c8c38d9d445b6799321bbbb --- later/__init__.py | 5 +- later/task.py | 133 ++++++++++++++++++++++++++++++++++++++- later/tests/test_task.py | 113 +++++++++++++++++++++++++++++++++ 3 files changed, 248 insertions(+), 3 deletions(-) diff --git a/later/__init__.py b/later/__init__.py index 0b1fb38..2ba38f0 100644 --- a/later/__init__.py +++ b/later/__init__.py @@ -15,10 +15,10 @@ from async_timeout import timeout from .event import BiDirectionalEvent -from .task import as_task, cancel, START_TASK, Watcher, WatcherError +from .task import as_task, cancel, herd, START_TASK, Watcher, WatcherError -__version__ = "23.04.1" +__version__ = "23.11.29" __all__ = [ "BiDirectionalEvent", "START_TASK", @@ -27,4 +27,5 @@ "cancel", "timeout", "as_task", + "herd", ] diff --git a/later/task.py b/later/task.py index 3f312cb..9b98eb8 100644 --- a/later/task.py +++ b/later/task.py @@ -15,21 +15,29 @@ import asyncio import contextvars +import functools import logging +import threading from collections.abc import Coroutine, Generator from functools import partial, wraps from inspect import isawaitable from types import TracebackType from typing import ( + AbstractSet, Any, Awaitable, Callable, cast, Dict, + Hashable, List, + Mapping, + NewType, Optional, + overload, Sequence, + Tuple, Type, TypeVar, Union, @@ -44,7 +52,14 @@ T = TypeVar("T") F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) -__all__: Sequence[str] = ["Watcher", "START_TASK", "TaskSentinel", "cancel", "as_task"] +__all__: Sequence[str] = [ + "Watcher", + "START_TASK", + "TaskSentinel", + "cancel", + "as_task", + "herd", +] class TaskSentinel(asyncio.Task): @@ -391,3 +406,119 @@ async def _handle_cancel(self): raise WatcherError( "The following tasks didn't cancel cleanly or at all!", bad_tasks ) + + +CacheKey = NewType("CacheKey", Sequence[Hashable]) +ArgID = Union[int, str] + + +class _CountTask: + """So herd can track herd size and task together for cancellation""" + + task: Optional[asyncio.Task] = None + count: int = 0 + + +def _get_local(local: threading.local, field: str) -> Dict[CacheKey, object]: + """ + helper for attempting to fetch a named attr from a threading.local + """ + try: + return cast(Dict[CacheKey, object], getattr(local, field)) + except AttributeError: + container: Dict[CacheKey, object] = {} + setattr(local, field, container) + return container + + +def _build_key( + args: Tuple[object, ...], + kwargs: Mapping[str, object], + ignored_args: Optional[AbstractSet[ArgID]] = None, +) -> CacheKey: + """ + Build a key for caching Hashable args and kwargs. + Allow for not including certain fields from args or kwargs + """ + if not ignored_args: + # pyre-fixme[45]: Cannot instantiate abstract class `CacheKey`. + return CacheKey((args, tuple(sorted(kwargs.items())))) + + # If we do want to ignore something then do so + # pyre-fixme[45]: Cannot instantiate abstract class `CacheKey`. + return CacheKey( + ( + tuple((value for idx, value in enumerate(args) if idx not in ignored_args)), + tuple( + (item for item in sorted(kwargs.items()) if item[0] not in ignored_args) + ), + ) + ) + + +@overload # noqa: 811 +def herd( + fn: F, *, ignored_args: Optional[AbstractSet[ArgID]] = None +) -> F: # pragma: nocover + ... + + +@overload # noqa: 811 +def herd( + fn: Optional[F] = None, *, ignored_args: Optional[AbstractSet[ArgID]] = None +) -> Callable[[F], F]: # pragma: nocover + ... + + +def herd( + fn=None, + *, + ignored_args: Optional[AbstractSet[ArgID]] = None, +): # noqa: 811 + """ + Provide a simple thundering herd protection as a decorator. + if requests comes in while and existing request with those same args is pending, + wait for the pending request and return its results. + + ignored_args are arguments that should be ignored for matching with + existing requests. Use arg position or kwargs name. + Example: a client arg for when multiple clients exists but the request hits the same + backend. + + Each member of the herd is "shielded" from cancellation effecting other herd members + """ + + def decorator(fn: F) -> F: + local = threading.local() + + @functools.wraps(fn) + async def wrapped(*args, **kwargs): + pending = cast(Dict[CacheKey, _CountTask], _get_local(local, "pending")) + request = _build_key(tuple(args), kwargs, ignored_args) + count_task = pending.setdefault(request, _CountTask()) + count_task.count += 1 + task = count_task.task # thanks pyre + if task is None: + count_task.task = task = asyncio.create_task(fn(*args, **kwargs)) + try: + return await asyncio.shield(task) + except asyncio.CancelledError: + if count_task.count == 1: + await cancel(task) + raise # always re-raise CancelledError + finally: + count_task.count -= 1 + # Lets destroy the herd on last member exit or + # First success member exit. This is to mirror the original + # herd behavior that tore down the herd after the original call exited + if count_task.count == 0 or not task.cancelled(): + if request in pending and pending[request] is count_task: + del pending[request] + + return cast(F, wrapped) + + if fn and callable(fn): + # pyre-fixme[6]: For 1st param expected `F` but got `(...) -> object`. + return decorator(fn) + + return decorator diff --git a/later/tests/test_task.py b/later/tests/test_task.py index 7387b36..ca98185 100644 --- a/later/tests/test_task.py +++ b/later/tests/test_task.py @@ -358,3 +358,116 @@ async def test_watch_with_shield(self) -> None: watcher.cancel() await task + + +class HerdTests(TestCase): + async def test_herd_cancellation(self) -> None: + called = 0 + original_cancelled = False + + @later.herd + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + async def fun(event): + nonlocal called + nonlocal original_cancelled + called += 1 + try: + await event.wait() + except asyncio.CancelledError: + original_cancelled = True + raise + + event = asyncio.Event() + call1 = asyncio.create_task(fun(event)) + call2 = asyncio.create_task(fun(event)) + + done, pending = await asyncio.wait({call1, call2}, timeout=0.05) + self.assertFalse(done) + self.assertEqual(called, 1) + self.assertFalse(original_cancelled) + + await later.cancel(call1) + self.assertFalse(original_cancelled) + + # even with the first call cancelled call2 can continue + done, pending = await asyncio.wait([call2], timeout=0.05) + self.assertFalse(done) + + # Only after there is only one pending left do we allow the original task + # to be cancelled. + await later.cancel(call2) + self.assertTrue(original_cancelled) + + async def test_herd(self) -> None: + called = 0 + waited = 0 + + @later.herd(ignored_args={1}) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + async def fun(event, ignored_arg): + nonlocal called + called += 1 + await event.wait() + nonlocal waited + waited += 1 + + event = asyncio.Event() + call1 = asyncio.create_task(fun(event, object())) + call2 = asyncio.create_task(fun(event, object())) + call3 = asyncio.create_task(fun(event, object())) + + done, pending = await asyncio.wait({call1, call2, call3}, timeout=0.05) + self.assertFalse(done) + self.assertEqual(called, 1) + self.assertEqual(waited, 0) + + event.set() + done, pending = await asyncio.wait({call1, call2, call3}) + self.assertTrue(done) + self.assertFalse(pending) + self.assertEqual(called, 1) + self.assertEqual(waited, 1) + + async def test_herd_exception(self) -> None: + called = 0 + + @later.herd + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + async def fun(event): + nonlocal called + called += 1 + await event.wait() + raise RuntimeError + + event = asyncio.Event() + call1 = asyncio.create_task(fun(event)) + call2 = asyncio.create_task(fun(event)) + + done, pending = await asyncio.wait({call1, call2}, timeout=0.05) + self.assertFalse(done) + self.assertEqual(called, 1) + + event.set() + done, pending = await asyncio.wait({call1, call2}) + self.assertTrue(done) + self.assertFalse(pending) + self.assertEqual(called, 1) + # access the exception so we don't blow up the test + with self.assertRaises(RuntimeError): + await call1 + + with self.assertRaises(RuntimeError): + await call2 + + def test_herd_usage_no_args(self) -> None: + @later.herd + async def hax(test: int) -> None: + """docstring""" + ... + + # Testing wraps + self.assertEqual(hax.__name__, "hax") + self.assertEqual(hax.__doc__, "docstring")