Skip to content

Commit 4a284e1

Browse files
committed
Fixing mypy errors in redis/commands/search/aggregation.py
1 parent 4a8d37c commit 4a284e1

File tree

2 files changed

+30
-32
lines changed

2 files changed

+30
-32
lines changed

redis/commands/search/aggregation.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Union
1+
from typing import List, Optional, Union
22

33
from redis.commands.search.dialect import DEFAULT_DIALECT
44

@@ -26,10 +26,10 @@ class Reducer:
2626

2727
NAME = None
2828

29-
def __init__(self, *args: List[str]) -> None:
30-
self._args = args
31-
self._field = None
32-
self._alias = None
29+
def __init__(self, *args: str) -> None:
30+
self._args: tuple[str, ...] = args
31+
self._field: Optional[str] = None
32+
self._alias: Optional[str] = None
3333

3434
def alias(self, alias: str) -> "Reducer":
3535
"""
@@ -49,13 +49,14 @@ def alias(self, alias: str) -> "Reducer":
4949
if alias is FIELDNAME:
5050
if not self._field:
5151
raise ValueError("Cannot use FIELDNAME alias with no field")
52-
# Chop off initial '@'
53-
alias = self._field[1:]
52+
else:
53+
# Chop off initial '@'
54+
alias = self._field[1:]
5455
self._alias = alias
5556
return self
5657

5758
@property
58-
def args(self) -> List[str]:
59+
def args(self) -> tuple[str, ...]:
5960
return self._args
6061

6162

@@ -64,7 +65,7 @@ class SortDirection:
6465
This special class is used to indicate sort direction.
6566
"""
6667

67-
DIRSTRING = None
68+
DIRSTRING: Optional[str] = None
6869

6970
def __init__(self, field: str) -> None:
7071
self.field = field
@@ -104,19 +105,19 @@ def __init__(self, query: str = "*") -> None:
104105
All member methods (except `build_args()`)
105106
return the object itself, making them useful for chaining.
106107
"""
107-
self._query = query
108-
self._aggregateplan = []
109-
self._loadfields = []
110-
self._loadall = False
111-
self._max = 0
112-
self._with_schema = False
113-
self._verbatim = False
114-
self._cursor = []
115-
self._dialect = DEFAULT_DIALECT
116-
self._add_scores = False
117-
self._scorer = "TFIDF"
118-
119-
def load(self, *fields: List[str]) -> "AggregateRequest":
108+
self._query: str = query
109+
self._aggregateplan: List[str] = []
110+
self._loadfields: List[str] = []
111+
self._loadall: bool = False
112+
self._max: int = 0
113+
self._with_schema: bool = False
114+
self._verbatim: bool = False
115+
self._cursor: List[str] = []
116+
self._dialect: int = DEFAULT_DIALECT
117+
self._add_scores: bool = False
118+
self._scorer: str = "TFIDF"
119+
120+
def load(self, *fields: str) -> "AggregateRequest":
120121
"""
121122
Indicate the fields to be returned in the response. These fields are
122123
returned in addition to any others implicitly specified.
@@ -133,7 +134,7 @@ def load(self, *fields: List[str]) -> "AggregateRequest":
133134
return self
134135

135136
def group_by(
136-
self, fields: List[str], *reducers: Union[Reducer, List[Reducer]]
137+
self, fields: Union[str, List[str]], *reducers: Reducer
137138
) -> "AggregateRequest":
138139
"""
139140
Specify by which fields to group the aggregation.
@@ -147,7 +148,6 @@ def group_by(
147148
`aggregation` module.
148149
"""
149150
fields = [fields] if isinstance(fields, str) else fields
150-
reducers = [reducers] if isinstance(reducers, Reducer) else reducers
151151

152152
ret = ["GROUPBY", str(len(fields)), *fields]
153153
for reducer in reducers:
@@ -223,7 +223,7 @@ def limit(self, offset: int, num: int) -> "AggregateRequest":
223223
self._aggregateplan.extend(_limit.build_args())
224224
return self
225225

226-
def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest":
226+
def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
227227
"""
228228
Indicate how the results should be sorted. This can also be used for
229229
*top-N* style queries
@@ -251,12 +251,10 @@ def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest":
251251
.sort_by(Desc("@paid"), max=10)
252252
```
253253
"""
254-
if isinstance(fields, (str, SortDirection)):
255-
fields = [fields]
256254

257255
fields_args = []
258256
for f in fields:
259-
if isinstance(f, SortDirection):
257+
if isinstance(f, (Asc, Desc)):
260258
fields_args += [f.field, f.DIRSTRING]
261259
else:
262260
fields_args += [f]
@@ -356,7 +354,7 @@ def build_args(self) -> List[str]:
356354
ret.extend(self._loadfields)
357355

358356
if self._dialect:
359-
ret.extend(["DIALECT", self._dialect])
357+
ret.extend(["DIALECT", str(self._dialect)])
360358

361359
ret.extend(self._aggregateplan)
362360

@@ -393,7 +391,7 @@ def __init__(self, rows, cursor: Cursor, schema) -> None:
393391
self.cursor = cursor
394392
self.schema = schema
395393

396-
def __repr__(self) -> (str, str):
394+
def __repr__(self) -> str:
397395
cid = self.cursor.cid if self.cursor else -1
398396
return (
399397
f"<{self.__class__.__name__} at 0x{id(self):x} "

redis/commands/search/commands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def info(self):
464464
return self._parse_results(INFO_CMD, res)
465465

466466
def get_params_args(
467-
self, query_params: Union[Dict[str, Union[str, int, float, bytes]], None]
467+
self, query_params: Optional[Dict[str, Union[str, int, float, bytes]]]
468468
):
469469
if query_params is None:
470470
return []
@@ -543,7 +543,7 @@ def explain_cli(self, query: Union[str, Query]): # noqa
543543
def aggregate(
544544
self,
545545
query: Union[str, Query],
546-
query_params: Dict[str, Union[str, int, float]] = None,
546+
query_params: Optional[Dict[str, Union[str, int, float]]] = None,
547547
):
548548
"""
549549
Issue an aggregation query.

0 commit comments

Comments
 (0)