1
1
import abc
2
2
import socket
3
3
from time import sleep
4
- from typing import TYPE_CHECKING , Any , Callable , Iterable , Tuple , Type , TypeVar , Union
4
+ from typing import TYPE_CHECKING , Any , Callable , Generic , Iterable , Tuple , Type , TypeVar
5
5
6
6
from redis .exceptions import ConnectionError , TimeoutError
7
7
8
8
T = TypeVar ("T" )
9
+ E = TypeVar ("E" , bound = Exception , covariant = True )
9
10
10
11
if TYPE_CHECKING :
11
12
from redis .backoff import AbstractBackoff
12
13
13
14
14
- class AbstractRetry (abc .ABC ):
15
+ class AbstractRetry (Generic [ E ], abc .ABC ):
15
16
"""Retry a specific number of times after a failure"""
16
17
17
- _supported_errors : Tuple [Type [Exception ], ...]
18
+ _supported_errors : Tuple [Type [E ], ...]
18
19
19
20
def __init__ (
20
21
self ,
21
22
backoff : "AbstractBackoff" ,
22
23
retries : int ,
23
- supported_errors : Union [ Tuple [Type [Exception ], ...], None ] = None ,
24
+ supported_errors : Tuple [Type [E ], ...],
24
25
):
25
26
"""
26
27
Initialize a `Retry` object with a `Backoff` object
@@ -31,8 +32,7 @@ def __init__(
31
32
"""
32
33
self ._backoff = backoff
33
34
self ._retries = retries
34
- if supported_errors :
35
- self ._supported_errors = supported_errors
35
+ self ._supported_errors = supported_errors
36
36
37
37
@abc .abstractmethod
38
38
def __eq__ (self , other : Any ) -> bool :
@@ -41,9 +41,7 @@ def __eq__(self, other: Any) -> bool:
41
41
def __hash__ (self ) -> int :
42
42
return hash ((self ._backoff , self ._retries , frozenset (self ._supported_errors )))
43
43
44
- def update_supported_errors (
45
- self , specified_errors : Iterable [Type [Exception ]]
46
- ) -> None :
44
+ def update_supported_errors (self , specified_errors : Iterable [Type [E ]]) -> None :
47
45
"""
48
46
Updates the supported errors with the specified error types
49
47
"""
@@ -64,14 +62,23 @@ def update_retries(self, value: int) -> None:
64
62
self ._retries = value
65
63
66
64
67
- class Retry (AbstractRetry ):
68
- _supported_errors : Tuple [Type [Exception ], ...] = (
69
- ConnectionError ,
70
- TimeoutError ,
71
- socket .timeout ,
72
- )
65
+ class Retry (AbstractRetry [Exception ]):
73
66
__hash__ = AbstractRetry .__hash__
74
67
68
+ def __init__ (
69
+ self ,
70
+ backoff : "AbstractBackoff" ,
71
+ retries : int ,
72
+ supported_errors : Tuple [Type [Exception ], ...] = (
73
+ ConnectionError ,
74
+ TimeoutError ,
75
+ socket .timeout ,
76
+ ),
77
+ ):
78
+ super ().__init__ (backoff , retries , supported_errors )
79
+
80
+ __init__ .__doc__ = AbstractRetry .__init__ .__doc__
81
+
75
82
def __eq__ (self , other : Any ) -> bool :
76
83
if not isinstance (other , Retry ):
77
84
return NotImplemented
0 commit comments