Skip to content

Commit 0242f98

Browse files
committed
Cache field types of the given model
1 parent 2d513df commit 0242f98

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

pydantic_redis/model.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class Model(_AbstractModel):
1919
_nested_model_tuple_fields = {}
2020
_nested_model_list_fields = {}
2121
_nested_model_fields = {}
22+
_field_types = {}
2223

2324
@classmethod
2425
def __get_primary_key(cls, primary_key_value: Any):
@@ -37,13 +38,13 @@ def get_table_index_key(cls):
3738
@classmethod
3839
def initialize(cls):
3940
"""Initializes class-wide variables for performance's reasons e.g. it caches the nested model fields"""
40-
field_types = typing.get_type_hints(cls)
41+
cls._field_types = typing.get_type_hints(cls)
4142

4243
cls._nested_model_list_fields = {}
4344
cls._nested_model_tuple_fields = {}
4445
cls._nested_model_fields = {}
4546

46-
for field, field_type in field_types.items():
47+
for field, field_type in cls._field_types.items():
4748
try:
4849
# In case the annotation is Optional, an alias of Union[X, None], extract the X
4950
is_generic = hasattr(field_type, "__origin__")
@@ -199,23 +200,22 @@ def __parse_dict_list(cls, data: List[Dict[bytes, Any]]) -> List[Dict[str, Any]]
199200
cls.deserialize_partially(record) for record in data if record != {}
200201
]
201202
if len(parsed_data) > 0:
202-
field_types = typing.get_type_hints(cls)
203203
keys = [*parsed_data[0].keys()]
204204

205205
for k in keys:
206206
if k.startswith(NESTED_MODEL_LIST_FIELD_PREFIX):
207207
cls.__eager_load_nested_model_lists(
208-
prefixed_field=k, data=parsed_data, field_types=field_types
208+
prefixed_field=k, data=parsed_data
209209
)
210210

211211
elif k.startswith(NESTED_MODEL_TUPLE_FIELD_PREFIX):
212212
cls.__eager_load_nested_model_tuples(
213-
prefixed_field=k, data=parsed_data, field_types=field_types
213+
prefixed_field=k, data=parsed_data
214214
)
215215

216216
elif k.startswith(NESTED_MODEL_PREFIX):
217217
cls.__eager_load_nested_models(
218-
prefixed_field=k, data=parsed_data, field_types=field_types
218+
prefixed_field=k, data=parsed_data, field_types=cls._field_types
219219
)
220220
return parsed_data
221221

@@ -224,7 +224,6 @@ def __eager_load_nested_model_lists(
224224
cls,
225225
prefixed_field: str,
226226
data: List[Dict[str, Any]],
227-
field_types: Dict[str, Any],
228227
):
229228
"""
230229
Eagerly loads any properties that have `List[Model]` or Optional[List[Model]]` as their type annotations
@@ -234,24 +233,17 @@ def __eager_load_nested_model_lists(
234233
[{"___books": ["id1", "id2"]}] becomes [{"books": [Book{"id": "id1", ...}, Book{"id": "id2", ...}]}]
235234
"""
236235
field = strip_leading(prefixed_field, NESTED_MODEL_LIST_FIELD_PREFIX)
237-
field_type = field_types.get(field)
238-
model_type = field_type.__args__[0]
239-
240-
# in case the field is Optional e.g. books: Optional[List[Model]], an alias for Union[List[Model], None]
241-
is_optional = field_type.__origin__ == Union
242-
if is_optional:
243-
model_type = model_type.__args__[0]
236+
field_type = cls._nested_model_list_fields.get(field)
244237

245238
for record in data:
246239
ids = record.pop(prefixed_field, None)
247-
record[field] = model_type.select(ids=ids)
240+
record[field] = field_type.select(ids=ids)
248241

249242
@classmethod
250243
def __eager_load_nested_model_tuples(
251244
cls,
252245
prefixed_field: str,
253246
data: List[Dict[str, Any]],
254-
field_types: Dict[str, Any],
255247
):
256248
"""
257249
Eagerly loads any properties that have `Tuple[Model]` or `Optional[Tuple[Model]]` as their type annotations
@@ -287,7 +279,7 @@ def __eager_load_nested_models(
287279
[{"__book": "id1"}] becomes [{"book": Book{"id": "id1", ...}}]
288280
"""
289281
field = strip_leading(prefixed_field, NESTED_MODEL_PREFIX)
290-
model_type = field_types.get(field)
282+
model_type = cls._nested_model_fields.get(field)
291283

292284
ids: List[str] = [record.pop(prefixed_field, None) for record in data]
293285
# a bulk network request might be faster than eagerly loading for each record for many records
@@ -400,7 +392,7 @@ def __get_select_fields(cls, columns: Optional[List[str]]) -> Optional[List[str]
400392
if columns is None:
401393
return None
402394

403-
field_types = typing.get_type_hints(cls)
395+
field_types = cls._field_types
404396

405397
fields = []
406398
for col in columns:

0 commit comments

Comments
 (0)