66from abc import ABC , abstractmethod
77from collections .abc import Awaitable , Callable , Sequence
88from dataclasses import dataclass , field
9+ from functools import cached_property
910from typing import TYPE_CHECKING , Any , Generic , Literal , cast , overload
1011
1112from pydantic import Json , TypeAdapter , ValidationError
@@ -218,7 +219,6 @@ class OutputSchema(ABC, Generic[OutputDataT]):
218219 object_def : OutputObjectDefinition | None = None
219220 allows_deferred_tools : bool = False
220221 allows_image : bool = False
221- json_schema : JsonSchema = field (init = False )
222222
223223 @property
224224 def mode (self ) -> OutputMode :
@@ -228,6 +228,10 @@ def mode(self) -> OutputMode:
228228 def allows_text (self ) -> bool :
229229 return self .text_processor is not None
230230
231+ @cached_property
232+ def json_schema (self ) -> JsonSchema :
233+ raise NotImplementedError ()
234+
231235 @classmethod
232236 def build ( # noqa: C901
233237 cls ,
@@ -382,37 +386,31 @@ def _build_processor(
382386 return UnionOutputProcessor (outputs = outputs , strict = strict , name = name , description = description )
383387
384388 @staticmethod
385- def build_json_schema ( # noqa: C901
389+ def build_json_schema (
386390 allows_deferred_tools : bool = False ,
387391 allows_image : bool = False ,
388392 allows_text : bool = False ,
389393 base_processor : BaseObjectOutputProcessor [OutputDataT ] | None = None ,
390394 toolset_processors : dict [str , ObjectOutputProcessor [OutputDataT ]] | None = None ,
391395 ) -> JsonSchema :
392396 # allow any output with {'type': 'string'} if no constraints
393- if not any ([allows_deferred_tools , allows_image , allows_text , base_processor , toolset_processors ]):
397+ if not any ([allows_deferred_tools , allows_image , base_processor , toolset_processors ]):
394398 return TypeAdapter (str ).json_schema ()
395399
396400 object_keys : list [str ] = []
397401 json_schemas : list [ObjectJsonSchema ] = []
398402
399- if base_processor :
400- json_schema = base_processor .object_def .json_schema .copy ()
401- json_schema ['title' ] = base_processor .object_def .name
402- if base_processor .object_def .description :
403- json_schema ['description' ] = base_processor .object_def .description
404- json_schemas .append (json_schema )
405- object_keys .append (json_schema .get ('title' , 'result' ))
406-
407403 if toolset_processors :
408404 for name , tool_processor in toolset_processors .items ():
409405 json_schema = tool_processor .object_def .json_schema .copy ()
410- json_schema ['title' ] = name
411- if tool_processor .object_def .description :
412- json_schema ['description' ] = tool_processor .object_def .description
413406 json_schemas .append (json_schema )
414407 object_keys .append (name )
415408
409+ if base_processor :
410+ json_schema = base_processor .object_def .json_schema .copy ()
411+ json_schemas .append (json_schema )
412+ object_keys .append (json_schema .get ('title' , 'result' ))
413+
416414 special_output_types : list [type ] = []
417415 if allows_text :
418416 special_output_types .append (str )
@@ -422,17 +420,16 @@ def build_json_schema( # noqa: C901
422420 special_output_types .append (_messages .BinaryImage )
423421 if special_output_types :
424422 for output_type in special_output_types :
425- output_type_json_schema = TypeAdapter (output_type ).json_schema ()
423+ output_type_json_schema = TypeAdapter (output_type ).json_schema (mode = 'serialization' )
426424 json_schemas .append (output_type_json_schema )
427- object_key = output_type .__name__
428- object_keys .append (object_key )
429-
430- json_schemas , all_defs = _utils .merge_json_schema_defs (json_schemas )
425+ object_keys .append (output_type .__name__ )
431426
432- # do not discriminate JSON if not needed
433- if len (json_schemas ) == 1 and not all_defs :
427+ # do not further process JSON if not needed
428+ if len (json_schemas ) == 1 :
434429 return json_schemas [0 ]
435430
431+ json_schemas , all_defs = _utils .merge_json_schema_defs (json_schemas )
432+
436433 unique_object_keys : list [str ] = []
437434 for key in object_keys :
438435 count = 1
@@ -442,44 +439,7 @@ def build_json_schema( # noqa: C901
442439 new_key = f'{ key } _{ count } '
443440 unique_object_keys .append (new_key )
444441
445- discriminated_json_schemas : list [ObjectJsonSchema ] = []
446- for object_key , json_schema in zip (unique_object_keys , json_schemas ):
447- title = json_schema .pop ('title' , None )
448- description = json_schema .pop ('description' , None )
449-
450- discriminated_json_schema = {
451- 'type' : 'object' ,
452- 'properties' : {
453- 'kind' : {
454- 'type' : 'string' ,
455- 'const' : object_key ,
456- },
457- 'data' : json_schema ,
458- },
459- 'required' : ['kind' , 'data' ],
460- 'additionalProperties' : False ,
461- }
462- if title : # pragma: no branch
463- discriminated_json_schema ['title' ] = title
464- if description :
465- discriminated_json_schema ['description' ] = description
466-
467- discriminated_json_schemas .append (discriminated_json_schema )
468-
469- json_schema = {
470- 'type' : 'object' ,
471- 'properties' : {
472- 'result' : {
473- 'anyOf' : discriminated_json_schemas ,
474- }
475- },
476- 'required' : ['result' ],
477- 'additionalProperties' : False ,
478- }
479- if all_defs :
480- json_schema ['$defs' ] = all_defs
481-
482- return json_schema
442+ return UnionOutputProcessor .make_discriminated_json_schema_union (unique_object_keys , json_schemas , all_defs )
483443
484444
485445@dataclass (init = False )
@@ -503,17 +463,20 @@ def __init__(
503463 allows_deferred_tools = allows_deferred_tools ,
504464 allows_image = allows_image ,
505465 )
506- self .json_schema = OutputSchema [OutputDataT ].build_json_schema (
507- base_processor = processor ,
508- allows_deferred_tools = allows_deferred_tools ,
509- allows_image = allows_image ,
510- )
511466 self .processor = processor
512467
513468 @property
514469 def mode (self ) -> OutputMode :
515470 return 'auto'
516471
472+ @cached_property
473+ def json_schema (self ) -> JsonSchema :
474+ return OutputSchema [OutputDataT ].build_json_schema (
475+ base_processor = self .processor ,
476+ allows_deferred_tools = self .allows_deferred_tools ,
477+ allows_image = self .allows_image ,
478+ )
479+
517480
518481@dataclass (init = False )
519482class TextOutputSchema (OutputSchema [OutputDataT ]):
@@ -529,17 +492,17 @@ def __init__(
529492 allows_deferred_tools = allows_deferred_tools ,
530493 allows_image = allows_image ,
531494 )
532- if allows_deferred_tools or allows_image :
533- self .json_schema = OutputSchema [OutputDataT ].build_json_schema (
534- allows_deferred_tools = allows_deferred_tools , allows_image = allows_image , allows_text = True
535- )
536- else :
537- self .json_schema = OutputSchema [OutputDataT ].build_json_schema ()
538495
539496 @property
540497 def mode (self ) -> OutputMode :
541498 return 'text'
542499
500+ @cached_property
501+ def json_schema (self ) -> JsonSchema :
502+ return OutputSchema [OutputDataT ].build_json_schema (
503+ allows_deferred_tools = self .allows_deferred_tools , allows_image = self .allows_image , allows_text = True
504+ )
505+
543506
544507class ImageOutputSchema (OutputSchema [OutputDataT ]):
545508 def __init__ (self , * , allows_deferred_tools : bool ):
@@ -566,13 +529,16 @@ def __init__(
566529 allows_deferred_tools = allows_deferred_tools ,
567530 allows_image = allows_image ,
568531 )
569- self .json_schema = OutputSchema [OutputDataT ].build_json_schema (
570- base_processor = processor ,
571- allows_deferred_tools = allows_deferred_tools ,
572- allows_image = allows_image ,
573- )
574532 self .processor = processor
575533
534+ @cached_property
535+ def json_schema (self ) -> JsonSchema :
536+ return OutputSchema [OutputDataT ].build_json_schema (
537+ base_processor = self .processor ,
538+ allows_deferred_tools = self .allows_deferred_tools ,
539+ allows_image = self .allows_image ,
540+ )
541+
576542
577543class NativeOutputSchema (StructuredTextOutputSchema [OutputDataT ]):
578544 @property
@@ -634,17 +600,20 @@ def __init__(
634600 text_processor = text_processor ,
635601 allows_image = allows_image ,
636602 )
637- self .json_schema = OutputSchema [OutputDataT ].build_json_schema (
603+
604+ @property
605+ def mode (self ) -> OutputMode :
606+ return 'tool'
607+
608+ @cached_property
609+ def json_schema (self ) -> JsonSchema :
610+ return OutputSchema [OutputDataT ].build_json_schema (
638611 toolset_processors = self .toolset .processors , # pyright: ignore[reportOptionalMemberAccess]
639612 allows_deferred_tools = self .allows_deferred_tools ,
640613 allows_image = self .allows_image ,
641614 allows_text = self .allows_text ,
642615 )
643616
644- @property
645- def mode (self ) -> OutputMode :
646- return 'tool'
647-
648617
649618class BaseOutputProcessor (ABC , Generic [OutputDataT ]):
650619 @abstractmethod
@@ -850,8 +819,29 @@ def __init__(
850819
851820 json_schemas , all_defs = _utils .merge_json_schema_defs (json_schemas )
852821
822+ json_schema = UnionOutputProcessor .make_discriminated_json_schema_union (
823+ object_keys = list (self ._processors .keys ()),
824+ json_schemas = json_schemas ,
825+ all_defs = all_defs ,
826+ )
827+
828+ super ().__init__ (
829+ object_def = OutputObjectDefinition (
830+ json_schema = json_schema ,
831+ strict = strict ,
832+ name = name ,
833+ description = description ,
834+ )
835+ )
836+
837+ @staticmethod
838+ def make_discriminated_json_schema_union (
839+ object_keys : Sequence [str ],
840+ json_schemas : Sequence [JsonSchema ],
841+ all_defs : JsonSchema ,
842+ ):
853843 discriminated_json_schemas : list [ObjectJsonSchema ] = []
854- for object_key , json_schema in zip (self . _processors . keys () , json_schemas ):
844+ for object_key , json_schema in zip (object_keys , json_schemas ):
855845 title = json_schema .pop ('title' , None )
856846 description = json_schema .pop ('description' , None )
857847
@@ -887,14 +877,7 @@ def __init__(
887877 if all_defs :
888878 json_schema ['$defs' ] = all_defs
889879
890- super ().__init__ (
891- object_def = OutputObjectDefinition (
892- json_schema = json_schema ,
893- strict = strict ,
894- name = name ,
895- description = description ,
896- )
897- )
880+ return json_schema
898881
899882 async def process (
900883 self ,
0 commit comments