1
1
from __future__ import annotations
2
2
3
+ import decimal
4
+ import json
5
+
3
6
import pytest
4
7
5
8
from aws_lambda_powertools .event_handler import BedrockAgentFunctionResolver , BedrockFunctionResponse
@@ -37,7 +40,7 @@ def test_function():
37
40
assert result ["messageVersion" ] == "1.0"
38
41
assert result ["response" ]["actionGroup" ] == raw_event ["actionGroup" ]
39
42
assert result ["response" ]["function" ] == "test_function"
40
- assert result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ] == "Hello from string"
43
+ assert result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ] == json . dumps ( "Hello from string" )
41
44
assert "responseState" not in result ["response" ]["functionResponse" ]
42
45
43
46
@@ -55,7 +58,7 @@ def none_response_function():
55
58
result = app .resolve (raw_event , {})
56
59
57
60
# THEN process event correctly with empty string body
58
- assert result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ] == ""
61
+ assert result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ] == json . dumps ( "" )
59
62
60
63
61
64
def test_bedrock_agent_function_error_handling ():
@@ -106,7 +109,7 @@ def second_function():
106
109
result = app .resolve (raw_event , {})
107
110
108
111
# The second function should be used
109
- assert result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ] == "second test"
112
+ assert result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ] == json . dumps ( "second test" )
110
113
111
114
112
115
def test_bedrock_agent_function_with_optional_fields ():
@@ -133,7 +136,7 @@ def test_function():
133
136
result = app .resolve (raw_event , {})
134
137
135
138
# THEN include all optional fields in response
136
- assert result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ] == "Hello"
139
+ assert result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ] == json . dumps ( "Hello" )
137
140
assert result ["sessionAttributes" ] == {"userId" : "123" }
138
141
assert result ["promptSessionAttributes" ] == {"context" : "test" }
139
142
assert result ["knowledgeBasesConfiguration" ][0 ]["knowledgeBaseId" ] == "kb1"
@@ -300,9 +303,10 @@ def complex_response():
300
303
# THEN complex object should be converted to string representation
301
304
response_body = result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ]
302
305
# Check that it contains the expected string representation
303
- assert "{'key1': 'value1'" in response_body
304
- assert "'key2': 123" in response_body
305
- assert "'nested': {'inner': 'value'}" in response_body
306
+
307
+ assert response_body == json .dumps (
308
+ {"key1" : "value1" , "key2" : 123 , "nested" : {"inner" : "value" }},
309
+ )
306
310
307
311
308
312
def test_bedrock_agent_function_append_context ():
@@ -383,7 +387,7 @@ def vacation_request(month: int, payment: float, approved: bool):
383
387
result = app .resolve (raw_event , {})
384
388
385
389
# THEN parameters should be correctly passed to the function
386
- assert "Vacation request" == result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ]
390
+ assert result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ] == json . dumps ( "Vacation request" )
387
391
388
392
389
393
def test_bedrock_agent_function_with_parameters_casting_errors ():
@@ -418,7 +422,35 @@ def process_data(id_product: str, quantity: int, price: float, available: bool,
418
422
result = app .resolve (raw_event , {})
419
423
420
424
# THEN parameters should be handled properly despite casting errors
421
- assert (
422
- "Processed with casting errors handled"
423
- == result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ]
425
+ assert result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ] == json .dumps (
426
+ "Processed with casting errors handled" ,
424
427
)
428
+
429
+
430
+ def test_bedrock_agent_function_with_custom_serializer ():
431
+ """Test BedrockAgentFunctionResolver with a custom serializer for non-standard JSON types."""
432
+
433
+ def decimal_serializer (obj ):
434
+ if isinstance (obj , decimal .Decimal ):
435
+ return float (obj )
436
+ raise TypeError (f"Object of type { type (obj )} is not JSON serializable" )
437
+
438
+ # GIVEN a Bedrock Agent Function resolver with that custom serializer
439
+ app = BedrockAgentFunctionResolver (serializer = lambda obj : json .dumps (obj , default = decimal_serializer ))
440
+
441
+ @app .tool ()
442
+ def decimal_response ():
443
+ # Return a response with Decimal type that standard JSON can't serialize
444
+ return {"price" : decimal .Decimal ("99.99" )}
445
+
446
+ # WHEN calling with a response containing non-standard JSON types
447
+ raw_event = load_event ("bedrockAgentFunctionEvent.json" )
448
+ raw_event ["function" ] = "decimal_response"
449
+ result = app .resolve (raw_event , {})
450
+
451
+ # THEN non-standard types should be properly serialized
452
+ response_body = result ["response" ]["functionResponse" ]["responseBody" ]["TEXT" ]["body" ]
453
+ parsed_response = json .loads (response_body )
454
+
455
+ # VERIFY that decimal was converted to float and datetime to ISO string
456
+ assert parsed_response ["price" ] == 99.99
0 commit comments