|
3 | 3 | import dataclasses |
4 | 4 | import json |
5 | 5 | import logging |
6 | | -from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, cast |
| 6 | +from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence, Union, cast |
7 | 7 | from urllib.parse import parse_qs |
8 | 8 |
|
9 | 9 | from pydantic import BaseModel |
| 10 | +from typing_extensions import get_args, get_origin |
10 | 11 |
|
11 | 12 | from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler |
12 | 13 | from aws_lambda_powertools.event_handler.openapi.compat import ( |
|
25 | 26 | ResponseValidationError, |
26 | 27 | ) |
27 | 28 | from aws_lambda_powertools.event_handler.openapi.params import Param |
| 29 | +from aws_lambda_powertools.event_handler.openapi.types import UnionType |
28 | 30 |
|
29 | 31 | if TYPE_CHECKING: |
30 | 32 | from pydantic.fields import FieldInfo |
@@ -431,9 +433,41 @@ def _handle_missing_field_value( |
431 | 433 | values[field.name] = field.get_default() |
432 | 434 |
|
433 | 435 |
|
| 436 | +def _is_or_contains_sequence(annotation: Any) -> bool: |
| 437 | + """ |
| 438 | + Check if annotation is a sequence or Union/RootModel containing a sequence. |
| 439 | +
|
| 440 | + This function handles complex type annotations like: |
| 441 | + - List[Model] - direct sequence |
| 442 | + - Union[Model, List[Model]] - checks if any Union member is a sequence |
| 443 | + - Optional[List[Model]] - Union[List[Model], None] |
| 444 | + - RootModel[List[Model]] - checks if the RootModel wraps a sequence |
| 445 | + - Optional[RootModel[List[Model]]] - Union member that is a RootModel |
| 446 | + - RootModel[Union[Model, List[Model]]] - RootModel wrapping a Union with a sequence |
| 447 | + """ |
| 448 | + # Direct sequence check |
| 449 | + if field_annotation_is_sequence(annotation): |
| 450 | + return True |
| 451 | + |
| 452 | + # Check Union members — recurse so we catch RootModel inside Union |
| 453 | + origin = get_origin(annotation) |
| 454 | + if origin is Union or origin is UnionType: |
| 455 | + for arg in get_args(annotation): |
| 456 | + if _is_or_contains_sequence(arg): |
| 457 | + return True |
| 458 | + |
| 459 | + # Check if it's a RootModel wrapping a sequence (or Union containing a sequence) |
| 460 | + if lenient_issubclass(annotation, BaseModel) and getattr(annotation, "__pydantic_root_model__", False): |
| 461 | + if hasattr(annotation, "model_fields") and "root" in annotation.model_fields: |
| 462 | + root_annotation = annotation.model_fields["root"].annotation |
| 463 | + return _is_or_contains_sequence(root_annotation) |
| 464 | + |
| 465 | + return False |
| 466 | + |
| 467 | + |
434 | 468 | def _normalize_field_value(value: Any, field_info: FieldInfo) -> Any: |
435 | 469 | """Normalize field value, converting lists to single values for non-sequence fields.""" |
436 | | - if field_annotation_is_sequence(field_info.annotation): |
| 470 | + if _is_or_contains_sequence(field_info.annotation): |
437 | 471 | return value |
438 | 472 | elif isinstance(value, list) and value: |
439 | 473 | return value[0] |
|
0 commit comments