22
33from collections import defaultdict
44import typing
5- from typing import Any , Optional , Type , Literal , Union
5+ from typing import cast , Any , Optional , Type , Literal , Union
66from inspect import isclass
77
8+ from pydantic import Discriminator
89from pydantic_core import InitErrorDetails
9- from pydantic .fields import FieldInfo
10+ from pydantic .fields import FieldInfo , ModelPrivateAttr
1011
1112from .._types import OpenJDModel , ResolutionScope
1213from .._format_strings import FormatString , FormatStringError
101102# of their fields variable definitions are passed in to.
102103# - Information on this is encoded in the model's `_template_variable_sources` field. See the comment for this field in the
103104# OpenJDModel base class for information on this property
104-
105-
106- # TODO: Rewrite this comment for Pydantic 2
107105# 4. Since this validation is a pre-validator, we basically have to re-implement a fragment of Pydantic's model parser for this
108- # depth first traversal. Thus, you'll need to know the following about Pydantic v1.x 's data model and parser to understand this
106+ # depth first traversal. Thus, you'll need to know the following about Pydantic v2 's data model and parser to understand this
109107# implementation:
110108# a) All models are derived from pydantic.BaseModel
111109# b) pydantic.BaseModel.model_fields: dict[str, FieldInfo] is injected into all BaseModels by pydantic's BaseModel metaclass.
112110# This member is what gives pydantic information about each of the fields defined in the model class. The key of the dict is the
113111# name of the field in the model.
114112# c) pydantic.FieldInfo describes the type information about a model's field:
115- # i) pydantic.FieldInfo.annotation is the type annotating the field
116- # Literal[...] means that it's a singleton type.
117- # list[...] means that it's a list type.
118- # dict[...] means that it's a dict type.
119- # etc.
120- # ii) pydantic.FieldInfo.annotation gives you the type of the field; this is only useful for scalar singleton fields.
121- # iii) pydantic.FieldInfo.sub_fields: Optional[list[pydantic.FieldInfo]] exists for list, dictionary, and union-typed singleton
122- # fields:
123- # 1. For SHAPE_LIST: sub_fields has length 1, and its element is the FieldInfo for the elements of the list.
124- # 2. For SHAPE_DICT: sub_fields has length 1, and its element is the FieldInfo for the value-type of the dict.
125- # 3. For SHAPE_SINGLETON:
126- # a) For scalar-typed fields: sub_fields is None
127- # b) For union-typed fields: sub_fields is a list of all of the types in the union
128- # iv) For discriminated unions:
129- # 1. pydantic.FieldInfo.discriminator_key: Optional[str] exists and it gives the name of the submodel field used to
130- # determine which type of the union a given data value is.
131- # 2. pydantic.sub_fields_mapping: Optional[dict[str,pydantic.FieldInfo]] exists and can be used to find the unioned type
132- # for a given discriminator value.
133- #
113+ # i) pydantic.FieldInfo.annotation gives you the type of the field; The structure of the field is contained
114+ # in this type, including typing.Annotated values for discriminated unions. Both the definition collection and
115+ # validation recursively unwraps these types along with the values. The pydantic.FieldInfo also includes a discriminator
116+ # value, so the code handles both cases.
134117
135118
136119class ScopedSymtabs (defaultdict ):
@@ -198,7 +181,7 @@ def _validate_model_template_variable_references(
198181 symbol_prefix : str ,
199182 symbols : ScopedSymtabs ,
200183 loc : tuple ,
201- discriminator : Optional [str ] = None ,
184+ discriminator : Union [str , Discriminator , None ] = None ,
202185) -> list [InitErrorDetails ]:
203186 """Inner implementation of prevalidate_model_template_variable_references().
204187
@@ -290,14 +273,18 @@ def _validate_model_template_variable_references(
290273
291274 # Unwrap a discriminated union to the selected type
292275 if model_origin is Union and discriminator is not None :
293- return _validate_model_template_variable_references (
294- _get_model_for_singleton_value (model , value , discriminator ),
295- value ,
296- current_scope ,
297- symbol_prefix ,
298- symbols ,
299- loc ,
300- )
276+ unioned_model = _get_model_for_singleton_value (model , value , discriminator )
277+ if unioned_model is not None :
278+ return _validate_model_template_variable_references (
279+ unioned_model ,
280+ value ,
281+ current_scope ,
282+ symbol_prefix ,
283+ symbols ,
284+ loc ,
285+ )
286+ else :
287+ return []
301288
302289 if isclass (model ) and issubclass (model , FormatString ):
303290 if isinstance (value , str ):
@@ -310,8 +297,9 @@ def _validate_model_template_variable_references(
310297
311298 # Does this cls change the variable reference scope for itself and its children? If so, then update
312299 # our scope.
313- if model ._template_variable_scope .get_default () is not None :
314- current_scope = model ._template_variable_scope .get_default ()
300+ model_override_scope = cast (ModelPrivateAttr , model ._template_variable_scope ).get_default ()
301+ if model_override_scope is not None :
302+ current_scope = model_override_scope
315303
316304 # Apply any changes that this node makes to the template variable prefix.
317305 # e.g. It may change "Env." to "Env.File."
@@ -331,9 +319,10 @@ def _validate_model_template_variable_references(
331319 # Recursively validate the contents of FormatStrings within the model.
332320 for field_name , field_info in model .model_fields .items ():
333321 field_value = value .get (field_name )
334- if field_value is None :
322+ field_model = field_info .annotation
323+ if field_value is None or field_model is None :
335324 continue
336- if typing .get_origin (field_info . annotation ) is Literal :
325+ if typing .get_origin (field_model ) is Literal :
337326 # Literals aren't format strings and cannot be recursed in to; skip them.
338327 continue
339328
@@ -350,7 +339,7 @@ def _validate_model_template_variable_references(
350339
351340 errors .extend (
352341 _validate_model_template_variable_references (
353- field_info . annotation ,
342+ field_model ,
354343 field_value ,
355344 current_scope ,
356345 symbol_prefix ,
@@ -381,13 +370,15 @@ def _check_format_string(
381370 try :
382371 expr .expression .validate_symbol_refs (symbols = scoped_symbols )
383372 except ValueError as exc :
384- errors .append (InitErrorDetails (type = "value_error" , loc = loc , ctx = {"error" : exc }))
373+ errors .append (
374+ InitErrorDetails (type = "value_error" , loc = loc , ctx = {"error" : exc }, input = value )
375+ )
385376 return errors
386377
387378
388379def _get_model_for_singleton_value (
389- model : Any , value : Any , discriminator : Optional [str ] = None
390- ) -> Optional [FieldInfo ]:
380+ model : Any , value : Any , discriminator : Union [str , Discriminator , None ] = None
381+ ) -> Optional [Type ]:
391382 """Given a FieldInfo and the value that we're given for that field, determine
392383 the actual Model for the value in the event that the FieldInfo may be for
393384 a discriminated union."""
@@ -453,7 +444,7 @@ def _collect_variable_definitions( # noqa: C901 (suppress: too complex)
453444 current_scope : ResolutionScope ,
454445 symbol_prefix : str ,
455446 recursive_pruning : bool = True ,
456- discriminator : Optional [str ] = None ,
447+ discriminator : Union [str , Discriminator , None ] = None ,
457448) -> dict [str , ScopedSymtabs ]:
458449 """Collects the names of variables that each field of this model object provides.
459450
@@ -518,12 +509,16 @@ def _collect_variable_definitions( # noqa: C901 (suppress: too complex)
518509
519510 # Unwrap a discriminated union to the selected type
520511 if model_origin is Union and discriminator is not None :
521- return _collect_variable_definitions (
522- _get_model_for_singleton_value (model , value , discriminator ),
523- value ,
524- current_scope ,
525- symbol_prefix ,
526- )
512+ unioned_model = _get_model_for_singleton_value (model , value , discriminator )
513+ if unioned_model is not None :
514+ return _collect_variable_definitions (
515+ unioned_model ,
516+ value ,
517+ current_scope ,
518+ symbol_prefix ,
519+ )
520+ else :
521+ return {"__export__" : ScopedSymtabs ()}
527522
528523 # Anything except for an OpenJDModel returns an empty result
529524 if not isclass (model ) or not issubclass (model , OpenJDModel ):
@@ -577,10 +572,10 @@ def _collect_variable_definitions( # noqa: C901 (suppress: too complex)
577572 # Collect the variable definitions exported by the fields of the model
578573 for field_name , field_info in model .model_fields .items ():
579574 field_value = value .get (field_name )
580- if field_value is None :
575+ field_model = field_info .annotation
576+ if field_value is None or field_model is None :
581577 continue
582578
583- field_model = field_info .annotation
584579 discriminator = field_info .discriminator
585580
586581 symbols [field_name ] = _collect_variable_definitions (
0 commit comments