Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,13 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
return self._handle_response(route=route, response=response)

def _handle_response(self, *, route: Route, response: Response):
# Process the response body if it exists
if response.body:
# Validate and serialize the response, if it's JSON
if response.is_json():
response.body = self._serialize_response(
field=route.dependant.return_param,
response_content=response.body,
)
# Check if we have a return type defined
if route.dependant.return_param:
# Validate and serialize the response, including None
response.body = self._serialize_response(
field=route.dependant.return_param,
response_content=response.body,
)

return response

Expand All @@ -164,15 +163,6 @@ def _serialize_response(
"""
if field:
errors: list[dict[str, Any]] = []
# MAINTENANCE: remove this when we drop pydantic v1
if not hasattr(field, "serializable"):
response_content = self._prepare_response_content(
response_content,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)

value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
if errors:
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
Expand All @@ -187,7 +177,6 @@ def _serialize_response(
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)

return jsonable_encoder(
value,
include=include,
Expand All @@ -199,7 +188,7 @@ def _serialize_response(
custom_serializer=self._validation_serializer,
)
else:
# Just serialize the response content returned from the handler
# Just serialize the response content returned from the handler.
return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)

def _prepare_response_content(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import json
from typing import Dict
from dataclasses import dataclass
from typing import Dict, Optional, Set

import pytest
from pydantic import BaseModel

from aws_lambda_powertools.event_handler import APIGatewayRestResolver


@dataclass
class Person:
name: str
birth_date: str
scores: Set[int]


def test_openapi_duplicated_serialization():
# GIVEN APIGatewayRestResolver is initialized with enable_validation=True
app = APIGatewayRestResolver(enable_validation=True)
Expand Down Expand Up @@ -61,3 +70,124 @@ def handler():

# THEN we the custom serializer should be used
assert response["body"] == "hello world"


def test_valid_model_returned_for_optional_type(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

@app.get("/valid_optional")
def handler_valid_optional() -> Optional[Model]:
return Model(name="John", age=30)

# WHEN returning a valid model for an Optional type
gw_event["path"] = "/valid_optional"
result = app(gw_event, {})

# THEN it should succeed and return the serialized model
assert result["statusCode"] == 200
assert json.loads(result["body"]) == {"name": "John", "age": 30}


def test_serialize_response_without_field(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler is defined without return type annotation
@app.get("/test")
def handler():
return {"message": "Hello, World!"}

gw_event["path"] = "/test"

# THEN the handler should be invoked and return 200
# AND the body must be a JSON object
response = app(gw_event, None)
assert response["statusCode"] == 200
assert response["body"] == '{"message":"Hello, World!"}'


def test_serialize_response_list(gw_event):
"""Test serialization of list responses containing complex types"""
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler returns a list containing various types
@app.get("/test")
def handler():
return [{"set": [1, 2, 3]}, {"simple": "value"}]

gw_event["path"] = "/test"

# THEN the response should be properly serialized
response = app(gw_event, None)
assert response["statusCode"] == 200
assert response["body"] == '[{"set":[1,2,3]},{"simple":"value"}]'


def test_serialize_response_nested_dict(gw_event):
"""Test serialization of nested dictionary responses"""
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler returns a nested dictionary with complex types
@app.get("/test")
def handler():
return {"nested": {"date": "2000-01-01", "set": [1, 2, 3]}, "simple": "value"}

gw_event["path"] = "/test"

# THEN the response should be properly serialized
response = app(gw_event, None)
assert response["statusCode"] == 200
assert response["body"] == '{"nested":{"date":"2000-01-01","set":[1,2,3]},"simple":"value"}'


def test_serialize_response_dataclass(gw_event):
"""Test serialization of dataclass responses"""
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler returns a dataclass instance
@app.get("/test")
def handler():
return Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91])

gw_event["path"] = "/test"

# THEN the response should be properly serialized
response = app(gw_event, None)
assert response["statusCode"] == 200
assert response["body"] == '{"name":"John Doe","birth_date":"1990-01-01","scores":[95,87,91]}'


def test_serialize_response_mixed_types(gw_event):
"""Test serialization of mixed type responses"""
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler returns a response with mixed types
@app.get("/test")
def handler():
person = Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91])
return {
"person": person,
"records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}],
"metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]},
}

gw_event["path"] = "/test"

# THEN the response should be properly serialized
response = app(gw_event, None)
assert response["statusCode"] == 200
expected = {
"person": {"name": "John Doe", "birth_date": "1990-01-01", "scores": [95, 87, 91]},
"records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}],
"metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]},
}
assert json.loads(response["body"]) == expected
Original file line number Diff line number Diff line change
Expand Up @@ -1128,3 +1128,76 @@ def handler(user_id: int = 123):
# THEN the handler should be invoked and return 200
result = app(minimal_event, {})
assert result["statusCode"] == 200


def test_validation_error_none_returned_non_optional_type(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

@app.get("/none_not_allowed")
def handler_none_not_allowed() -> Model:
return None # type: ignore

# WHEN returning None for a non-Optional type
gw_event["path"] = "/none_not_allowed"
result = app(gw_event, {})

# THEN it should return a validation error
assert result["statusCode"] == 422
body = json.loads(result["body"])
assert "model_attributes_type" in body["detail"][0]["type"]


def test_none_returned_for_optional_type(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

@app.get("/none_allowed")
def handler_none_allowed() -> Optional[Model]:
return None

# WHEN returning None for an Optional type
gw_event["path"] = "/none_allowed"
result = app(gw_event, {})

# THEN it should succeed
assert result["statusCode"] == 200
assert result["body"] == "null"


@pytest.mark.parametrize(
"path, body",
[
("/empty_dict", {}),
("/empty_list", []),
("/none", "null"),
("/empty_string", ""),
],
ids=["empty_dict", "empty_list", "none", "empty_string"],
)
def test_none_returned_for_falsy_return(gw_event, path, body):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

@app.get(path)
def handler_none_allowed() -> Model:
return body

# WHEN returning None for an Optional type
gw_event["path"] = path
result = app(gw_event, {})

# THEN it should succeed
assert result["statusCode"] == 422
Loading