Skip to content

Commit 3fc381f

Browse files
committed
chore: Fix mypy linting errors
Signed-off-by: Mark Wiebe <[email protected]>
1 parent 177a15a commit 3fc381f

File tree

7 files changed

+250
-197
lines changed

7 files changed

+250
-197
lines changed

src/openjd/model/_convert_pydantic_error.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22

3-
from typing import Type
3+
from typing import Type, Union
44
from pydantic import BaseModel
55
from pydantic_core import ErrorDetails
66
from inspect import getmodule
@@ -50,7 +50,7 @@ def _error_dict_to_str(root_model: Type[BaseModel], error_details: ErrorDetails)
5050
return f"{_loc_to_str(root_model, loc)}:\n\t{msg}"
5151

5252

53-
def _loc_to_str(root_model: Type[BaseModel], loc: tuple[int | str, ...]) -> str:
53+
def _loc_to_str(root_model: Type[BaseModel], loc: tuple[Union[int, str], ...]) -> str:
5454
model_module = getmodule(root_model)
5555

5656
# If a nested error is from a root validator, then just report the error as being

src/openjd/model/_format_strings/_format_string.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class ExpressionInfo:
2020

2121
class FormatStringError(ValueError):
2222
def __init__(self, *, string: str, start: int, end: int, expr: str = "", details: str = ""):
23+
self.input = string
2324
expression = f"Expression: {expr}. " if expr else ""
2425
reason = f"Reason: {details}." if details else ""
2526
msg = (

src/openjd/model/_internal/_create_job.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22

3-
from typing import Any, Union
3+
from typing import cast, Any, Union
44

55
from pydantic import ValidationError
66
from pydantic_core import InitErrorDetails
@@ -71,14 +71,18 @@ def instantiate_model( # noqa: C901
7171
except ValidationError as exc:
7272
# Convert the ErrorDetails to InitErrorDetails by excluding the 'msg'
7373
for error_details in exc.errors():
74-
errors.append(
75-
InitErrorDetails(
76-
**{key: value for key, value in error_details.items() if key != "msg"}
77-
)
78-
)
74+
init_error_details = {
75+
key: value for key, value in error_details.items() if key != "msg"
76+
}
77+
errors.append(cast(InitErrorDetails, init_error_details))
7978
except FormatStringError as exc:
8079
errors.append(
81-
InitErrorDetails(type="value_error", loc=loc, ctx={"error": ValueError(str(exc))})
80+
InitErrorDetails(
81+
type="value_error",
82+
loc=loc,
83+
ctx={"error": ValueError(str(exc))},
84+
input=exc.input,
85+
)
8286
)
8387

8488
if errors:
@@ -105,10 +109,10 @@ def instantiate_model( # noqa: C901
105109
init_error_details = {}
106110
for key, value in error_details.items():
107111
if key == "loc":
108-
init_error_details["loc"] = loc + value
112+
init_error_details["loc"] = loc + cast(tuple, value)
109113
elif key != "msg":
110114
init_error_details[key] = value
111-
errors.append(InitErrorDetails(**init_error_details))
115+
errors.append(cast(InitErrorDetails, init_error_details))
112116
raise ValidationError.from_exception_data(
113117
title=model.__class__.__name__, line_errors=errors
114118
)
@@ -183,15 +187,17 @@ def _instantiate_list_field( # noqa: C901
183187
except ValidationError as exc:
184188
# Convert the ErrorDetails to InitErrorDetails by excluding the 'msg'
185189
for error_details in exc.errors():
186-
errors.append(
187-
InitErrorDetails(
188-
**{key: value for key, value in error_details.items() if key != "msg"}
189-
)
190-
)
190+
init_error_details = {
191+
key: value for key, value in error_details.items() if key != "msg"
192+
}
193+
errors.append(cast(InitErrorDetails, init_error_details))
191194
except FormatStringError as exc:
192195
errors.append(
193196
InitErrorDetails(
194-
type="value_error", loc=loc, ctx={"error": ValueError(str(exc))}
197+
type="value_error",
198+
loc=loc,
199+
ctx={"error": ValueError(str(exc))},
200+
input=exc.input,
195201
)
196202
)
197203
else:
@@ -211,15 +217,17 @@ def _instantiate_list_field( # noqa: C901
211217
except ValidationError as exc:
212218
# Convert the ErrorDetails to InitErrorDetails by excluding the 'msg'
213219
for error_details in exc.errors():
214-
errors.append(
215-
InitErrorDetails(
216-
**{key: value for key, value in error_details.items() if key != "msg"}
217-
)
218-
)
220+
init_error_details = {
221+
key: value for key, value in error_details.items() if key != "msg"
222+
}
223+
errors.append(cast(InitErrorDetails, init_error_details))
219224
except FormatStringError as exc:
220225
errors.append(
221226
InitErrorDetails(
222-
type="value_error", loc=loc, ctx={"error": ValueError(str(exc))}
227+
type="value_error",
228+
loc=loc,
229+
ctx={"error": ValueError(str(exc))},
230+
input=exc.input,
223231
)
224232
)
225233

@@ -266,14 +274,18 @@ def _instantiate_dict_field(
266274
except ValidationError as exc:
267275
# Convert the ErrorDetails to InitErrorDetails by excluding the 'msg'
268276
for error_details in exc.errors():
269-
errors.append(
270-
InitErrorDetails(
271-
**{key: value for key, value in error_details.items() if key != "msg"}
272-
)
273-
)
277+
init_error_details = {
278+
key: value for key, value in error_details.items() if key != "msg"
279+
}
280+
errors.append(cast(InitErrorDetails, init_error_details))
274281
except FormatStringError as exc:
275282
errors.append(
276-
InitErrorDetails(type="value_error", loc=loc, ctx={"error": ValueError(str(exc))})
283+
InitErrorDetails(
284+
type="value_error",
285+
loc=loc,
286+
ctx={"error": ValueError(str(exc))},
287+
input=exc.input,
288+
)
277289
)
278290

279291
if errors:

src/openjd/model/_internal/_variable_reference_validation.py

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
from collections import defaultdict
44
import typing
5-
from typing import Any, Optional, Type, Literal, Union
5+
from typing import cast, Any, Optional, Type, Literal, Union
66
from inspect import isclass
77

8+
from pydantic import Discriminator
89
from pydantic_core import InitErrorDetails
9-
from pydantic.fields import FieldInfo
10+
from pydantic.fields import FieldInfo, ModelPrivateAttr
1011

1112
from .._types import OpenJDModel, ResolutionScope
1213
from .._format_strings import FormatString, FormatStringError
@@ -101,36 +102,18 @@
101102
# of their fields variable definitions are passed in to.
102103
# - Information on this is encoded in the model's `_template_variable_sources` field. See the comment for this field in the
103104
# OpenJDModel base class for information on this property
104-
105-
106-
# TODO: Rewrite this comment for Pydantic 2
107105
# 4. Since this validation is a pre-validator, we basically have to re-implement a fragment of Pydantic's model parser for this
108-
# depth first traversal. Thus, you'll need to know the following about Pydantic v1.x's data model and parser to understand this
106+
# depth first traversal. Thus, you'll need to know the following about Pydantic v2's data model and parser to understand this
109107
# implementation:
110108
# a) All models are derived from pydantic.BaseModel
111109
# b) pydantic.BaseModel.model_fields: dict[str, FieldInfo] is injected into all BaseModels by pydantic's BaseModel metaclass.
112110
# This member is what gives pydantic information about each of the fields defined in the model class. The key of the dict is the
113111
# name of the field in the model.
114112
# c) pydantic.FieldInfo describes the type information about a model's field:
115-
# i) pydantic.FieldInfo.annotation is the type annotating the field
116-
# Literal[...] means that it's a singleton type.
117-
# list[...] means that it's a list type.
118-
# dict[...] means that it's a dict type.
119-
# etc.
120-
# ii) pydantic.FieldInfo.annotation gives you the type of the field; this is only useful for scalar singleton fields.
121-
# iii) pydantic.FieldInfo.sub_fields: Optional[list[pydantic.FieldInfo]] exists for list, dictionary, and union-typed singleton
122-
# fields:
123-
# 1. For SHAPE_LIST: sub_fields has length 1, and its element is the FieldInfo for the elements of the list.
124-
# 2. For SHAPE_DICT: sub_fields has length 1, and its element is the FieldInfo for the value-type of the dict.
125-
# 3. For SHAPE_SINGLETON:
126-
# a) For scalar-typed fields: sub_fields is None
127-
# b) For union-typed fields: sub_fields is a list of all of the types in the union
128-
# iv) For discriminated unions:
129-
# 1. pydantic.FieldInfo.discriminator_key: Optional[str] exists and it gives the name of the submodel field used to
130-
# determine which type of the union a given data value is.
131-
# 2. pydantic.sub_fields_mapping: Optional[dict[str,pydantic.FieldInfo]] exists and can be used to find the unioned type
132-
# for a given discriminator value.
133-
#
113+
# i) pydantic.FieldInfo.annotation gives you the type of the field; The structure of the field is contained
114+
# in this type, including typing.Annotated values for discriminated unions. Both the definition collection and
115+
# validation recursively unwraps these types along with the values. The pydantic.FieldInfo also includes a discriminator
116+
# value, so the code handles both cases.
134117

135118

136119
class ScopedSymtabs(defaultdict):
@@ -198,7 +181,7 @@ def _validate_model_template_variable_references(
198181
symbol_prefix: str,
199182
symbols: ScopedSymtabs,
200183
loc: tuple,
201-
discriminator: Optional[str] = None,
184+
discriminator: Union[str, Discriminator, None] = None,
202185
) -> list[InitErrorDetails]:
203186
"""Inner implementation of prevalidate_model_template_variable_references().
204187
@@ -290,14 +273,18 @@ def _validate_model_template_variable_references(
290273

291274
# Unwrap a discriminated union to the selected type
292275
if model_origin is Union and discriminator is not None:
293-
return _validate_model_template_variable_references(
294-
_get_model_for_singleton_value(model, value, discriminator),
295-
value,
296-
current_scope,
297-
symbol_prefix,
298-
symbols,
299-
loc,
300-
)
276+
unioned_model = _get_model_for_singleton_value(model, value, discriminator)
277+
if unioned_model is not None:
278+
return _validate_model_template_variable_references(
279+
unioned_model,
280+
value,
281+
current_scope,
282+
symbol_prefix,
283+
symbols,
284+
loc,
285+
)
286+
else:
287+
return []
301288

302289
if isclass(model) and issubclass(model, FormatString):
303290
if isinstance(value, str):
@@ -310,8 +297,9 @@ def _validate_model_template_variable_references(
310297

311298
# Does this cls change the variable reference scope for itself and its children? If so, then update
312299
# our scope.
313-
if model._template_variable_scope.get_default() is not None:
314-
current_scope = model._template_variable_scope.get_default()
300+
model_override_scope = cast(ModelPrivateAttr, model._template_variable_scope).get_default()
301+
if model_override_scope is not None:
302+
current_scope = model_override_scope
315303

316304
# Apply any changes that this node makes to the template variable prefix.
317305
# e.g. It may change "Env." to "Env.File."
@@ -331,9 +319,10 @@ def _validate_model_template_variable_references(
331319
# Recursively validate the contents of FormatStrings within the model.
332320
for field_name, field_info in model.model_fields.items():
333321
field_value = value.get(field_name)
334-
if field_value is None:
322+
field_model = field_info.annotation
323+
if field_value is None or field_model is None:
335324
continue
336-
if typing.get_origin(field_info.annotation) is Literal:
325+
if typing.get_origin(field_model) is Literal:
337326
# Literals aren't format strings and cannot be recursed in to; skip them.
338327
continue
339328

@@ -350,7 +339,7 @@ def _validate_model_template_variable_references(
350339

351340
errors.extend(
352341
_validate_model_template_variable_references(
353-
field_info.annotation,
342+
field_model,
354343
field_value,
355344
current_scope,
356345
symbol_prefix,
@@ -381,13 +370,15 @@ def _check_format_string(
381370
try:
382371
expr.expression.validate_symbol_refs(symbols=scoped_symbols)
383372
except ValueError as exc:
384-
errors.append(InitErrorDetails(type="value_error", loc=loc, ctx={"error": exc}))
373+
errors.append(
374+
InitErrorDetails(type="value_error", loc=loc, ctx={"error": exc}, input=value)
375+
)
385376
return errors
386377

387378

388379
def _get_model_for_singleton_value(
389-
model: Any, value: Any, discriminator: Optional[str] = None
390-
) -> Optional[FieldInfo]:
380+
model: Any, value: Any, discriminator: Union[str, Discriminator, None] = None
381+
) -> Optional[Type]:
391382
"""Given a FieldInfo and the value that we're given for that field, determine
392383
the actual Model for the value in the event that the FieldInfo may be for
393384
a discriminated union."""
@@ -453,7 +444,7 @@ def _collect_variable_definitions( # noqa: C901 (suppress: too complex)
453444
current_scope: ResolutionScope,
454445
symbol_prefix: str,
455446
recursive_pruning: bool = True,
456-
discriminator: Optional[str] = None,
447+
discriminator: Union[str, Discriminator, None] = None,
457448
) -> dict[str, ScopedSymtabs]:
458449
"""Collects the names of variables that each field of this model object provides.
459450
@@ -518,12 +509,16 @@ def _collect_variable_definitions( # noqa: C901 (suppress: too complex)
518509

519510
# Unwrap a discriminated union to the selected type
520511
if model_origin is Union and discriminator is not None:
521-
return _collect_variable_definitions(
522-
_get_model_for_singleton_value(model, value, discriminator),
523-
value,
524-
current_scope,
525-
symbol_prefix,
526-
)
512+
unioned_model = _get_model_for_singleton_value(model, value, discriminator)
513+
if unioned_model is not None:
514+
return _collect_variable_definitions(
515+
unioned_model,
516+
value,
517+
current_scope,
518+
symbol_prefix,
519+
)
520+
else:
521+
return {"__export__": ScopedSymtabs()}
527522

528523
# Anything except for an OpenJDModel returns an empty result
529524
if not isclass(model) or not issubclass(model, OpenJDModel):
@@ -577,10 +572,10 @@ def _collect_variable_definitions( # noqa: C901 (suppress: too complex)
577572
# Collect the variable definitions exported by the fields of the model
578573
for field_name, field_info in model.model_fields.items():
579574
field_value = value.get(field_name)
580-
if field_value is None:
575+
field_model = field_info.annotation
576+
if field_value is None or field_model is None:
581577
continue
582578

583-
field_model = field_info.annotation
584579
discriminator = field_info.discriminator
585580

586581
symbols[field_name] = _collect_variable_definitions(

src/openjd/model/_parse.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import yaml
1010
from pydantic import BaseModel
1111
from pydantic import ValidationError as PydanticValidationError
12-
from pydantic_core import ErrorDetails
12+
from pydantic_core import ErrorDetails, InitErrorDetails
1313

1414
from ._errors import DecodeValidationError
1515
from ._types import EnvironmentTemplate, JobTemplate, OpenJDModel, TemplateSpecificationVersion
@@ -58,8 +58,14 @@ def _parse_model(*, model: Type[T], obj: Any) -> T:
5858
result = cast(T, cast(BaseModel, model).model_validate(obj))
5959
except PydanticValidationError as exc:
6060
if prevalidator_error is not None:
61+
errors = list[InitErrorDetails]()
62+
for error_details in exc.errors() + prevalidator_error.errors():
63+
init_error_details = {
64+
key: value for key, value in error_details.items() if key != "msg"
65+
}
66+
errors.append(cast(InitErrorDetails, init_error_details))
6167
raise PydanticValidationError.from_exception_data(
62-
exc.title, exc.errors() + prevalidator_error.errors()
68+
title=exc.title, line_errors=errors
6369
)
6470
else:
6571
raise

0 commit comments

Comments
 (0)