Skip to content

Commit 214d061

Browse files
Adding custom serializer
1 parent c992463 commit 214d061

File tree

2 files changed

+51
-17
lines changed

2 files changed

+51
-17
lines changed

aws_lambda_powertools/event_handler/bedrock_agent_function.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import inspect
4+
import json
45
import logging
56
import warnings
67
from collections.abc import Callable
@@ -73,7 +74,7 @@ class BedrockFunctionsResponseBuilder:
7374
def __init__(self, result: BedrockFunctionResponse | Any) -> None:
7475
self.result = result
7576

76-
def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]:
77+
def build(self, event: BedrockAgentFunctionEvent, serializer: Callable) -> dict[str, Any]:
7778
result_obj = self.result
7879

7980
# Extract attributes from BedrockFunctionResponse or use defaults
@@ -92,7 +93,7 @@ def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]:
9293
"actionGroup": event.action_group,
9394
"function": event.function,
9495
"functionResponse": {
95-
"responseBody": {"TEXT": {"body": str(body if body is not None else "")}},
96+
"responseBody": {"TEXT": {"body": serializer(body if body is not None else "")}},
9697
},
9798
},
9899
"sessionAttributes": session_attributes or event.session_attributes or {},
@@ -119,7 +120,7 @@ class BedrockAgentFunctionResolver:
119120
120121
app = BedrockAgentFunctionResolver()
121122
122-
@app.tool(description="Gets the current UTC time")
123+
@app.tool(name="get_current_time", description="Gets the current UTC time")
123124
def get_current_time():
124125
from datetime import datetime
125126
return datetime.utcnow().isoformat()
@@ -131,11 +132,12 @@ def lambda_handler(event, context):
131132

132133
context: dict
133134

134-
def __init__(self) -> None:
135+
def __init__(self, serializer: Callable | None = None) -> None:
135136
self._tools: dict[str, dict[str, Any]] = {}
136137
self.current_event: BedrockAgentFunctionEvent | None = None
137138
self.context = {}
138139
self._response_builder_class = BedrockFunctionsResponseBuilder
140+
self.serializer = serializer or json.dumps
139141

140142
def tool(
141143
self,
@@ -230,12 +232,12 @@ def _resolve(self) -> dict[str, Any]:
230232
self.clear_context()
231233

232234
# Build and return the response
233-
return BedrockFunctionsResponseBuilder(result).build(self.current_event)
235+
return BedrockFunctionsResponseBuilder(result).build(self.current_event, serializer=self.serializer)
234236
except Exception as error:
235237
# Return a formatted error response
236238
logger.error(f"Error processing function: {function_name}", exc_info=True)
237239
error_response = BedrockFunctionResponse(body=f"Error: {error.__class__.__name__}: {str(error)}")
238-
return BedrockFunctionsResponseBuilder(error_response).build(self.current_event)
240+
return BedrockFunctionsResponseBuilder(error_response).build(self.current_event, serializer=self.serializer)
239241

240242
def append_context(self, **additional_context):
241243
"""Append key=value data as routing context"""

tests/functional/event_handler/required_dependencies/test_bedrock_agent_functions.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
import decimal
4+
import json
5+
36
import pytest
47

58
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver, BedrockFunctionResponse
@@ -37,7 +40,7 @@ def test_function():
3740
assert result["messageVersion"] == "1.0"
3841
assert result["response"]["actionGroup"] == raw_event["actionGroup"]
3942
assert result["response"]["function"] == "test_function"
40-
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "Hello from string"
43+
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == json.dumps("Hello from string")
4144
assert "responseState" not in result["response"]["functionResponse"]
4245

4346

@@ -55,7 +58,7 @@ def none_response_function():
5558
result = app.resolve(raw_event, {})
5659

5760
# THEN process event correctly with empty string body
58-
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == ""
61+
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == json.dumps("")
5962

6063

6164
def test_bedrock_agent_function_error_handling():
@@ -106,7 +109,7 @@ def second_function():
106109
result = app.resolve(raw_event, {})
107110

108111
# The second function should be used
109-
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "second test"
112+
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == json.dumps("second test")
110113

111114

112115
def test_bedrock_agent_function_with_optional_fields():
@@ -133,7 +136,7 @@ def test_function():
133136
result = app.resolve(raw_event, {})
134137

135138
# THEN include all optional fields in response
136-
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == "Hello"
139+
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == json.dumps("Hello")
137140
assert result["sessionAttributes"] == {"userId": "123"}
138141
assert result["promptSessionAttributes"] == {"context": "test"}
139142
assert result["knowledgeBasesConfiguration"][0]["knowledgeBaseId"] == "kb1"
@@ -300,9 +303,10 @@ def complex_response():
300303
# THEN complex object should be converted to string representation
301304
response_body = result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"]
302305
# Check that it contains the expected string representation
303-
assert "{'key1': 'value1'" in response_body
304-
assert "'key2': 123" in response_body
305-
assert "'nested': {'inner': 'value'}" in response_body
306+
307+
assert response_body == json.dumps(
308+
{"key1": "value1", "key2": 123, "nested": {"inner": "value"}},
309+
)
306310

307311

308312
def test_bedrock_agent_function_append_context():
@@ -383,7 +387,7 @@ def vacation_request(month: int, payment: float, approved: bool):
383387
result = app.resolve(raw_event, {})
384388

385389
# THEN parameters should be correctly passed to the function
386-
assert "Vacation request" == result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"]
390+
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == json.dumps("Vacation request")
387391

388392

389393
def test_bedrock_agent_function_with_parameters_casting_errors():
@@ -418,7 +422,35 @@ def process_data(id_product: str, quantity: int, price: float, available: bool,
418422
result = app.resolve(raw_event, {})
419423

420424
# THEN parameters should be handled properly despite casting errors
421-
assert (
422-
"Processed with casting errors handled"
423-
== result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"]
425+
assert result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"] == json.dumps(
426+
"Processed with casting errors handled",
424427
)
428+
429+
430+
def test_bedrock_agent_function_with_custom_serializer():
431+
"""Test BedrockAgentFunctionResolver with a custom serializer for non-standard JSON types."""
432+
433+
def decimal_serializer(obj):
434+
if isinstance(obj, decimal.Decimal):
435+
return float(obj)
436+
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
437+
438+
# GIVEN a Bedrock Agent Function resolver with that custom serializer
439+
app = BedrockAgentFunctionResolver(serializer=lambda obj: json.dumps(obj, default=decimal_serializer))
440+
441+
@app.tool()
442+
def decimal_response():
443+
# Return a response with Decimal type that standard JSON can't serialize
444+
return {"price": decimal.Decimal("99.99")}
445+
446+
# WHEN calling with a response containing non-standard JSON types
447+
raw_event = load_event("bedrockAgentFunctionEvent.json")
448+
raw_event["function"] = "decimal_response"
449+
result = app.resolve(raw_event, {})
450+
451+
# THEN non-standard types should be properly serialized
452+
response_body = result["response"]["functionResponse"]["responseBody"]["TEXT"]["body"]
453+
parsed_response = json.loads(response_body)
454+
455+
# VERIFY that decimal was converted to float and datetime to ISO string
456+
assert parsed_response["price"] == 99.99

0 commit comments

Comments
 (0)