Skip to content

Commit 3e78858

Browse files
committed
Address feedback
1 parent 3be8f95 commit 3e78858

File tree

2 files changed

+131
-146
lines changed

2 files changed

+131
-146
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 72 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from abc import ABC, abstractmethod
77
from collections.abc import Awaitable, Callable, Sequence
88
from dataclasses import dataclass, field
9+
from functools import cached_property
910
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
1011

1112
from pydantic import Json, TypeAdapter, ValidationError
@@ -218,7 +219,6 @@ class OutputSchema(ABC, Generic[OutputDataT]):
218219
object_def: OutputObjectDefinition | None = None
219220
allows_deferred_tools: bool = False
220221
allows_image: bool = False
221-
json_schema: JsonSchema = field(init=False)
222222

223223
@property
224224
def mode(self) -> OutputMode:
@@ -228,6 +228,10 @@ def mode(self) -> OutputMode:
228228
def allows_text(self) -> bool:
229229
return self.text_processor is not None
230230

231+
@cached_property
232+
def json_schema(self) -> JsonSchema:
233+
raise NotImplementedError()
234+
231235
@classmethod
232236
def build( # noqa: C901
233237
cls,
@@ -382,37 +386,31 @@ def _build_processor(
382386
return UnionOutputProcessor(outputs=outputs, strict=strict, name=name, description=description)
383387

384388
@staticmethod
385-
def build_json_schema( # noqa: C901
389+
def build_json_schema(
386390
allows_deferred_tools: bool = False,
387391
allows_image: bool = False,
388392
allows_text: bool = False,
389393
base_processor: BaseObjectOutputProcessor[OutputDataT] | None = None,
390394
toolset_processors: dict[str, ObjectOutputProcessor[OutputDataT]] | None = None,
391395
) -> JsonSchema:
392396
# allow any output with {'type': 'string'} if no constraints
393-
if not any([allows_deferred_tools, allows_image, allows_text, base_processor, toolset_processors]):
397+
if not any([allows_deferred_tools, allows_image, base_processor, toolset_processors]):
394398
return TypeAdapter(str).json_schema()
395399

396400
object_keys: list[str] = []
397401
json_schemas: list[ObjectJsonSchema] = []
398402

399-
if base_processor:
400-
json_schema = base_processor.object_def.json_schema.copy()
401-
json_schema['title'] = base_processor.object_def.name
402-
if base_processor.object_def.description:
403-
json_schema['description'] = base_processor.object_def.description
404-
json_schemas.append(json_schema)
405-
object_keys.append(json_schema.get('title', 'result'))
406-
407403
if toolset_processors:
408404
for name, tool_processor in toolset_processors.items():
409405
json_schema = tool_processor.object_def.json_schema.copy()
410-
json_schema['title'] = name
411-
if tool_processor.object_def.description:
412-
json_schema['description'] = tool_processor.object_def.description
413406
json_schemas.append(json_schema)
414407
object_keys.append(name)
415408

409+
if base_processor:
410+
json_schema = base_processor.object_def.json_schema.copy()
411+
json_schemas.append(json_schema)
412+
object_keys.append(json_schema.get('title', 'result'))
413+
416414
special_output_types: list[type] = []
417415
if allows_text:
418416
special_output_types.append(str)
@@ -422,17 +420,16 @@ def build_json_schema( # noqa: C901
422420
special_output_types.append(_messages.BinaryImage)
423421
if special_output_types:
424422
for output_type in special_output_types:
425-
output_type_json_schema = TypeAdapter(output_type).json_schema()
423+
output_type_json_schema = TypeAdapter(output_type).json_schema(mode='serialization')
426424
json_schemas.append(output_type_json_schema)
427-
object_key = output_type.__name__
428-
object_keys.append(object_key)
429-
430-
json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas)
425+
object_keys.append(output_type.__name__)
431426

432-
# do not discriminate JSON if not needed
433-
if len(json_schemas) == 1 and not all_defs:
427+
# do not further process JSON if not needed
428+
if len(json_schemas) == 1:
434429
return json_schemas[0]
435430

431+
json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas)
432+
436433
unique_object_keys: list[str] = []
437434
for key in object_keys:
438435
count = 1
@@ -442,44 +439,7 @@ def build_json_schema( # noqa: C901
442439
new_key = f'{key}_{count}'
443440
unique_object_keys.append(new_key)
444441

445-
discriminated_json_schemas: list[ObjectJsonSchema] = []
446-
for object_key, json_schema in zip(unique_object_keys, json_schemas):
447-
title = json_schema.pop('title', None)
448-
description = json_schema.pop('description', None)
449-
450-
discriminated_json_schema = {
451-
'type': 'object',
452-
'properties': {
453-
'kind': {
454-
'type': 'string',
455-
'const': object_key,
456-
},
457-
'data': json_schema,
458-
},
459-
'required': ['kind', 'data'],
460-
'additionalProperties': False,
461-
}
462-
if title: # pragma: no branch
463-
discriminated_json_schema['title'] = title
464-
if description:
465-
discriminated_json_schema['description'] = description
466-
467-
discriminated_json_schemas.append(discriminated_json_schema)
468-
469-
json_schema = {
470-
'type': 'object',
471-
'properties': {
472-
'result': {
473-
'anyOf': discriminated_json_schemas,
474-
}
475-
},
476-
'required': ['result'],
477-
'additionalProperties': False,
478-
}
479-
if all_defs:
480-
json_schema['$defs'] = all_defs
481-
482-
return json_schema
442+
return UnionOutputProcessor.make_discriminated_json_schema_union(unique_object_keys, json_schemas, all_defs)
483443

484444

485445
@dataclass(init=False)
@@ -503,17 +463,20 @@ def __init__(
503463
allows_deferred_tools=allows_deferred_tools,
504464
allows_image=allows_image,
505465
)
506-
self.json_schema = OutputSchema[OutputDataT].build_json_schema(
507-
base_processor=processor,
508-
allows_deferred_tools=allows_deferred_tools,
509-
allows_image=allows_image,
510-
)
511466
self.processor = processor
512467

513468
@property
514469
def mode(self) -> OutputMode:
515470
return 'auto'
516471

472+
@cached_property
473+
def json_schema(self) -> JsonSchema:
474+
return OutputSchema[OutputDataT].build_json_schema(
475+
base_processor=self.processor,
476+
allows_deferred_tools=self.allows_deferred_tools,
477+
allows_image=self.allows_image,
478+
)
479+
517480

518481
@dataclass(init=False)
519482
class TextOutputSchema(OutputSchema[OutputDataT]):
@@ -529,17 +492,17 @@ def __init__(
529492
allows_deferred_tools=allows_deferred_tools,
530493
allows_image=allows_image,
531494
)
532-
if allows_deferred_tools or allows_image:
533-
self.json_schema = OutputSchema[OutputDataT].build_json_schema(
534-
allows_deferred_tools=allows_deferred_tools, allows_image=allows_image, allows_text=True
535-
)
536-
else:
537-
self.json_schema = OutputSchema[OutputDataT].build_json_schema()
538495

539496
@property
540497
def mode(self) -> OutputMode:
541498
return 'text'
542499

500+
@cached_property
501+
def json_schema(self) -> JsonSchema:
502+
return OutputSchema[OutputDataT].build_json_schema(
503+
allows_deferred_tools=self.allows_deferred_tools, allows_image=self.allows_image, allows_text=True
504+
)
505+
543506

544507
class ImageOutputSchema(OutputSchema[OutputDataT]):
545508
def __init__(self, *, allows_deferred_tools: bool):
@@ -566,13 +529,16 @@ def __init__(
566529
allows_deferred_tools=allows_deferred_tools,
567530
allows_image=allows_image,
568531
)
569-
self.json_schema = OutputSchema[OutputDataT].build_json_schema(
570-
base_processor=processor,
571-
allows_deferred_tools=allows_deferred_tools,
572-
allows_image=allows_image,
573-
)
574532
self.processor = processor
575533

534+
@cached_property
535+
def json_schema(self) -> JsonSchema:
536+
return OutputSchema[OutputDataT].build_json_schema(
537+
base_processor=self.processor,
538+
allows_deferred_tools=self.allows_deferred_tools,
539+
allows_image=self.allows_image,
540+
)
541+
576542

577543
class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
578544
@property
@@ -634,17 +600,20 @@ def __init__(
634600
text_processor=text_processor,
635601
allows_image=allows_image,
636602
)
637-
self.json_schema = OutputSchema[OutputDataT].build_json_schema(
603+
604+
@property
605+
def mode(self) -> OutputMode:
606+
return 'tool'
607+
608+
@cached_property
609+
def json_schema(self) -> JsonSchema:
610+
return OutputSchema[OutputDataT].build_json_schema(
638611
toolset_processors=self.toolset.processors, # pyright: ignore[reportOptionalMemberAccess]
639612
allows_deferred_tools=self.allows_deferred_tools,
640613
allows_image=self.allows_image,
641614
allows_text=self.allows_text,
642615
)
643616

644-
@property
645-
def mode(self) -> OutputMode:
646-
return 'tool'
647-
648617

649618
class BaseOutputProcessor(ABC, Generic[OutputDataT]):
650619
@abstractmethod
@@ -850,8 +819,29 @@ def __init__(
850819

851820
json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas)
852821

822+
json_schema = UnionOutputProcessor.make_discriminated_json_schema_union(
823+
object_keys=list(self._processors.keys()),
824+
json_schemas=json_schemas,
825+
all_defs=all_defs,
826+
)
827+
828+
super().__init__(
829+
object_def=OutputObjectDefinition(
830+
json_schema=json_schema,
831+
strict=strict,
832+
name=name,
833+
description=description,
834+
)
835+
)
836+
837+
@staticmethod
838+
def make_discriminated_json_schema_union(
839+
object_keys: Sequence[str],
840+
json_schemas: Sequence[JsonSchema],
841+
all_defs: JsonSchema,
842+
):
853843
discriminated_json_schemas: list[ObjectJsonSchema] = []
854-
for object_key, json_schema in zip(self._processors.keys(), json_schemas):
844+
for object_key, json_schema in zip(object_keys, json_schemas):
855845
title = json_schema.pop('title', None)
856846
description = json_schema.pop('description', None)
857847

@@ -887,14 +877,7 @@ def __init__(
887877
if all_defs:
888878
json_schema['$defs'] = all_defs
889879

890-
super().__init__(
891-
object_def=OutputObjectDefinition(
892-
json_schema=json_schema,
893-
strict=strict,
894-
name=name,
895-
description=description,
896-
)
897-
)
880+
return json_schema
898881

899882
async def process(
900883
self,

0 commit comments

Comments
 (0)