Skip to content

Commit 4307da6

Browse files
committed
Allow instrumentation hooks to be contextmanagers
fixes #84
1 parent 51b8072 commit 4307da6

File tree

6 files changed

+118
-9
lines changed

6 files changed

+118
-9
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/
2525
- `stamina.set_testing()` can now be used as a context manager.
2626
[#94](https://github.com/hynek/stamina/pull/94)
2727

28+
- Instrumentation hooks can now can return context managers.
29+
If they do, they are entered before going to sleep and exited after waking up.
30+
2831

2932
## [24.3.0](https://github.com/hynek/stamina/compare/24.2.0...24.3.0) - 2024-08-27
3033

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727
dependencies = ["tenacity", "typing-extensions; python_version < '3.10'"]
2828

2929
[project.optional-dependencies]
30-
tests = ["pytest", "anyio"]
30+
tests = ["pytest", "anyio", "dirty-equals"]
3131
typing = ["mypy >= 1.4"]
3232
docs = [
3333
"sphinx >= 7.2.2",

src/stamina/_core.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
from __future__ import annotations
66

7+
import contextlib
78
import datetime as dt
89
import random
910
import sys
1011

12+
from contextlib import AbstractContextManager
1113
from dataclasses import dataclass, replace
1214
from functools import wraps
1315
from inspect import iscoroutinefunction
@@ -400,6 +402,7 @@ class _RetryContextIterator:
400402
__slots__ = (
401403
"_args",
402404
"_attempts",
405+
"_cms_to_exit",
403406
"_kw",
404407
"_name",
405408
"_t_a_retrying",
@@ -421,6 +424,8 @@ class _RetryContextIterator:
421424
_wait_max: float
422425
_wait_exp_base: float
423426

427+
_cms_to_exit: list[AbstractContextManager[None]]
428+
424429
@classmethod
425430
def from_params(
426431
cls,
@@ -473,6 +478,7 @@ def from_params(
473478
"reraise": True,
474479
},
475480
_t_a_retrying=_LAZY_NO_ASYNC_RETRY,
481+
_cms_to_exit=[],
476482
)
477483

478484
inst._t_kw["wait"] = inst._jittered_backoff_for_rcs
@@ -501,17 +507,24 @@ def _apply_maybe_test_mode_to_tenacity_kw(
501507

502508
return t_kw
503509

510+
def _exit_cms(self, _: _t.RetryCallState | None) -> None:
511+
for cm in reversed(self._cms_to_exit):
512+
cm.__exit__(None, None, None)
513+
504514
def __iter__(self) -> Iterator[Attempt]:
505515
if not CONFIG.is_active:
506516
for r in _t.Retrying(reraise=True, stop=_STOP_NO_RETRY):
507517
yield Attempt(r, None)
508518

509519
return
510520

521+
before_sleep = _make_before_sleep(
522+
self._name, CONFIG, self._args, self._kw, self._cms_to_exit
523+
)
524+
511525
for r in _t.Retrying(
512-
before_sleep=_make_before_sleep(
513-
self._name, CONFIG, self._args, self._kw
514-
),
526+
before=self._exit_cms,
527+
before_sleep=before_sleep,
515528
**self._apply_maybe_test_mode_to_tenacity_kw(CONFIG.testing),
516529
):
517530
yield Attempt(r, self._backoff_for_attempt_number)
@@ -520,8 +533,9 @@ def __aiter__(self) -> AsyncIterator[Attempt]:
520533
if CONFIG.is_active:
521534
self._t_a_retrying = _t.AsyncRetrying(
522535
sleep=_smart_sleep,
536+
before=self._exit_cms,
523537
before_sleep=_make_before_sleep(
524-
self._name, CONFIG, self._args, self._kw
538+
self._name, CONFIG, self._args, self._kw, self._cms_to_exit
525539
),
526540
**self._apply_maybe_test_mode_to_tenacity_kw(CONFIG.testing),
527541
)
@@ -583,10 +597,15 @@ def _make_before_sleep(
583597
config: _Config,
584598
args: tuple[object, ...],
585599
kw: dict[str, object],
600+
hook_cms: list[contextlib.AbstractContextManager[None]],
586601
) -> Callable[[_t.RetryCallState], None]:
587602
"""
588603
Create a `before_sleep` callback function that runs our `RetryHook`s with
589604
the necessary arguments.
605+
606+
If a hook returns a context manager, it's entered before retries start and
607+
exited after they finish by keeping track of the context managers in
608+
*hook_cms*.
590609
"""
591610

592611
last_idle_for = 0.0
@@ -607,7 +626,11 @@ def before_sleep(rcs: _t.RetryCallState) -> None:
607626
)
608627

609628
for hook in config.on_retry:
610-
hook(details)
629+
maybe_cm = hook(details)
630+
631+
if isinstance(maybe_cm, AbstractContextManager):
632+
maybe_cm.__enter__()
633+
hook_cms.append(maybe_cm)
611634

612635
last_idle_for = rcs.idle_for
613636

src/stamina/instrumentation/_data.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
from contextlib import AbstractContextManager
78
from dataclasses import dataclass
89
from typing import Callable, Protocol
910

@@ -75,7 +76,9 @@ class RetryHook(Protocol):
7576
.. versionadded:: 23.2.0
7677
"""
7778

78-
def __call__(self, details: RetryDetails) -> None: ...
79+
def __call__(
80+
self, details: RetryDetails
81+
) -> None | AbstractContextManager[None]: ...
7982

8083

8184
@dataclass(frozen=True)

tests/test_instrumentation.py

+73-1
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,20 @@
44

55
from __future__ import annotations
66

7+
from contextlib import contextmanager
8+
79
import pytest
810

11+
from dirty_equals import IsInstance
12+
913
import stamina
1014

1115
from stamina.instrumentation import (
1216
RetryHookFactory,
1317
get_on_retry_hooks,
1418
set_on_retry_hooks,
1519
)
16-
from stamina.instrumentation._data import guess_name
20+
from stamina.instrumentation._data import RetryDetails, guess_name
1721
from stamina.instrumentation._hooks import get_default_hooks
1822
from stamina.instrumentation._structlog import init_structlog
1923

@@ -191,3 +195,71 @@ def init():
191195
hook,
192196
delayed_hook,
193197
) == get_on_retry_hooks()
198+
199+
def test_context_manager_hooks(self):
200+
"""
201+
If a hook is a context manager, it's entered before retries start and exited
202+
after they finish.
203+
"""
204+
entered = False
205+
exited = False
206+
deets = []
207+
208+
@contextmanager
209+
def cm1(details):
210+
nonlocal entered
211+
entered = True
212+
213+
deets.append(details)
214+
yield
215+
216+
nonlocal exited
217+
exited = True
218+
219+
class CM:
220+
def __init__(self):
221+
self.entered = False
222+
self.exited = False
223+
self.deets = []
224+
225+
def __call__(self, details):
226+
self.deets.append(details)
227+
return self
228+
229+
def __enter__(self):
230+
self.entered = True
231+
232+
def __exit__(self, *_):
233+
self.exited = True
234+
235+
cm2 = CM()
236+
237+
set_on_retry_hooks([cm1, cm2])
238+
239+
@stamina.retry(on=ValueError, wait_max=0, attempts=2)
240+
def f():
241+
raise ValueError
242+
243+
with pytest.raises(ValueError):
244+
f()
245+
246+
assert entered
247+
assert exited
248+
assert cm2.entered
249+
assert cm2.exited
250+
251+
assert (
252+
[
253+
RetryDetails(
254+
name="tests.test_instrumentation.TestSetOnRetryHooks.test_context_manager_hooks.<locals>.f",
255+
args=(),
256+
kwargs={},
257+
retry_num=1,
258+
wait_for=0.0,
259+
waited_so_far=0.0,
260+
caused_by=IsInstance(ValueError),
261+
)
262+
]
263+
== cm2.deets
264+
== deets
265+
)

tests/typing/api.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
import datetime as dt
1313

14+
from collections.abc import Generator
15+
from contextlib import contextmanager
16+
1417
from stamina import (
1518
AsyncRetryingCaller,
1619
BoundAsyncRetryingCaller,
@@ -109,11 +112,16 @@ def hook(details: RetryDetails) -> None:
109112
return None
110113

111114

115+
@contextmanager
116+
def cm_hook(details: RetryDetails) -> Generator[None]:
117+
yield
118+
119+
112120
def init() -> RetryHook:
113121
return hook
114122

115123

116-
set_on_retry_hooks([hook, RetryHookFactory(init)])
124+
set_on_retry_hooks([hook, RetryHookFactory(init), cm_hook])
117125

118126
hooks: tuple[RetryHook, ...] = get_on_retry_hooks()
119127

0 commit comments

Comments
 (0)