4
4
5
5
from __future__ import annotations
6
6
7
+ import contextlib
7
8
import datetime as dt
8
9
import random
9
10
import sys
10
11
12
+ from contextlib import AbstractContextManager
11
13
from dataclasses import dataclass , replace
12
14
from functools import wraps
13
15
from inspect import iscoroutinefunction
@@ -400,6 +402,7 @@ class _RetryContextIterator:
400
402
__slots__ = (
401
403
"_args" ,
402
404
"_attempts" ,
405
+ "_cms_to_exit" ,
403
406
"_kw" ,
404
407
"_name" ,
405
408
"_t_a_retrying" ,
@@ -421,6 +424,8 @@ class _RetryContextIterator:
421
424
_wait_max : float
422
425
_wait_exp_base : float
423
426
427
+ _cms_to_exit : list [AbstractContextManager [None ]]
428
+
424
429
@classmethod
425
430
def from_params (
426
431
cls ,
@@ -473,6 +478,7 @@ def from_params(
473
478
"reraise" : True ,
474
479
},
475
480
_t_a_retrying = _LAZY_NO_ASYNC_RETRY ,
481
+ _cms_to_exit = [],
476
482
)
477
483
478
484
inst ._t_kw ["wait" ] = inst ._jittered_backoff_for_rcs
@@ -501,17 +507,24 @@ def _apply_maybe_test_mode_to_tenacity_kw(
501
507
502
508
return t_kw
503
509
510
+ def _exit_cms (self , _ : _t .RetryCallState | None ) -> None :
511
+ for cm in reversed (self ._cms_to_exit ):
512
+ cm .__exit__ (None , None , None )
513
+
504
514
def __iter__ (self ) -> Iterator [Attempt ]:
505
515
if not CONFIG .is_active :
506
516
for r in _t .Retrying (reraise = True , stop = _STOP_NO_RETRY ):
507
517
yield Attempt (r , None )
508
518
509
519
return
510
520
521
+ before_sleep = _make_before_sleep (
522
+ self ._name , CONFIG , self ._args , self ._kw , self ._cms_to_exit
523
+ )
524
+
511
525
for r in _t .Retrying (
512
- before_sleep = _make_before_sleep (
513
- self ._name , CONFIG , self ._args , self ._kw
514
- ),
526
+ before = self ._exit_cms ,
527
+ before_sleep = before_sleep ,
515
528
** self ._apply_maybe_test_mode_to_tenacity_kw (CONFIG .testing ),
516
529
):
517
530
yield Attempt (r , self ._backoff_for_attempt_number )
@@ -520,8 +533,9 @@ def __aiter__(self) -> AsyncIterator[Attempt]:
520
533
if CONFIG .is_active :
521
534
self ._t_a_retrying = _t .AsyncRetrying (
522
535
sleep = _smart_sleep ,
536
+ before = self ._exit_cms ,
523
537
before_sleep = _make_before_sleep (
524
- self ._name , CONFIG , self ._args , self ._kw
538
+ self ._name , CONFIG , self ._args , self ._kw , self . _cms_to_exit
525
539
),
526
540
** self ._apply_maybe_test_mode_to_tenacity_kw (CONFIG .testing ),
527
541
)
@@ -583,10 +597,15 @@ def _make_before_sleep(
583
597
config : _Config ,
584
598
args : tuple [object , ...],
585
599
kw : dict [str , object ],
600
+ hook_cms : list [contextlib .AbstractContextManager [None ]],
586
601
) -> Callable [[_t .RetryCallState ], None ]:
587
602
"""
588
603
Create a `before_sleep` callback function that runs our `RetryHook`s with
589
604
the necessary arguments.
605
+
606
+ If a hook returns a context manager, it's entered before retries start and
607
+ exited after they finish by keeping track of the context managers in
608
+ *hook_cms*.
590
609
"""
591
610
592
611
last_idle_for = 0.0
@@ -607,7 +626,11 @@ def before_sleep(rcs: _t.RetryCallState) -> None:
607
626
)
608
627
609
628
for hook in config .on_retry :
610
- hook (details )
629
+ maybe_cm = hook (details )
630
+
631
+ if isinstance (maybe_cm , AbstractContextManager ):
632
+ maybe_cm .__enter__ ()
633
+ hook_cms .append (maybe_cm )
611
634
612
635
last_idle_for = rcs .idle_for
613
636
0 commit comments