1
- from typing import List , Union
1
+ from typing import List , Optional , Union
2
2
3
3
from redis .commands .search .dialect import DEFAULT_DIALECT
4
4
@@ -26,10 +26,10 @@ class Reducer:
26
26
27
27
NAME = None
28
28
29
- def __init__ (self , * args : List [ str ] ) -> None :
30
- self ._args = args
31
- self ._field = None
32
- self ._alias = None
29
+ def __init__ (self , * args : str ) -> None :
30
+ self ._args : tuple [ str , ...] = args
31
+ self ._field : Optional [ str ] = None
32
+ self ._alias : Optional [ str ] = None
33
33
34
34
def alias (self , alias : str ) -> "Reducer" :
35
35
"""
@@ -49,13 +49,14 @@ def alias(self, alias: str) -> "Reducer":
49
49
if alias is FIELDNAME :
50
50
if not self ._field :
51
51
raise ValueError ("Cannot use FIELDNAME alias with no field" )
52
- # Chop off initial '@'
53
- alias = self ._field [1 :]
52
+ else :
53
+ # Chop off initial '@'
54
+ alias = self ._field [1 :]
54
55
self ._alias = alias
55
56
return self
56
57
57
58
@property
58
- def args (self ) -> List [str ]:
59
+ def args (self ) -> tuple [str , ... ]:
59
60
return self ._args
60
61
61
62
@@ -64,7 +65,7 @@ class SortDirection:
64
65
This special class is used to indicate sort direction.
65
66
"""
66
67
67
- DIRSTRING = None
68
+ DIRSTRING : Optional [ str ] = None
68
69
69
70
def __init__ (self , field : str ) -> None :
70
71
self .field = field
@@ -104,19 +105,19 @@ def __init__(self, query: str = "*") -> None:
104
105
All member methods (except `build_args()`)
105
106
return the object itself, making them useful for chaining.
106
107
"""
107
- self ._query = query
108
- self ._aggregateplan = []
109
- self ._loadfields = []
110
- self ._loadall = False
111
- self ._max = 0
112
- self ._with_schema = False
113
- self ._verbatim = False
114
- self ._cursor = []
115
- self ._dialect = DEFAULT_DIALECT
116
- self ._add_scores = False
117
- self ._scorer = "TFIDF"
118
-
119
- def load (self , * fields : List [ str ] ) -> "AggregateRequest" :
108
+ self ._query : str = query
109
+ self ._aggregateplan : List [ str ] = []
110
+ self ._loadfields : List [ str ] = []
111
+ self ._loadall : bool = False
112
+ self ._max : int = 0
113
+ self ._with_schema : bool = False
114
+ self ._verbatim : bool = False
115
+ self ._cursor : List [ str ] = []
116
+ self ._dialect : int = DEFAULT_DIALECT
117
+ self ._add_scores : bool = False
118
+ self ._scorer : str = "TFIDF"
119
+
120
+ def load (self , * fields : str ) -> "AggregateRequest" :
120
121
"""
121
122
Indicate the fields to be returned in the response. These fields are
122
123
returned in addition to any others implicitly specified.
@@ -133,7 +134,7 @@ def load(self, *fields: List[str]) -> "AggregateRequest":
133
134
return self
134
135
135
136
def group_by (
136
- self , fields : List [str ], * reducers : Union [ Reducer , List [ Reducer ]]
137
+ self , fields : Union [ str , List [str ]] , * reducers : Reducer
137
138
) -> "AggregateRequest" :
138
139
"""
139
140
Specify by which fields to group the aggregation.
@@ -147,7 +148,6 @@ def group_by(
147
148
`aggregation` module.
148
149
"""
149
150
fields = [fields ] if isinstance (fields , str ) else fields
150
- reducers = [reducers ] if isinstance (reducers , Reducer ) else reducers
151
151
152
152
ret = ["GROUPBY" , str (len (fields )), * fields ]
153
153
for reducer in reducers :
@@ -223,7 +223,7 @@ def limit(self, offset: int, num: int) -> "AggregateRequest":
223
223
self ._aggregateplan .extend (_limit .build_args ())
224
224
return self
225
225
226
- def sort_by (self , * fields : List [ str ] , ** kwargs ) -> "AggregateRequest" :
226
+ def sort_by (self , * fields : str , ** kwargs ) -> "AggregateRequest" :
227
227
"""
228
228
Indicate how the results should be sorted. This can also be used for
229
229
*top-N* style queries
@@ -251,12 +251,10 @@ def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest":
251
251
.sort_by(Desc("@paid"), max=10)
252
252
```
253
253
"""
254
- if isinstance (fields , (str , SortDirection )):
255
- fields = [fields ]
256
254
257
255
fields_args = []
258
256
for f in fields :
259
- if isinstance (f , SortDirection ):
257
+ if isinstance (f , ( Asc , Desc ) ):
260
258
fields_args += [f .field , f .DIRSTRING ]
261
259
else :
262
260
fields_args += [f ]
@@ -356,7 +354,7 @@ def build_args(self) -> List[str]:
356
354
ret .extend (self ._loadfields )
357
355
358
356
if self ._dialect :
359
- ret .extend (["DIALECT" , self ._dialect ])
357
+ ret .extend (["DIALECT" , str ( self ._dialect ) ])
360
358
361
359
ret .extend (self ._aggregateplan )
362
360
@@ -393,7 +391,7 @@ def __init__(self, rows, cursor: Cursor, schema) -> None:
393
391
self .cursor = cursor
394
392
self .schema = schema
395
393
396
- def __repr__ (self ) -> ( str , str ) :
394
+ def __repr__ (self ) -> str :
397
395
cid = self .cursor .cid if self .cursor else - 1
398
396
return (
399
397
f"<{ self .__class__ .__name__ } at 0x{ id (self ):x} "
0 commit comments