|
12 | 12 | from functools import wraps
|
13 | 13 | from inspect import iscoroutinefunction
|
14 | 14 | from types import TracebackType
|
15 |
| -from typing import AsyncIterator, Iterator, TypeVar |
| 15 | +from typing import AsyncIterator, Awaitable, Iterator, TypedDict, TypeVar |
16 | 16 |
|
17 | 17 | import tenacity as _t
|
18 | 18 |
|
@@ -126,6 +126,95 @@ def __exit__(
|
126 | 126 | )
|
127 | 127 |
|
128 | 128 |
|
| 129 | +class RetryKWs(TypedDict): |
| 130 | + on: type[Exception] | tuple[type[Exception], ...] |
| 131 | + attempts: int | None |
| 132 | + timeout: float | dt.timedelta | None |
| 133 | + wait_initial: float | dt.timedelta |
| 134 | + wait_max: float | dt.timedelta |
| 135 | + wait_jitter: float | dt.timedelta |
| 136 | + wait_exp_base: float |
| 137 | + |
| 138 | + |
| 139 | +class BaseRetryingCaller: |
| 140 | + """ |
| 141 | + .. versionadded:: 24.2.0 |
| 142 | + """ |
| 143 | + |
| 144 | + __slots__ = ("_context_kws",) |
| 145 | + |
| 146 | + _context_kws: RetryKWs |
| 147 | + |
| 148 | + def __init__( |
| 149 | + self, |
| 150 | + on: type[Exception] | tuple[type[Exception], ...], |
| 151 | + attempts: int | None = 10, |
| 152 | + timeout: float | dt.timedelta | None = 45.0, |
| 153 | + wait_initial: float | dt.timedelta = 0.1, |
| 154 | + wait_max: float | dt.timedelta = 5.0, |
| 155 | + wait_jitter: float | dt.timedelta = 1.0, |
| 156 | + wait_exp_base: float = 2.0, |
| 157 | + ): |
| 158 | + self._context_kws = { |
| 159 | + "on": on, |
| 160 | + "attempts": attempts, |
| 161 | + "timeout": timeout, |
| 162 | + "wait_initial": wait_initial, |
| 163 | + "wait_max": wait_max, |
| 164 | + "wait_jitter": wait_jitter, |
| 165 | + "wait_exp_base": wait_exp_base, |
| 166 | + } |
| 167 | + |
| 168 | + def __repr__(self) -> str: |
| 169 | + on = guess_name(self._context_kws["on"]) |
| 170 | + kws = ", ".join( |
| 171 | + f"{k}={self._context_kws[k]!r}" # type: ignore[literal-required] |
| 172 | + for k in sorted(self._context_kws) |
| 173 | + if k != "on" |
| 174 | + ) |
| 175 | + return f"<{self.__class__.__name__}(on={on}, {kws})>" |
| 176 | + |
| 177 | + |
| 178 | +class RetryingCaller(BaseRetryingCaller): |
| 179 | + """ |
| 180 | + An object that will call your callable with retries. |
| 181 | +
|
| 182 | + Instances of `RetryingCaller` may be reused because they create a new |
| 183 | + :func:`retry_context` iterator on each call. |
| 184 | +
|
| 185 | + .. versionadded:: 24.2.0 |
| 186 | + """ |
| 187 | + |
| 188 | + def __call__( |
| 189 | + self, func: Callable[P, T], /, *args: P.args, **kw: P.kwargs |
| 190 | + ) -> T: |
| 191 | + for attempt in retry_context(**self._context_kws): |
| 192 | + with attempt: |
| 193 | + return func(*args, **kw) |
| 194 | + |
| 195 | + raise SystemError("unreachable") # pragma: no cover # noqa: EM101 |
| 196 | + |
| 197 | + |
| 198 | +class AsyncRetryingCaller(BaseRetryingCaller): |
| 199 | + """ |
| 200 | + An object that will call your async callable with retries. |
| 201 | +
|
| 202 | + Instances of `AsyncRetryingCaller` may be reused because they create a new |
| 203 | + :func:`retry_context` iterator on each call. |
| 204 | +
|
| 205 | + .. versionadded:: 24.2.0 |
| 206 | + """ |
| 207 | + |
| 208 | + async def __call__( |
| 209 | + self, func: Callable[P, Awaitable[T]], /, *args: P.args, **kw: P.kwargs |
| 210 | + ) -> T: |
| 211 | + async for attempt in retry_context(**self._context_kws): |
| 212 | + with attempt: |
| 213 | + return await func(*args, **kw) |
| 214 | + |
| 215 | + raise SystemError("unreachable") # pragma: no cover # noqa: EM101 |
| 216 | + |
| 217 | + |
129 | 218 | _STOP_NO_RETRY = _t.stop_after_attempt(1)
|
130 | 219 |
|
131 | 220 |
|
|
0 commit comments