Skip to content

Commit b66665c

Browse files
feat(bedrock_agent): add new Amazon Bedrock Agents Functions Resolver (#6564)
* feat(bedrock_agent): create bedrock agents functions data class * create resolver * mypy * add response * add name param to tool * add response optional fields * bedrockfunctionresponse and response state * remove body message * add parser * add test for required fields * add more tests for parser and resolver * add validation response state * params injection * doc event handler, parser and data class * fix doc typo * fix doc typo * mypy * Small refactor + documentation * Small refactor + documentation * Small refactor + documentation * Small refactor + documentation * Aligning Python implementation with TS * Adding custom serializer * Adding custom serializer * More documentation --------- Signed-off-by: Ana Falcão <[email protected]> Co-authored-by: Leandro Damascena <[email protected]>
1 parent 9c16b81 commit b66665c

File tree

22 files changed

+1252
-152
lines changed

22 files changed

+1252
-152
lines changed

aws_lambda_powertools/event_handler/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
)
1313
from aws_lambda_powertools.event_handler.appsync import AppSyncResolver
1414
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver, BedrockResponse
15+
from aws_lambda_powertools.event_handler.bedrock_agent_function import (
16+
BedrockAgentFunctionResolver,
17+
BedrockFunctionResponse,
18+
)
1519
from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver
1620
from aws_lambda_powertools.event_handler.lambda_function_url import (
1721
LambdaFunctionUrlResolver,
@@ -26,7 +30,9 @@
2630
"ALBResolver",
2731
"ApiGatewayResolver",
2832
"BedrockAgentResolver",
33+
"BedrockAgentFunctionResolver",
2934
"BedrockResponse",
35+
"BedrockFunctionResponse",
3036
"CORSConfig",
3137
"LambdaFunctionUrlResolver",
3238
"Response",
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
import json
5+
import logging
6+
import warnings
7+
from collections.abc import Callable
8+
from typing import Any, Literal, TypeVar
9+
10+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent
11+
from aws_lambda_powertools.warnings import PowertoolsUserWarning
12+
13+
# Define a generic type for the function
14+
T = TypeVar("T", bound=Callable[..., Any])
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class BedrockFunctionResponse:
20+
"""Response class for Bedrock Agent Functions.
21+
22+
Parameters
23+
----------
24+
body : Any, optional
25+
Response body to be returned to the caller.
26+
session_attributes : dict[str, str] or None, optional
27+
Session attributes to include in the response for maintaining state.
28+
prompt_session_attributes : dict[str, str] or None, optional
29+
Prompt session attributes to include in the response.
30+
knowledge_bases : list[dict[str, Any]] or None, optional
31+
Knowledge bases to include in the response.
32+
response_state : {"FAILURE", "REPROMPT"} or None, optional
33+
Response state indicating if the function failed or needs reprompting.
34+
35+
Examples
36+
--------
37+
>>> @app.tool(description="Function that uses session attributes")
38+
>>> def test_function():
39+
... return BedrockFunctionResponse(
40+
... body="Hello",
41+
... session_attributes={"userId": "123"},
42+
... prompt_session_attributes={"lastAction": "login"}
43+
... )
44+
45+
Notes
46+
-----
47+
The `response_state` parameter can only be set to "FAILURE" or "REPROMPT".
48+
"""
49+
50+
def __init__(
51+
self,
52+
body: Any = None,
53+
session_attributes: dict[str, str] | None = None,
54+
prompt_session_attributes: dict[str, str] | None = None,
55+
knowledge_bases: list[dict[str, Any]] | None = None,
56+
response_state: Literal["FAILURE", "REPROMPT"] | None = None,
57+
) -> None:
58+
if response_state and response_state not in ["FAILURE", "REPROMPT"]:
59+
raise ValueError("responseState must be 'FAILURE' or 'REPROMPT'")
60+
61+
self.body = body
62+
self.session_attributes = session_attributes
63+
self.prompt_session_attributes = prompt_session_attributes
64+
self.knowledge_bases = knowledge_bases
65+
self.response_state = response_state
66+
67+
68+
class BedrockFunctionsResponseBuilder:
69+
"""
70+
Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda
71+
when using Bedrock Agent Functions.
72+
"""
73+
74+
def __init__(self, result: BedrockFunctionResponse | Any) -> None:
75+
self.result = result
76+
77+
def build(self, event: BedrockAgentFunctionEvent, serializer: Callable) -> dict[str, Any]:
78+
result_obj = self.result
79+
80+
# Extract attributes from BedrockFunctionResponse or use defaults
81+
body = getattr(result_obj, "body", result_obj)
82+
session_attributes = getattr(result_obj, "session_attributes", None)
83+
prompt_session_attributes = getattr(result_obj, "prompt_session_attributes", None)
84+
knowledge_bases = getattr(result_obj, "knowledge_bases", None)
85+
response_state = getattr(result_obj, "response_state", None)
86+
87+
# Build base response structure
88+
# Per AWS Bedrock documentation, currently only "TEXT" is supported as the responseBody content type
89+
# https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html
90+
response: dict[str, Any] = {
91+
"messageVersion": "1.0",
92+
"response": {
93+
"actionGroup": event.action_group,
94+
"function": event.function,
95+
"functionResponse": {
96+
"responseBody": {"TEXT": {"body": serializer(body if body is not None else "")}},
97+
},
98+
},
99+
"sessionAttributes": session_attributes or event.session_attributes or {},
100+
"promptSessionAttributes": prompt_session_attributes or event.prompt_session_attributes or {},
101+
}
102+
103+
# Add optional fields when present
104+
if response_state:
105+
response["response"]["functionResponse"]["responseState"] = response_state
106+
107+
if knowledge_bases:
108+
response["knowledgeBasesConfiguration"] = knowledge_bases
109+
110+
return response
111+
112+
113+
class BedrockAgentFunctionResolver:
114+
"""Bedrock Agent Function resolver that handles function definitions
115+
116+
Examples
117+
--------
118+
```python
119+
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver
120+
121+
app = BedrockAgentFunctionResolver()
122+
123+
@app.tool(name="get_current_time", description="Gets the current UTC time")
124+
def get_current_time():
125+
from datetime import datetime
126+
return datetime.utcnow().isoformat()
127+
128+
def lambda_handler(event, context):
129+
return app.resolve(event, context)
130+
```
131+
"""
132+
133+
context: dict
134+
135+
def __init__(self, serializer: Callable | None = None) -> None:
136+
self._tools: dict[str, dict[str, Any]] = {}
137+
self.current_event: BedrockAgentFunctionEvent | None = None
138+
self.context = {}
139+
self._response_builder_class = BedrockFunctionsResponseBuilder
140+
self.serializer = serializer or json.dumps
141+
142+
def tool(
143+
self,
144+
name: str | None = None,
145+
description: str | None = None,
146+
) -> Callable[[T], T]:
147+
"""Decorator to register a tool function
148+
149+
Parameters
150+
----------
151+
name : str | None
152+
Custom name for the tool. If not provided, uses the function name
153+
description : str | None
154+
Description of what the tool does
155+
156+
Returns
157+
-------
158+
Callable
159+
Decorator function that registers and returns the original function
160+
"""
161+
162+
def decorator(func: T) -> T:
163+
function_name = name or func.__name__
164+
165+
logger.debug(f"Registering {function_name} tool")
166+
167+
if function_name in self._tools:
168+
warnings.warn(
169+
f"Tool '{function_name}' already registered. Overwriting with new definition.",
170+
PowertoolsUserWarning,
171+
stacklevel=2,
172+
)
173+
174+
self._tools[function_name] = {
175+
"function": func,
176+
"description": description,
177+
}
178+
return func
179+
180+
return decorator
181+
182+
def resolve(self, event: dict[str, Any], context: Any) -> dict[str, Any]:
183+
"""Resolves the function call from Bedrock Agent event"""
184+
try:
185+
self.current_event = BedrockAgentFunctionEvent(event)
186+
return self._resolve()
187+
except KeyError as e:
188+
raise ValueError(f"Missing required field: {str(e)}") from e
189+
190+
def _resolve(self) -> dict[str, Any]:
191+
"""Internal resolution logic"""
192+
if self.current_event is None:
193+
raise ValueError("No event to process")
194+
195+
function_name = self.current_event.function
196+
197+
logger.debug(f"Resolving {function_name} tool")
198+
199+
try:
200+
parameters: dict[str, Any] = {}
201+
# Extract parameters from the event
202+
for param in getattr(self.current_event, "parameters", []):
203+
param_type = getattr(param, "type", None)
204+
if param_type == "string":
205+
parameters[param.name] = str(param.value)
206+
elif param_type == "integer":
207+
try:
208+
parameters[param.name] = int(param.value)
209+
except (ValueError, TypeError):
210+
parameters[param.name] = param.value
211+
elif param_type == "number":
212+
try:
213+
parameters[param.name] = float(param.value)
214+
except (ValueError, TypeError):
215+
parameters[param.name] = param.value
216+
elif param_type == "boolean":
217+
if isinstance(param.value, str):
218+
parameters[param.name] = param.value.lower() == "true"
219+
else:
220+
parameters[param.name] = bool(param.value)
221+
else: # "array" or any other type
222+
parameters[param.name] = param.value
223+
224+
func = self._tools[function_name]["function"]
225+
# Filter parameters to only include those expected by the function
226+
sig = inspect.signature(func)
227+
valid_params = {name: value for name, value in parameters.items() if name in sig.parameters}
228+
229+
# Call the function with the filtered parameters
230+
result = func(**valid_params)
231+
232+
self.clear_context()
233+
234+
# Build and return the response
235+
return BedrockFunctionsResponseBuilder(result).build(self.current_event, serializer=self.serializer)
236+
except Exception as error:
237+
# Return a formatted error response
238+
logger.error(f"Error processing function: {function_name}", exc_info=True)
239+
error_response = BedrockFunctionResponse(body=f"Error: {error.__class__.__name__}: {str(error)}")
240+
return BedrockFunctionsResponseBuilder(error_response).build(self.current_event, serializer=self.serializer)
241+
242+
def append_context(self, **additional_context):
243+
"""Append key=value data as routing context"""
244+
self.context.update(**additional_context)
245+
246+
def clear_context(self):
247+
"""Resets routing context"""
248+
self.context.clear()

aws_lambda_powertools/utilities/data_classes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .appsync_resolver_events_event import AppSyncResolverEventsEvent
1010
from .aws_config_rule_event import AWSConfigRuleEvent
1111
from .bedrock_agent_event import BedrockAgentEvent
12+
from .bedrock_agent_function_event import BedrockAgentFunctionEvent
1213
from .cloud_watch_alarm_event import (
1314
CloudWatchAlarmConfiguration,
1415
CloudWatchAlarmData,
@@ -59,6 +60,7 @@
5960
"AppSyncResolverEventsEvent",
6061
"ALBEvent",
6162
"BedrockAgentEvent",
63+
"BedrockAgentFunctionEvent",
6264
"CloudWatchAlarmData",
6365
"CloudWatchAlarmEvent",
6466
"CloudWatchAlarmMetric",
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
3+
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
4+
5+
6+
class BedrockAgentInfo(DictWrapper):
7+
@property
8+
def name(self) -> str:
9+
return self["name"]
10+
11+
@property
12+
def id(self) -> str: # noqa: A003
13+
return self["id"]
14+
15+
@property
16+
def alias(self) -> str:
17+
return self["alias"]
18+
19+
@property
20+
def version(self) -> str:
21+
return self["version"]
22+
23+
24+
class BedrockAgentFunctionParameter(DictWrapper):
25+
@property
26+
def name(self) -> str:
27+
return self["name"]
28+
29+
@property
30+
def type(self) -> str: # noqa: A003
31+
return self["type"]
32+
33+
@property
34+
def value(self) -> str:
35+
return self["value"]
36+
37+
38+
class BedrockAgentFunctionEvent(DictWrapper):
39+
"""
40+
Bedrock Agent Function input event
41+
42+
Documentation:
43+
https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html
44+
"""
45+
46+
@property
47+
def message_version(self) -> str:
48+
return self["messageVersion"]
49+
50+
@property
51+
def input_text(self) -> str:
52+
return self["inputText"]
53+
54+
@property
55+
def session_id(self) -> str:
56+
return self["sessionId"]
57+
58+
@property
59+
def action_group(self) -> str:
60+
return self["actionGroup"]
61+
62+
@property
63+
def function(self) -> str:
64+
return self["function"]
65+
66+
@property
67+
def parameters(self) -> list[BedrockAgentFunctionParameter]:
68+
parameters = self.get("parameters") or []
69+
return [BedrockAgentFunctionParameter(x) for x in parameters]
70+
71+
@property
72+
def agent(self) -> BedrockAgentInfo:
73+
return BedrockAgentInfo(self["agent"])
74+
75+
@property
76+
def session_attributes(self) -> dict[str, str]:
77+
return self.get("sessionAttributes", {}) or {}
78+
79+
@property
80+
def prompt_session_attributes(self) -> dict[str, str]:
81+
return self.get("promptSessionAttributes", {}) or {}

aws_lambda_powertools/utilities/parser/envelopes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .apigw_websocket import ApiGatewayWebSocketEnvelope
33
from .apigwv2 import ApiGatewayV2Envelope
44
from .base import BaseEnvelope
5-
from .bedrock_agent import BedrockAgentEnvelope
5+
from .bedrock_agent import BedrockAgentEnvelope, BedrockAgentFunctionEnvelope
66
from .cloudwatch import CloudWatchLogsEnvelope
77
from .dynamodb import DynamoDBStreamEnvelope
88
from .event_bridge import EventBridgeEnvelope
@@ -20,6 +20,7 @@
2020
"ApiGatewayV2Envelope",
2121
"ApiGatewayWebSocketEnvelope",
2222
"BedrockAgentEnvelope",
23+
"BedrockAgentFunctionEnvelope",
2324
"CloudWatchLogsEnvelope",
2425
"DynamoDBStreamEnvelope",
2526
"EventBridgeEnvelope",

0 commit comments

Comments
 (0)