Skip to content

Commit

Permalink
include herd() decorator
Browse files Browse the repository at this point in the history
Summary: While working on later figured its been a while since we add to this toolkit so including the herd() decorator for controlling thundering herd of tasks.

Reviewed By: itamaro

Differential Revision: D51688031

fbshipit-source-id: 5c65b6c1fdf3d3cb9c8c38d9d445b6799321bbbb
  • Loading branch information
fried authored and facebook-github-bot committed Nov 30, 2023
1 parent c4e6cd9 commit 31fd94a
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 3 deletions.
5 changes: 3 additions & 2 deletions later/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from async_timeout import timeout

from .event import BiDirectionalEvent
from .task import as_task, cancel, START_TASK, Watcher, WatcherError
from .task import as_task, cancel, herd, START_TASK, Watcher, WatcherError


__version__ = "23.04.1"
__version__ = "23.11.29"
__all__ = [
"BiDirectionalEvent",
"START_TASK",
Expand All @@ -27,4 +27,5 @@
"cancel",
"timeout",
"as_task",
"herd",
]
133 changes: 132 additions & 1 deletion later/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,29 @@

import asyncio
import contextvars
import functools
import logging
import threading

from collections.abc import Coroutine, Generator
from functools import partial, wraps
from inspect import isawaitable
from types import TracebackType
from typing import (
AbstractSet,
Any,
Awaitable,
Callable,
cast,
Dict,
Hashable,
List,
Mapping,
NewType,
Optional,
overload,
Sequence,
Tuple,
Type,
TypeVar,
Union,
Expand All @@ -44,7 +52,14 @@
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])

__all__: Sequence[str] = ["Watcher", "START_TASK", "TaskSentinel", "cancel", "as_task"]
__all__: Sequence[str] = [
"Watcher",
"START_TASK",
"TaskSentinel",
"cancel",
"as_task",
"herd",
]


class TaskSentinel(asyncio.Task):
Expand Down Expand Up @@ -391,3 +406,119 @@ async def _handle_cancel(self):
raise WatcherError(
"The following tasks didn't cancel cleanly or at all!", bad_tasks
)


CacheKey = NewType("CacheKey", Sequence[Hashable])
ArgID = Union[int, str]


class _CountTask:
"""So herd can track herd size and task together for cancellation"""

task: Optional[asyncio.Task] = None
count: int = 0


def _get_local(local: threading.local, field: str) -> Dict[CacheKey, object]:
"""
helper for attempting to fetch a named attr from a threading.local
"""
try:
return cast(Dict[CacheKey, object], getattr(local, field))
except AttributeError:
container: Dict[CacheKey, object] = {}
setattr(local, field, container)
return container


def _build_key(
args: Tuple[object, ...],
kwargs: Mapping[str, object],
ignored_args: Optional[AbstractSet[ArgID]] = None,
) -> CacheKey:
"""
Build a key for caching Hashable args and kwargs.
Allow for not including certain fields from args or kwargs
"""
if not ignored_args:
# pyre-fixme[45]: Cannot instantiate abstract class `CacheKey`.
return CacheKey((args, tuple(sorted(kwargs.items()))))

# If we do want to ignore something then do so
# pyre-fixme[45]: Cannot instantiate abstract class `CacheKey`.
return CacheKey(
(
tuple((value for idx, value in enumerate(args) if idx not in ignored_args)),
tuple(
(item for item in sorted(kwargs.items()) if item[0] not in ignored_args)
),
)
)


@overload # noqa: 811
def herd(
fn: F, *, ignored_args: Optional[AbstractSet[ArgID]] = None
) -> F: # pragma: nocover
...


@overload # noqa: 811
def herd(
fn: Optional[F] = None, *, ignored_args: Optional[AbstractSet[ArgID]] = None
) -> Callable[[F], F]: # pragma: nocover
...


def herd(
fn=None,
*,
ignored_args: Optional[AbstractSet[ArgID]] = None,
): # noqa: 811
"""
Provide a simple thundering herd protection as a decorator.
if requests comes in while and existing request with those same args is pending,
wait for the pending request and return its results.
ignored_args are arguments that should be ignored for matching with
existing requests. Use arg position or kwargs name.
Example: a client arg for when multiple clients exists but the request hits the same
backend.
Each member of the herd is "shielded" from cancellation effecting other herd members
"""

def decorator(fn: F) -> F:
local = threading.local()

@functools.wraps(fn)
async def wrapped(*args, **kwargs):
pending = cast(Dict[CacheKey, _CountTask], _get_local(local, "pending"))
request = _build_key(tuple(args), kwargs, ignored_args)
count_task = pending.setdefault(request, _CountTask())
count_task.count += 1
task = count_task.task # thanks pyre
if task is None:
count_task.task = task = asyncio.create_task(fn(*args, **kwargs))
try:
return await asyncio.shield(task)
except asyncio.CancelledError:
if count_task.count == 1:
await cancel(task)
raise # always re-raise CancelledError
finally:
count_task.count -= 1
# Lets destroy the herd on last member exit or
# First success member exit. This is to mirror the original
# herd behavior that tore down the herd after the original call exited
if count_task.count == 0 or not task.cancelled():
if request in pending and pending[request] is count_task:
del pending[request]

return cast(F, wrapped)

if fn and callable(fn):
# pyre-fixme[6]: For 1st param expected `F` but got `(...) -> object`.
return decorator(fn)

return decorator
113 changes: 113 additions & 0 deletions later/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,116 @@ async def test_watch_with_shield(self) -> None:
watcher.cancel()

await task


class HerdTests(TestCase):
async def test_herd_cancellation(self) -> None:
called = 0
original_cancelled = False

@later.herd
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
async def fun(event):
nonlocal called
nonlocal original_cancelled
called += 1
try:
await event.wait()
except asyncio.CancelledError:
original_cancelled = True
raise

event = asyncio.Event()
call1 = asyncio.create_task(fun(event))
call2 = asyncio.create_task(fun(event))

done, pending = await asyncio.wait({call1, call2}, timeout=0.05)
self.assertFalse(done)
self.assertEqual(called, 1)
self.assertFalse(original_cancelled)

await later.cancel(call1)
self.assertFalse(original_cancelled)

# even with the first call cancelled call2 can continue
done, pending = await asyncio.wait([call2], timeout=0.05)
self.assertFalse(done)

# Only after there is only one pending left do we allow the original task
# to be cancelled.
await later.cancel(call2)
self.assertTrue(original_cancelled)

async def test_herd(self) -> None:
called = 0
waited = 0

@later.herd(ignored_args={1})
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
async def fun(event, ignored_arg):
nonlocal called
called += 1
await event.wait()
nonlocal waited
waited += 1

event = asyncio.Event()
call1 = asyncio.create_task(fun(event, object()))
call2 = asyncio.create_task(fun(event, object()))
call3 = asyncio.create_task(fun(event, object()))

done, pending = await asyncio.wait({call1, call2, call3}, timeout=0.05)
self.assertFalse(done)
self.assertEqual(called, 1)
self.assertEqual(waited, 0)

event.set()
done, pending = await asyncio.wait({call1, call2, call3})
self.assertTrue(done)
self.assertFalse(pending)
self.assertEqual(called, 1)
self.assertEqual(waited, 1)

async def test_herd_exception(self) -> None:
called = 0

@later.herd
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
async def fun(event):
nonlocal called
called += 1
await event.wait()
raise RuntimeError

event = asyncio.Event()
call1 = asyncio.create_task(fun(event))
call2 = asyncio.create_task(fun(event))

done, pending = await asyncio.wait({call1, call2}, timeout=0.05)
self.assertFalse(done)
self.assertEqual(called, 1)

event.set()
done, pending = await asyncio.wait({call1, call2})
self.assertTrue(done)
self.assertFalse(pending)
self.assertEqual(called, 1)
# access the exception so we don't blow up the test
with self.assertRaises(RuntimeError):
await call1

with self.assertRaises(RuntimeError):
await call2

def test_herd_usage_no_args(self) -> None:
@later.herd
async def hax(test: int) -> None:
"""docstring"""
...

# Testing wraps
self.assertEqual(hax.__name__, "hax")
self.assertEqual(hax.__doc__, "docstring")

0 comments on commit 31fd94a

Please sign in to comment.