Skip to content

Commit b5314a2

Browse files
authored
Add RetryingCaller and AsyncRetryingCaller (#56)
* Add RetryingCaller and AsyncRetryingCaller Implements #45 * Add PR link * Update CHANGELOG.md
1 parent 6ddcc7d commit b5314a2

File tree

8 files changed

+210
-2
lines changed

8 files changed

+210
-2
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/
1515

1616
## [Unreleased](https://github.com/hynek/stamina/compare/24.1.0...HEAD)
1717

18+
### Added
19+
20+
- `stamina.RetryingCaller` and `stamina.AsyncRetryingCaller` that allow even easier retries of single callables.
21+
[#56](https://github.com/hynek/stamina/pull/56)
22+
1823

1924
## [24.1.0](https://github.com/hynek/stamina/compare/23.3.0...24.1.0) - 2024-01-03
2025

docs/api.md

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
.. autofunction:: retry_context
88
.. autoclass:: Attempt
99
:members: num
10+
.. autoclass:: RetryingCaller
11+
.. autoclass:: AsyncRetryingCaller
1012
```
1113

1214

docs/tutorial.md

+15
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ for attempt in stamina.retry_context(on=httpx.HTTPError):
4646
resp.raise_for_status()
4747
```
4848

49+
If you want to retry just one function call, *stamina* comes with an even easier way in the shape of {class}`stamina.RetryingCaller` and {class}`stamina.AsyncRetryingCaller`:
50+
51+
```python
52+
def do_something_with_url(url, some_kw):
53+
resp = httpx.get(url)
54+
resp.raise_for_status()
55+
...
56+
57+
rc = stamina.RetryingCaller(on=httpx.HTTPError)
58+
59+
rc(do_something_with_url, f"https://httpbin.org/status/404", some_kw=42)
60+
```
61+
62+
The last line calls `do_something_with_url(f"https://httpbin.org/status/404", some_kw=42)` and retries on `httpx.HTTPError`.
63+
4964

5065
## Async
5166

src/stamina/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44

55
from . import instrumentation
66
from ._config import is_active, set_active
7-
from ._core import Attempt, retry, retry_context
7+
from ._core import (
8+
AsyncRetryingCaller,
9+
Attempt,
10+
RetryingCaller,
11+
retry,
12+
retry_context,
13+
)
814

915

1016
__all__ = [
@@ -14,6 +20,8 @@
1420
"is_active",
1521
"set_active",
1622
"instrumentation",
23+
"RetryingCaller",
24+
"AsyncRetryingCaller",
1725
]
1826

1927

src/stamina/_core.py

+90-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functools import wraps
1313
from inspect import iscoroutinefunction
1414
from types import TracebackType
15-
from typing import AsyncIterator, Iterator, TypeVar
15+
from typing import AsyncIterator, Awaitable, Iterator, TypedDict, TypeVar
1616

1717
import tenacity as _t
1818

@@ -126,6 +126,95 @@ def __exit__(
126126
)
127127

128128

129+
class RetryKWs(TypedDict):
130+
on: type[Exception] | tuple[type[Exception], ...]
131+
attempts: int | None
132+
timeout: float | dt.timedelta | None
133+
wait_initial: float | dt.timedelta
134+
wait_max: float | dt.timedelta
135+
wait_jitter: float | dt.timedelta
136+
wait_exp_base: float
137+
138+
139+
class BaseRetryingCaller:
140+
"""
141+
.. versionadded:: 24.2.0
142+
"""
143+
144+
__slots__ = ("_context_kws",)
145+
146+
_context_kws: RetryKWs
147+
148+
def __init__(
149+
self,
150+
on: type[Exception] | tuple[type[Exception], ...],
151+
attempts: int | None = 10,
152+
timeout: float | dt.timedelta | None = 45.0,
153+
wait_initial: float | dt.timedelta = 0.1,
154+
wait_max: float | dt.timedelta = 5.0,
155+
wait_jitter: float | dt.timedelta = 1.0,
156+
wait_exp_base: float = 2.0,
157+
):
158+
self._context_kws = {
159+
"on": on,
160+
"attempts": attempts,
161+
"timeout": timeout,
162+
"wait_initial": wait_initial,
163+
"wait_max": wait_max,
164+
"wait_jitter": wait_jitter,
165+
"wait_exp_base": wait_exp_base,
166+
}
167+
168+
def __repr__(self) -> str:
169+
on = guess_name(self._context_kws["on"])
170+
kws = ", ".join(
171+
f"{k}={self._context_kws[k]!r}" # type: ignore[literal-required]
172+
for k in sorted(self._context_kws)
173+
if k != "on"
174+
)
175+
return f"<{self.__class__.__name__}(on={on}, {kws})>"
176+
177+
178+
class RetryingCaller(BaseRetryingCaller):
179+
"""
180+
An object that will call your callable with retries.
181+
182+
Instances of `RetryingCaller` may be reused because they create a new
183+
:func:`retry_context` iterator on each call.
184+
185+
.. versionadded:: 24.2.0
186+
"""
187+
188+
def __call__(
189+
self, func: Callable[P, T], /, *args: P.args, **kw: P.kwargs
190+
) -> T:
191+
for attempt in retry_context(**self._context_kws):
192+
with attempt:
193+
return func(*args, **kw)
194+
195+
raise SystemError("unreachable") # pragma: no cover # noqa: EM101
196+
197+
198+
class AsyncRetryingCaller(BaseRetryingCaller):
199+
"""
200+
An object that will call your async callable with retries.
201+
202+
Instances of `AsyncRetryingCaller` may be reused because they create a new
203+
:func:`retry_context` iterator on each call.
204+
205+
.. versionadded:: 24.2.0
206+
"""
207+
208+
async def __call__(
209+
self, func: Callable[P, Awaitable[T]], /, *args: P.args, **kw: P.kwargs
210+
) -> T:
211+
async for attempt in retry_context(**self._context_kws):
212+
with attempt:
213+
return await func(*args, **kw)
214+
215+
raise SystemError("unreachable") # pragma: no cover # noqa: EM101
216+
217+
129218
_STOP_NO_RETRY = _t.stop_after_attempt(1)
130219

131220

tests/test_async.py

+24
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,27 @@ async def test_retry_blocks_can_be_disabled():
209209
raise Exception("passed")
210210

211211
assert 1 == num_called
212+
213+
214+
class TestAsyncRetryingCaller:
215+
async def test_retries(self):
216+
"""
217+
Retries if the specific error is raised. Arguments are passed through.
218+
"""
219+
i = 0
220+
221+
async def f(*args, **kw):
222+
nonlocal i
223+
if i < 1:
224+
i += 1
225+
raise ValueError
226+
227+
return args, kw
228+
229+
arc = stamina.AsyncRetryingCaller(on=ValueError)
230+
231+
args, kw = await arc(f, 42, foo="bar")
232+
233+
assert 1 == i
234+
assert (42,) == args
235+
assert {"foo": "bar"} == kw

tests/test_sync.py

+44
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,47 @@ def test_never(self):
172172
If all conditions are None, return stop_never.
173173
"""
174174
assert tenacity.stop_never is _make_stop(attempts=None, timeout=None)
175+
176+
177+
class TestRetryingCaller:
178+
def test_retries(self):
179+
"""
180+
Retries if the specific error is raised. Arguments are passed through.
181+
"""
182+
i = 0
183+
184+
def f(*args, **kw):
185+
nonlocal i
186+
if i < 1:
187+
i += 1
188+
raise ValueError
189+
190+
return args, kw
191+
192+
rc = stamina.RetryingCaller(on=ValueError)
193+
194+
args, kw = rc(f, 42, foo="bar")
195+
196+
assert 1 == i
197+
assert (42,) == args
198+
assert {"foo": "bar"} == kw
199+
200+
def test_repr(self):
201+
"""
202+
repr() is useful.
203+
"""
204+
rc = stamina.RetryingCaller(
205+
on=ValueError,
206+
attempts=42,
207+
timeout=13.0,
208+
wait_initial=23,
209+
wait_max=123,
210+
wait_jitter=0.42,
211+
wait_exp_base=666,
212+
)
213+
214+
assert (
215+
"<RetryingCaller(on=ValueError, attempts=42, timeout=13.0, "
216+
"wait_exp_base=666, wait_initial=23, wait_jitter=0.42, "
217+
"wait_max=123)>"
218+
) == repr(rc)

tests/typing/api.py

+21
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import datetime as dt
1313

1414
from stamina import (
15+
AsyncRetryingCaller,
16+
RetryingCaller,
1517
is_active,
1618
retry,
1719
retry_context,
@@ -125,3 +127,22 @@ async def f() -> None:
125127
):
126128
with attempt:
127129
pass
130+
131+
132+
def sync_f(x: int, foo: str) -> bool:
133+
return True
134+
135+
136+
rc = RetryingCaller(on=ValueError, timeout=13.0, attempts=10)
137+
b: bool = rc(sync_f, 1, foo="bar")
138+
139+
140+
async def async_f(x: int, foo: str) -> bool:
141+
return True
142+
143+
144+
arc = AsyncRetryingCaller(on=ValueError, timeout=13.0, attempts=10)
145+
146+
147+
async def g() -> bool:
148+
return await arc(async_f, 1, foo="bar")

0 commit comments

Comments
 (0)