Skip to content

Commit 896c0ee

Browse files
fix(event_handler): fix bug regression in Annotated field (#7904)
1 parent ca44496 commit 896c0ee

File tree

2 files changed

+134
-2
lines changed

2 files changed

+134
-2
lines changed

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,11 +1112,23 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
11121112
"""
11131113
annotated_args = get_args(annotation)
11141114
type_annotation = annotated_args[0]
1115-
powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)]
1115+
1116+
# Handle both FieldInfo instances and FieldInfo subclasses (e.g., Body vs Body())
1117+
powertools_annotations: list[FieldInfo] = []
1118+
for arg in annotated_args[1:]:
1119+
if isinstance(arg, FieldInfo):
1120+
powertools_annotations.append(arg)
1121+
elif isinstance(arg, type) and issubclass(arg, FieldInfo):
1122+
# If it's a class (e.g., Body instead of Body()), instantiate it
1123+
powertools_annotations.append(arg())
11161124

11171125
# Preserve non-FieldInfo metadata (like annotated_types constraints)
11181126
# This is important for constraints like Interval, Gt, Lt, etc.
1119-
other_metadata = [arg for arg in annotated_args[1:] if not isinstance(arg, FieldInfo)]
1127+
other_metadata = [
1128+
arg
1129+
for arg in annotated_args[1:]
1130+
if not isinstance(arg, FieldInfo) and not (isinstance(arg, type) and issubclass(arg, FieldInfo))
1131+
]
11201132

11211133
# Determine which annotation to use
11221134
powertools_annotation: FieldInfo | None = None

tests/functional/event_handler/_pydantic/test_openapi_params.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,3 +1267,123 @@ def list_items(limit: Annotated[constrained_int, Query()] = 10):
12671267
assert limit_param.schema_.type == "integer"
12681268
assert limit_param.schema_.default == 10
12691269
assert limit_param.required is False
1270+
1271+
1272+
def test_body_class_annotation_without_parentheses():
1273+
"""
1274+
GIVEN an endpoint using Body class (not instance) in Annotated
1275+
WHEN sending a valid request body
1276+
THEN the request should be validated correctly
1277+
"""
1278+
app = APIGatewayRestResolver(enable_validation=True)
1279+
1280+
class MyRequest(BaseModel):
1281+
foo: str
1282+
bar: str = "default_bar"
1283+
1284+
class MyResponse(BaseModel):
1285+
concatenated: str
1286+
1287+
# Using Body (class) instead of Body() (instance)
1288+
@app.patch("/test")
1289+
def handler(body: Annotated[MyRequest, Body]) -> MyResponse:
1290+
return MyResponse(concatenated=body.foo + body.bar)
1291+
1292+
event = {
1293+
"resource": "/test",
1294+
"path": "/test",
1295+
"httpMethod": "PATCH",
1296+
"body": '{"foo": "hello"}',
1297+
"isBase64Encoded": False,
1298+
}
1299+
1300+
result = app(event, {})
1301+
assert result["statusCode"] == 200
1302+
response_body = json.loads(result["body"])
1303+
assert response_body["concatenated"] == "hellodefault_bar"
1304+
1305+
1306+
def test_body_instance_annotation_with_parentheses():
1307+
"""
1308+
GIVEN an endpoint using Body() instance in Annotated
1309+
WHEN sending a valid request body
1310+
THEN the request should be validated correctly
1311+
"""
1312+
app = APIGatewayRestResolver(enable_validation=True)
1313+
1314+
class MyRequest(BaseModel):
1315+
foo: str
1316+
bar: str = "default_bar"
1317+
1318+
class MyResponse(BaseModel):
1319+
concatenated: str
1320+
1321+
# Using Body() (instance)
1322+
@app.patch("/test")
1323+
def handler(body: Annotated[MyRequest, Body()]) -> MyResponse:
1324+
return MyResponse(concatenated=body.foo + body.bar)
1325+
1326+
event = {
1327+
"resource": "/test",
1328+
"path": "/test",
1329+
"httpMethod": "PATCH",
1330+
"body": '{"foo": "hello"}',
1331+
"isBase64Encoded": False,
1332+
}
1333+
1334+
result = app(event, {})
1335+
assert result["statusCode"] == 200
1336+
response_body = json.loads(result["body"])
1337+
assert response_body["concatenated"] == "hellodefault_bar"
1338+
1339+
1340+
def test_query_class_annotation_without_parentheses():
1341+
"""
1342+
GIVEN an endpoint using Query class (not instance) in Annotated
1343+
WHEN sending a valid query parameter
1344+
THEN the request should be validated correctly
1345+
"""
1346+
app = APIGatewayRestResolver(enable_validation=True)
1347+
1348+
@app.get("/test")
1349+
def handler(name: Annotated[str, Query]) -> dict:
1350+
return {"name": name}
1351+
1352+
event = {
1353+
"resource": "/test",
1354+
"path": "/test",
1355+
"httpMethod": "GET",
1356+
"queryStringParameters": {"name": "hello"},
1357+
"isBase64Encoded": False,
1358+
}
1359+
1360+
result = app(event, {})
1361+
assert result["statusCode"] == 200
1362+
response_body = json.loads(result["body"])
1363+
assert response_body["name"] == "hello"
1364+
1365+
1366+
def test_header_class_annotation_without_parentheses():
1367+
"""
1368+
GIVEN an endpoint using Header class (not instance) in Annotated
1369+
WHEN sending a valid header
1370+
THEN the request should be validated correctly
1371+
"""
1372+
app = APIGatewayRestResolver(enable_validation=True)
1373+
1374+
@app.get("/test")
1375+
def handler(x_custom: Annotated[str, Header]) -> dict:
1376+
return {"header": x_custom}
1377+
1378+
event = {
1379+
"resource": "/test",
1380+
"path": "/test",
1381+
"httpMethod": "GET",
1382+
"headers": {"x-custom": "my-value"},
1383+
"isBase64Encoded": False,
1384+
}
1385+
1386+
result = app(event, {})
1387+
assert result["statusCode"] == 200
1388+
response_body = json.loads(result["body"])
1389+
assert response_body["header"] == "my-value"

0 commit comments

Comments
 (0)