Skip to content

Commit a87bc30

Browse files
glsukkiSukruth Gowdru Lingaraju
andauthored
feat: integrate metadata support for short-term-memory (STM) (#114)
Co-authored-by: Sukruth Gowdru Lingaraju <[email protected]>
1 parent 6334a24 commit a87bc30

File tree

4 files changed

+583
-10
lines changed

4 files changed

+583
-10
lines changed

src/bedrock_agentcore/memory/models/__init__.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33
from typing import Any, Dict
44

55
from .DictWrapper import DictWrapper
6-
6+
from .filters import (
7+
StringValue,
8+
MetadataValue,
9+
MetadataKey,
10+
LeftExpression,
11+
OperatorType,
12+
RightExpression,
13+
EventMetadataFilter,
14+
)
715

816
class ActorSummary(DictWrapper):
917
"""A class representing an actor summary."""
@@ -75,3 +83,20 @@ def __init__(self, session_summary: Dict[str, Any]):
7583
session_summary: Dictionary containing session summary data.
7684
"""
7785
super().__init__(session_summary)
86+
87+
__all__ = [
88+
"DictWrapper",
89+
"ActorSummary",
90+
"Branch",
91+
"Event",
92+
"EventMessage",
93+
"MemoryRecord",
94+
"SessionSummary",
95+
"StringValue",
96+
"MetadataValue",
97+
"MetadataKey",
98+
"LeftExpression",
99+
"OperatorType",
100+
"RightExpression",
101+
"EventMetadataFilter",
102+
]
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from enum import Enum
2+
from typing import Optional, TypedDict, Union, NotRequired
3+
4+
class StringValue(TypedDict):
5+
"""Value associated with the `eventMetadata` key."""
6+
stringValue: str
7+
8+
@staticmethod
9+
def build(value: str) -> 'StringValue':
10+
return {
11+
"stringValue": value
12+
}
13+
14+
MetadataValue = Union[StringValue]
15+
"""
16+
Union type representing metadata values.
17+
18+
Variants:
19+
- StringValue: {"stringValue": str} - String metadata value
20+
"""
21+
22+
MetadataKey = Union[str]
23+
"""
24+
Union type representing metadata key.
25+
"""
26+
27+
class LeftExpression(TypedDict):
28+
"""
29+
Left operand of the event metadata filter expression.
30+
"""
31+
metadataKey: MetadataKey
32+
33+
@staticmethod
34+
def build(key: str) -> 'LeftExpression':
35+
"""Builds the `metadataKey` for `LeftExpression`"""
36+
return {
37+
"metadataKey": key
38+
}
39+
40+
class OperatorType(Enum):
41+
"""
42+
Operator applied to the event metadata filter expression.
43+
44+
Currently supports:
45+
- `EQUALS_TO`
46+
- `EXISTS`
47+
- `NOT_EXISTS`
48+
"""
49+
EQUALS_TO = "EQUALS_TO"
50+
EXISTS = "EXISTS"
51+
NOT_EXISTS = "NOT_EXISTS"
52+
53+
class RightExpression(TypedDict):
54+
"""
55+
Right operand of the event metadata filter expression.
56+
57+
Variants:
58+
- StringValue: {"metadataValue": {"stringValue": str}}
59+
"""
60+
metadataValue: MetadataValue
61+
62+
@staticmethod
63+
def build(value: str) -> 'RightExpression':
64+
"""Builds the `RightExpression` for `stringValue` type"""
65+
return {"metadataValue": StringValue.build(value)}
66+
67+
class EventMetadataFilter(TypedDict):
68+
"""
69+
Filter expression for retrieving events based on metadata associated with an event.
70+
71+
Args:
72+
left: `LeftExpression` of the event metadata filter expression.
73+
operator: `OperatorType` applied to the event metadata filter expression.
74+
right: Optional `RightExpression` of the event metadata filter expression.
75+
"""
76+
left: LeftExpression
77+
operator: OperatorType
78+
right: NotRequired[RightExpression]
79+
80+
def build_expression(left_operand: LeftExpression, operator: OperatorType, right_operand: Optional[RightExpression] = None) -> 'EventMetadataFilter':
81+
"""
82+
This method builds the required event metadata filter expression into the `EventMetadataFilterExpression` type when querying listEvents.
83+
84+
Args:
85+
left_operand: Left operand of the event metadata filter expression
86+
operator: Operator applied to the event metadata filter expression
87+
right_operand: Optional right_operand of the event metadata filter expression.
88+
89+
Example:
90+
```
91+
left_operand = LeftExpression.build_key(key='location')
92+
operator = OperatorType.EQUALS_TO
93+
right_operand = RightExpression.build_string_value(value='NYC')
94+
```
95+
96+
#### Response Object:
97+
```
98+
{
99+
'left': {
100+
'metadataKey': 'location'
101+
},
102+
'operator': 'EQUALS_TO',
103+
'right': {
104+
'metadataValue': {
105+
'stringValue': 'NYC'
106+
}
107+
}
108+
}
109+
```
110+
"""
111+
filter = {
112+
'left': left_operand,
113+
'operator': operator.value
114+
}
115+
116+
if right_operand:
117+
filter['right'] = right_operand
118+
return filter

src/bedrock_agentcore/memory/session.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
EventMessage,
1919
MemoryRecord,
2020
SessionSummary,
21+
MetadataValue,
22+
EventMetadataFilter
2123
)
2224

2325
logger = logging.getLogger(__name__)
@@ -246,6 +248,7 @@ def process_turn_with_llm(
246248
user_input: str,
247249
llm_callback: Callable[[str, List[Dict[str, Any]]], str],
248250
retrieval_config: Optional[Dict[str, RetrievalConfig]],
251+
metadata: Optional[Dict[str, MetadataValue]] = None,
249252
event_timestamp: Optional[datetime] = None,
250253
) -> Tuple[List[Dict[str, Any]], str, Dict[str, Any]]:
251254
r"""Complete conversation turn with LLM callback integration.
@@ -263,6 +266,7 @@ def process_turn_with_llm(
263266
retrieval_config: Optional dictionary mapping namespaces to RetrievalConfig objects.
264267
Each namespace can contain template variables like {actorId}, {sessionId},
265268
{memoryStrategyId} that will be resolved at runtime.
269+
metadata: Optional custom key-value metadata to attach to an event.
266270
event_timestamp: Optional timestamp for the event
267271
268272
Returns:
@@ -340,6 +344,7 @@ def my_llm(user_input: str, memories: List[Dict]) -> str:
340344
ConversationalMessage(user_input, MessageRole.USER),
341345
ConversationalMessage(agent_response, MessageRole.ASSISTANT),
342346
],
347+
metadata=metadata,
343348
event_timestamp=event_timestamp,
344349
)
345350

@@ -352,6 +357,7 @@ def add_turns(
352357
session_id: str,
353358
messages: List[Union[ConversationalMessage, BlobMessage]],
354359
branch: Optional[Dict[str, str]] = None,
360+
metadata: Optional[Dict[str, MetadataValue]] = None,
355361
event_timestamp: Optional[datetime] = None,
356362
) -> Event:
357363
"""Adds conversational turns or blob objects to short-term memory.
@@ -365,21 +371,31 @@ def add_turns(
365371
- ConversationalMessage objects for conversational messages
366372
- BlobMessage objects for blob data
367373
branch: Optional branch info
374+
metadata: Optional custom key-value metadata to attach to an event.
368375
event_timestamp: Optional timestamp for the event
369376
370377
Returns:
371378
Created event
372379
373380
Example:
381+
```
374382
manager.add_turns(
375383
actor_id="user-123",
376384
session_id="session-456",
377385
messages=[
378386
ConversationalMessage("Hello", USER),
379387
BlobMessage({"file_data": "base64_content"}),
380388
ConversationalMessage("How can I help?", ASSISTANT)
389+
],
390+
metadata=[
391+
{
392+
'location': {
393+
'stringValue': 'NYC'
394+
}
395+
}
381396
]
382397
)
398+
```
383399
"""
384400
logger.info(" -> Storing %d messages in short-term memory...", len(messages))
385401

@@ -412,6 +428,10 @@ def add_turns(
412428

413429
if branch:
414430
params["branch"] = branch
431+
432+
if metadata:
433+
params["metadata"] = metadata
434+
415435
try:
416436
response = self._data_plane_client.create_event(**params)
417437
logger.info(" ✅ Turn stored successfully with Event ID: %s", response.get("eventId"))
@@ -427,6 +447,7 @@ def fork_conversation(
427447
root_event_id: str,
428448
branch_name: str,
429449
messages: List[Union[ConversationalMessage, BlobMessage]],
450+
metadata: Optional[Dict[str, MetadataValue]] = None,
430451
event_timestamp: Optional[datetime] = None,
431452
) -> Dict[str, Any]:
432453
"""Fork a conversation from a specific event to create a new branch."""
@@ -439,6 +460,7 @@ def fork_conversation(
439460
messages=messages,
440461
event_timestamp=event_timestamp,
441462
branch=branch,
463+
metadata=metadata,
442464
)
443465

444466
logger.info("Created branch '%s' from event %s", branch_name, root_event_id)
@@ -454,6 +476,7 @@ def list_events(
454476
session_id: str,
455477
branch_name: Optional[str] = None,
456478
include_parent_branches: bool = False,
479+
eventMetadata: Optional[List[EventMetadataFilter]] = None,
457480
max_results: int = 100,
458481
include_payload: bool = True,
459482
) -> List[Event]:
@@ -482,6 +505,49 @@ def list_events(
482505
483506
# Get events from a specific branch
484507
branch_events = client.list_events(actor_id, session_id, branch_name="test-branch")
508+
509+
#### Get events with event metadata filter
510+
```
511+
filtered_events_with_metadata = client.list_events(
512+
actor_id=actor_id,
513+
session_id=session_id,
514+
eventMetadata=[
515+
{
516+
'left': {
517+
'metadataKey': 'location'
518+
},
519+
'operator': 'EQUALS_TO',
520+
'right': {
521+
'metadataValue': {
522+
'stringValue': 'NYC'
523+
}
524+
}
525+
}
526+
]
527+
)
528+
```
529+
530+
#### Get events with event metadata filter + specific branch filter
531+
```
532+
branch_with_metadata_filtered_events = client.list_events(
533+
actor_id=actor_id,
534+
session_id=session_id,
535+
branch_name="test-branch",
536+
eventMetadata=[
537+
{
538+
'left': {
539+
'metadataKey': 'location'
540+
},
541+
'operator': 'EQUALS_TO',
542+
'right': {
543+
'metadataValue': {
544+
'stringValue': 'NYC'
545+
}
546+
}
547+
}
548+
]
549+
)
550+
```
485551
"""
486552
try:
487553
all_events: List[Event] = []
@@ -509,6 +575,12 @@ def list_events(
509575
"branch": {"name": branch_name, "includeParentBranches": include_parent_branches}
510576
}
511577

578+
# Add eventMetadata filter if specified
579+
if eventMetadata:
580+
params["filter"] = {
581+
"eventMetadata": eventMetadata
582+
}
583+
512584
response = self._data_plane_client.list_events(**params)
513585

514586
events = response.get("events", [])
@@ -888,28 +960,31 @@ def add_turns(
888960
self,
889961
messages: List[Union[ConversationalMessage, BlobMessage]],
890962
branch: Optional[Dict[str, str]] = None,
963+
metadata: Optional[Dict[str, MetadataValue]] = None,
891964
event_timestamp: Optional[datetime] = None,
892965
) -> Event:
893966
"""Delegates to manager.add_turns."""
894-
return self._manager.add_turns(self._actor_id, self._session_id, messages, branch, event_timestamp)
967+
return self._manager.add_turns(self._actor_id, self._session_id, messages, branch, metadata, event_timestamp)
895968

896969
def fork_conversation(
897970
self,
898971
messages: List[Union[ConversationalMessage, BlobMessage]],
899972
root_event_id: str,
900973
branch_name: str,
974+
metadata: Optional[Dict[str, MetadataValue]] = None,
901975
event_timestamp: Optional[datetime] = None,
902976
) -> Event:
903977
"""Delegates to manager.fork_conversation."""
904978
return self._manager.fork_conversation(
905-
self._actor_id, self._session_id, root_event_id, branch_name, messages, event_timestamp
979+
self._actor_id, self._session_id, root_event_id, branch_name, messages, metadata, event_timestamp
906980
)
907981

908982
def process_turn_with_llm(
909983
self,
910984
user_input: str,
911985
llm_callback: Callable[[str, List[Dict[str, Any]]], str],
912986
retrieval_config: Optional[Dict[str, RetrievalConfig]],
987+
metadata: Optional[Dict[str, MetadataValue]] = None,
913988
event_timestamp: Optional[datetime] = None,
914989
) -> Tuple[List[Dict[str, Any]], str, Dict[str, Any]]:
915990
"""Delegates to manager.process_turn_with_llm."""
@@ -919,6 +994,7 @@ def process_turn_with_llm(
919994
user_input,
920995
llm_callback,
921996
retrieval_config,
997+
metadata,
922998
event_timestamp,
923999
)
9241000

@@ -975,6 +1051,7 @@ def list_events(
9751051
self,
9761052
branch_name: Optional[str] = None,
9771053
include_parent_branches: bool = False,
1054+
eventMetadata: Optional[List[EventMetadataFilter]] = None,
9781055
max_results: int = 100,
9791056
include_payload: bool = True,
9801057
) -> List[Event]:
@@ -984,6 +1061,7 @@ def list_events(
9841061
session_id=self._session_id,
9851062
branch_name=branch_name,
9861063
include_parent_branches=include_parent_branches,
1064+
eventMetadata=eventMetadata,
9871065
include_payload=include_payload,
9881066
max_results=max_results,
9891067
)

0 commit comments

Comments
 (0)