diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py index 3cb3efa8145d..facf51b9b8d8 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py @@ -1,5 +1,6 @@ """State management for agents, teams and termination conditions.""" +from ._message_store import InMemoryMessageStore, MessageStore from ._states import ( AssistantAgentState, BaseGroupChatManagerState, @@ -18,6 +19,8 @@ "AssistantAgentState", "BaseGroupChatManagerState", "ChatAgentContainerState", + "InMemoryMessageStore", + "MessageStore", "RoundRobinManagerState", "SelectorManagerState", "SwarmManagerState", diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/state/_message_store.py b/python/packages/autogen-agentchat/src/autogen_agentchat/state/_message_store.py new file mode 100644 index 000000000000..cfe7036913f8 --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/state/_message_store.py @@ -0,0 +1,174 @@ +"""Message store abstraction for storing message threads in teams. + +This module provides the :class:`MessageStore` abstract base class and an +:class:`InMemoryMessageStore` implementation that serves as the default, +backwards-compatible storage backend for group chat message threads. +""" + +import time +from abc import ABC, abstractmethod +from typing import Any, List, Mapping, Sequence + +from ..messages import BaseAgentEvent, BaseChatMessage + + +class MessageStore(ABC): + """Abstract base class for message thread storage. + + A ``MessageStore`` is responsible for persisting the message thread that + accumulates during a group chat session. Implementations may choose to + keep messages in memory, write them to a database, or use any other + persistence strategy. + + The optional *ttl* (time-to-live) parameter specifies how long messages + should be retained (in seconds). A value of ``None`` means messages + never expire. It is the responsibility of each concrete implementation + to honour the TTL policy. + """ + + def __init__(self, *, ttl: float | None = None) -> None: + if ttl is not None and ttl <= 0: + raise ValueError("TTL must be a positive number or None.") + self._ttl = ttl + + @property + def ttl(self) -> float | None: + """Return the TTL (time-to-live) in seconds, or ``None`` for no expiry.""" + return self._ttl + + @abstractmethod + async def add_messages(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + """Append one or more messages to the store. + + Args: + messages: A sequence of messages to add. + """ + ... + + @abstractmethod + async def get_messages(self) -> List[BaseAgentEvent | BaseChatMessage]: + """Return all non-expired messages in insertion order. + + Returns: + A list of messages. + """ + ... + + @abstractmethod + async def clear(self) -> None: + """Remove **all** messages from the store (including unexpired ones).""" + ... + + @abstractmethod + async def save_state(self) -> Mapping[str, Any]: + """Serialise the current store contents so they can be persisted externally. + + Returns: + A JSON-serialisable mapping. + """ + ... + + @abstractmethod + async def load_state(self, state: Mapping[str, Any]) -> None: + """Restore the store from a previously saved state. + + Args: + state: A mapping previously returned by :meth:`save_state`. + """ + ... + + +class InMemoryMessageStore(MessageStore): + """A simple in-memory message store that is fully backwards compatible with + the previous ``List``-based storage used in :class:`BaseGroupChatManager`. + + When a *ttl* is configured, messages older than *ttl* seconds are + automatically pruned on every read (:meth:`get_messages`). Timestamps are + captured at the time :meth:`add_messages` is called. + + Example: + + .. code-block:: python + + store = InMemoryMessageStore(ttl=300) # 5-minute TTL + await store.add_messages([msg1, msg2]) + messages = await store.get_messages() + """ + + def __init__(self, *, ttl: float | None = None) -> None: + super().__init__(ttl=ttl) + self._messages: List[BaseAgentEvent | BaseChatMessage] = [] + # Parallel list of insertion timestamps (epoch seconds) when TTL is enabled. + self._timestamps: List[float] = [] + + async def add_messages(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + now = time.monotonic() + for msg in messages: + self._messages.append(msg) + self._timestamps.append(now) + + async def get_messages(self) -> List[BaseAgentEvent | BaseChatMessage]: + if self._ttl is not None: + self._prune_expired() + return list(self._messages) + + async def clear(self) -> None: + self._messages.clear() + self._timestamps.clear() + + async def save_state(self) -> Mapping[str, Any]: + return { + "messages": [msg.dump() for msg in self._messages], + "timestamps": list(self._timestamps), + } + + async def load_state(self, state: Mapping[str, Any]) -> None: + from ..messages import MessageFactory + + factory = MessageFactory() + self._messages = [factory.create(m) for m in state.get("messages", [])] + self._timestamps = list(state.get("timestamps", [])) + # Ensure timestamps list is the same length as messages. + while len(self._timestamps) < len(self._messages): + self._timestamps.append(time.monotonic()) + + # ------------------------------------------------------------------ + # Convenience helpers for the group chat manager integration + # ------------------------------------------------------------------ + + def extend(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + """Synchronous helper that mirrors ``list.extend`` for easy migration.""" + now = time.monotonic() + for msg in messages: + self._messages.append(msg) + self._timestamps.append(now) + + @property + def messages(self) -> List[BaseAgentEvent | BaseChatMessage]: + """Direct access to the underlying list (backwards compatibility).""" + if self._ttl is not None: + self._prune_expired() + return self._messages + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + def _prune_expired(self) -> None: + """Remove messages that have exceeded the TTL.""" + if self._ttl is None: + return + cutoff = time.monotonic() - self._ttl + # Walk from the front; messages are in insertion order so timestamps + # are monotonically non-decreasing. + first_valid = 0 + for i, ts in enumerate(self._timestamps): + if ts >= cutoff: + first_valid = i + break + else: + # All messages are expired. + first_valid = len(self._messages) + if first_valid > 0: + del self._messages[:first_valid] + del self._timestamps[:first_valid] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py index 712976980ae5..1879ed6fb934 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py @@ -3,6 +3,7 @@ Each team inherits from the BaseGroupChat class. """ +from ..state import InMemoryMessageStore, MessageStore from ._group_chat._base_group_chat import BaseGroupChat from ._group_chat._graph import ( DiGraph, @@ -18,6 +19,8 @@ __all__ = [ "BaseGroupChat", + "InMemoryMessageStore", + "MessageStore", "RoundRobinGroupChat", "SelectorGroupChat", "Swarm", diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 60f222912387..b98438c323a1 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -24,7 +24,7 @@ StructuredMessage, TextMessage, ) -from ...state import TeamState +from ...state import MessageStore, TeamState from ._chat_agent_container import ChatAgentContainer from ._events import ( GroupChatPause, @@ -75,6 +75,7 @@ def __init__( runtime: AgentRuntime | None = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, emit_team_events: bool = False, + message_store: MessageStore | None = None, ): self._name = name self._description = description @@ -150,6 +151,9 @@ def __init__( # Flag to track if the team events should be emitted. self._emit_team_events = emit_team_events + # Optional message store for the group chat manager. + self._message_store = message_store + @property def name(self) -> str: """The name of the group chat team.""" @@ -173,6 +177,7 @@ def _create_group_chat_manager_factory( termination_condition: TerminationCondition | None, max_turns: int | None, message_factory: MessageFactory, + message_store: MessageStore | None = None, ) -> Callable[[], SequentialRoutedAgent]: ... def _create_participant_factory( @@ -224,6 +229,7 @@ async def _init(self, runtime: AgentRuntime) -> None: termination_condition=self._termination_condition, max_turns=self._max_turns, message_factory=self._message_factory, + message_store=self._message_store, ), ) # Add subscriptions for the group chat manager. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index b0a0c1d55fc4..10eda5004a2b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -6,6 +6,7 @@ from ...base import TerminationCondition from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, SelectSpeakerEvent, StopMessage +from ...state import InMemoryMessageStore, MessageStore from ._events import ( GroupChatAgentResponse, GroupChatError, @@ -47,6 +48,7 @@ def __init__( max_turns: int | None, message_factory: MessageFactory, emit_team_events: bool = False, + message_store: MessageStore | None = None, ): super().__init__( description="Group chat manager", @@ -74,7 +76,7 @@ def __init__( name: topic_type for name, topic_type in zip(participant_names, participant_topic_types, strict=True) } self._participant_descriptions = participant_descriptions - self._message_thread: List[BaseAgentEvent | BaseChatMessage] = [] + self._message_store: MessageStore = message_store if message_store is not None else InMemoryMessageStore() self._output_message_queue = output_message_queue self._termination_condition = termination_condition self._max_turns = max_turns @@ -83,6 +85,43 @@ def __init__( self._emit_team_events = emit_team_events self._active_speakers: List[str] = [] + @property + def _message_thread(self) -> List[BaseAgentEvent | BaseChatMessage]: + """Backwards-compatible access to the message thread via the message store. + + Subclasses that previously accessed ``self._message_thread`` directly + will continue to work transparently. + """ + if isinstance(self._message_store, InMemoryMessageStore): + return self._message_store.messages + # For non-InMemoryMessageStore implementations, we cannot provide a + # live mutable list. Return a snapshot instead. + import asyncio as _asyncio + + loop = _asyncio.get_event_loop() + if loop.is_running(): + # We are inside an async context; callers should prefer + # get_messages() but this keeps sync access working for + # simple attribute reads used in select_speaker() etc. + # Fall back to the store's synchronous snapshot if available. + raise RuntimeError( + "Cannot access _message_thread synchronously with a non-InMemoryMessageStore. " + "Use 'await self._message_store.get_messages()' instead." + ) + return loop.run_until_complete(self._message_store.get_messages()) + + @_message_thread.setter + def _message_thread(self, value: List[BaseAgentEvent | BaseChatMessage]) -> None: + """Allow ``self._message_thread = [...]`` for backwards compatibility.""" + if isinstance(self._message_store, InMemoryMessageStore): + self._message_store._messages = value + self._message_store._timestamps = [] + else: + raise RuntimeError( + "Cannot set _message_thread directly with a non-InMemoryMessageStore. " + "Use the message store API instead." + ) + @rpc async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: """Handle the start of a group chat by selecting a speaker to start the conversation.""" @@ -170,7 +209,8 @@ async def handle_agent_response( raise async def _transition_to_next_speakers(self, cancellation_token: CancellationToken) -> None: - speaker_names_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) + thread = await self._message_store.get_messages() + speaker_names_future = asyncio.ensure_future(self.select_speaker(thread)) # Link the select speaker future to the cancellation token. cancellation_token.link_future(speaker_names_future) speaker_names = await speaker_names_future @@ -300,7 +340,7 @@ async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseCh This is called when the group chat receives a GroupChatStart or GroupChatAgentResponse event, before calling the select_speakers method. """ - self._message_thread.extend(messages) + await self._message_store.add_messages(messages) @abstractmethod async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py index d77b42dd17f2..17c94c92559b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py @@ -511,8 +511,9 @@ def _reset_execution_state(self) -> None: async def save_state(self) -> Mapping[str, Any]: """Save the execution state.""" + messages = await self._message_store.get_messages() state = { - "message_thread": [message.dump() for message in self._message_thread], + "message_thread": [message.dump() for message in messages], "current_turn": self._current_turn, "remaining": {target: dict(counter) for target, counter in self._remaining.items()}, "enqueued_any": dict(self._enqueued_any), @@ -522,7 +523,9 @@ async def save_state(self) -> Mapping[str, Any]: async def load_state(self, state: Mapping[str, Any]) -> None: """Restore execution state from saved data.""" - self._message_thread = [self._message_factory.create(msg) for msg in state["message_thread"]] + loaded_messages = [self._message_factory.create(msg) for msg in state["message_thread"]] + await self._message_store.clear() + await self._message_store.add_messages(loaded_messages) self._current_turn = state["current_turn"] self._remaining = {target: Counter(groups) for target, groups in state["remaining"].items()} self._enqueued_any = state["enqueued_any"] @@ -531,7 +534,7 @@ async def load_state(self, state: Mapping[str, Any]) -> None: async def reset(self) -> None: """Reset execution state to the start of the graph.""" self._current_turn = 0 - self._message_thread.clear() + await self._message_store.clear() if self._termination_condition: await self._termination_condition.reset() self._reset_execution_state() diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index 176789257ba7..0b3cb87ea352 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -223,8 +223,9 @@ async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> pass async def save_state(self) -> Mapping[str, Any]: + messages = await self._message_store.get_messages() state = MagenticOneOrchestratorState( - message_thread=[msg.dump() for msg in self._message_thread], + message_thread=[msg.dump() for msg in messages], current_turn=self._current_turn, task=self._task, facts=self._facts, @@ -236,7 +237,9 @@ async def save_state(self) -> Mapping[str, Any]: async def load_state(self, state: Mapping[str, Any]) -> None: orchestrator_state = MagenticOneOrchestratorState.model_validate(state) - self._message_thread = [self._message_factory.create(message) for message in orchestrator_state.message_thread] + loaded_messages = [self._message_factory.create(message) for message in orchestrator_state.message_thread] + await self._message_store.clear() + await self._message_store.add_messages(loaded_messages) self._current_turn = orchestrator_state.current_turn self._task = orchestrator_state.task self._facts = orchestrator_state.facts @@ -250,7 +253,7 @@ async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage async def reset(self) -> None: """Reset the group chat manager.""" - self._message_thread.clear() + await self._message_store.clear() if self._termination_condition is not None: await self._termination_condition.reset() self._n_rounds = 0 @@ -269,7 +272,7 @@ async def _reenter_outer_loop(self, cancellation_token: CancellationToken) -> No cancellation_token=cancellation_token, ) # Reset partially the group chat manager - self._message_thread.clear() + await self._message_store.clear() # Prepare the ledger ledger_message = TextMessage( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py index 3f529f0c4474..f3b8c38e69d6 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py @@ -7,7 +7,7 @@ from ...base import ChatAgent, Team, TerminationCondition from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory -from ...state import RoundRobinManagerState +from ...state import MessageStore, RoundRobinManagerState from ._base_group_chat import BaseGroupChat from ._base_group_chat_manager import BaseGroupChatManager from ._events import GroupChatTermination @@ -29,6 +29,7 @@ def __init__( max_turns: int | None, message_factory: MessageFactory, emit_team_events: bool, + message_store: MessageStore | None = None, ) -> None: super().__init__( name, @@ -42,6 +43,7 @@ def __init__( max_turns, message_factory, emit_team_events, + message_store=message_store, ) self._next_speaker_index = 0 @@ -50,14 +52,15 @@ async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> async def reset(self) -> None: self._current_turn = 0 - self._message_thread.clear() + await self._message_store.clear() if self._termination_condition is not None: await self._termination_condition.reset() self._next_speaker_index = 0 async def save_state(self) -> Mapping[str, Any]: + messages = await self._message_store.get_messages() state = RoundRobinManagerState( - message_thread=[message.dump() for message in self._message_thread], + message_thread=[message.dump() for message in messages], current_turn=self._current_turn, next_speaker_index=self._next_speaker_index, ) @@ -65,7 +68,9 @@ async def save_state(self) -> Mapping[str, Any]: async def load_state(self, state: Mapping[str, Any]) -> None: round_robin_state = RoundRobinManagerState.model_validate(state) - self._message_thread = [self._message_factory.create(message) for message in round_robin_state.message_thread] + loaded_messages = [self._message_factory.create(message) for message in round_robin_state.message_thread] + await self._message_store.clear() + await self._message_store.add_messages(loaded_messages) self._current_turn = round_robin_state.current_turn self._next_speaker_index = round_robin_state.next_speaker_index @@ -250,6 +255,7 @@ def __init__( runtime: AgentRuntime | None = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, emit_team_events: bool = False, + message_store: MessageStore | None = None, ) -> None: super().__init__( name=name or self.DEFAULT_NAME, @@ -262,6 +268,7 @@ def __init__( runtime=runtime, custom_message_types=custom_message_types, emit_team_events=emit_team_events, + message_store=message_store, ) def _create_group_chat_manager_factory( @@ -276,6 +283,7 @@ def _create_group_chat_manager_factory( termination_condition: TerminationCondition | None, max_turns: int | None, message_factory: MessageFactory, + message_store: MessageStore | None = None, ) -> Callable[[], RoundRobinGroupChatManager]: def _factory() -> RoundRobinGroupChatManager: return RoundRobinGroupChatManager( @@ -290,6 +298,7 @@ def _factory() -> RoundRobinGroupChatManager: max_turns, message_factory, self._emit_team_events, + message_store=message_store, ) return _factory diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 480dc6b71641..6f60fce93bab 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -107,15 +107,16 @@ async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> async def reset(self) -> None: self._current_turn = 0 - self._message_thread.clear() + await self._message_store.clear() await self._model_context.clear() if self._termination_condition is not None: await self._termination_condition.reset() self._previous_speaker = None async def save_state(self) -> Mapping[str, Any]: + messages = await self._message_store.get_messages() state = SelectorManagerState( - message_thread=[msg.dump() for msg in self._message_thread], + message_thread=[msg.dump() for msg in messages], current_turn=self._current_turn, previous_speaker=self._previous_speaker, ) @@ -123,9 +124,11 @@ async def save_state(self) -> Mapping[str, Any]: async def load_state(self, state: Mapping[str, Any]) -> None: selector_state = SelectorManagerState.model_validate(state) - self._message_thread = [self._message_factory.create(msg) for msg in selector_state.message_thread] + loaded_messages = [self._message_factory.create(msg) for msg in selector_state.message_thread] + await self._message_store.clear() + await self._message_store.add_messages(loaded_messages) await self._add_messages_to_context( - self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)] + self._model_context, [msg for msg in loaded_messages if isinstance(msg, BaseChatMessage)] ) self._current_turn = selector_state.current_turn self._previous_speaker = selector_state.previous_speaker @@ -145,7 +148,7 @@ async def _add_messages_to_context( await model_context.add_message(msg.to_model_message()) async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: - self._message_thread.extend(messages) + await self._message_store.add_messages(messages) base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)] await self._add_messages_to_context(self._model_context, base_chat_messages) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py index c9b495083939..bad0a58f548a 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -74,7 +74,7 @@ async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> async def reset(self) -> None: self._current_turn = 0 - self._message_thread.clear() + await self._message_store.clear() if self._termination_condition is not None: await self._termination_condition.reset() self._current_speaker = self._participant_names[0] @@ -98,8 +98,9 @@ async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage return self._current_speaker async def save_state(self) -> Mapping[str, Any]: + messages = await self._message_store.get_messages() state = SwarmManagerState( - message_thread=[msg.dump() for msg in self._message_thread], + message_thread=[msg.dump() for msg in messages], current_turn=self._current_turn, current_speaker=self._current_speaker, ) @@ -107,7 +108,9 @@ async def save_state(self) -> Mapping[str, Any]: async def load_state(self, state: Mapping[str, Any]) -> None: swarm_state = SwarmManagerState.model_validate(state) - self._message_thread = [self._message_factory.create(message) for message in swarm_state.message_thread] + loaded_messages = [self._message_factory.create(message) for message in swarm_state.message_thread] + await self._message_store.clear() + await self._message_store.add_messages(loaded_messages) self._current_turn = swarm_state.current_turn self._current_speaker = swarm_state.current_speaker diff --git a/python/packages/autogen-agentchat/tests/test_message_store.py b/python/packages/autogen-agentchat/tests/test_message_store.py new file mode 100644 index 000000000000..a8e4cd32301d --- /dev/null +++ b/python/packages/autogen-agentchat/tests/test_message_store.py @@ -0,0 +1,311 @@ +"""Comprehensive tests for the MessageStore abstraction and InMemoryMessageStore implementation.""" + +import asyncio +import time +from typing import Any, List, Mapping, Sequence +from unittest.mock import patch + +import pytest + +from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, TextMessage +from autogen_agentchat.state import InMemoryMessageStore, MessageStore + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_msg(content: str, source: str = "user") -> TextMessage: + """Create a simple TextMessage for testing.""" + return TextMessage(content=content, source=source) + + +class DummyMessageStore(MessageStore): + """Minimal concrete subclass used only to test the ABC contract.""" + + def __init__(self, *, ttl: float | None = None) -> None: + super().__init__(ttl=ttl) + self._msgs: List[BaseAgentEvent | BaseChatMessage] = [] + + async def add_messages(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + self._msgs.extend(messages) + + async def get_messages(self) -> List[BaseAgentEvent | BaseChatMessage]: + return list(self._msgs) + + async def clear(self) -> None: + self._msgs.clear() + + async def save_state(self) -> Mapping[str, Any]: + return {"messages": [m.dump() for m in self._msgs]} + + async def load_state(self, state: Mapping[str, Any]) -> None: + from autogen_agentchat.messages import MessageFactory + + factory = MessageFactory() + self._msgs = [factory.create(m) for m in state.get("messages", [])] + + +# --------------------------------------------------------------------------- +# Tests – ABC & construction +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_abc_cannot_be_instantiated() -> None: + """MessageStore is abstract and cannot be instantiated directly.""" + with pytest.raises(TypeError): + MessageStore() # type: ignore[abstract] + + +@pytest.mark.asyncio +async def test_ttl_must_be_positive() -> None: + """A non-positive TTL must raise ValueError.""" + with pytest.raises(ValueError, match="positive"): + InMemoryMessageStore(ttl=0) + with pytest.raises(ValueError, match="positive"): + InMemoryMessageStore(ttl=-5) + + +@pytest.mark.asyncio +async def test_ttl_none_is_allowed() -> None: + """TTL=None means messages never expire.""" + store = InMemoryMessageStore(ttl=None) + assert store.ttl is None + + +@pytest.mark.asyncio +async def test_ttl_property() -> None: + """The ttl property should return the configured value.""" + store = InMemoryMessageStore(ttl=60.0) + assert store.ttl == 60.0 + + +# --------------------------------------------------------------------------- +# Tests – basic CRUD (InMemoryMessageStore) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_and_get_messages() -> None: + """Messages added via add_messages should be retrievable via get_messages.""" + store = InMemoryMessageStore() + msg1 = _make_msg("hello") + msg2 = _make_msg("world") + await store.add_messages([msg1, msg2]) + result = await store.get_messages() + assert len(result) == 2 + assert result[0].content == "hello" + assert result[1].content == "world" + + +@pytest.mark.asyncio +async def test_add_messages_preserves_order() -> None: + """Multiple add_messages calls should preserve insertion order.""" + store = InMemoryMessageStore() + await store.add_messages([_make_msg("a")]) + await store.add_messages([_make_msg("b"), _make_msg("c")]) + contents = [m.content for m in await store.get_messages()] + assert contents == ["a", "b", "c"] + + +@pytest.mark.asyncio +async def test_get_messages_returns_copy() -> None: + """get_messages should return a new list each time (not a reference to internal state).""" + store = InMemoryMessageStore() + await store.add_messages([_make_msg("x")]) + list1 = await store.get_messages() + list2 = await store.get_messages() + assert list1 is not list2 + assert list1 == list2 + + +@pytest.mark.asyncio +async def test_clear_removes_all() -> None: + """clear() should remove all messages.""" + store = InMemoryMessageStore() + await store.add_messages([_make_msg("a"), _make_msg("b")]) + await store.clear() + result = await store.get_messages() + assert len(result) == 0 + + +@pytest.mark.asyncio +async def test_clear_on_empty_store() -> None: + """Clearing an already-empty store should not raise.""" + store = InMemoryMessageStore() + await store.clear() + assert await store.get_messages() == [] + + +@pytest.mark.asyncio +async def test_add_empty_sequence() -> None: + """Adding an empty sequence should be a no-op.""" + store = InMemoryMessageStore() + await store.add_messages([]) + assert await store.get_messages() == [] + + +# --------------------------------------------------------------------------- +# Tests – TTL expiration +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ttl_expires_messages() -> None: + """Messages older than the TTL should not appear in get_messages.""" + store = InMemoryMessageStore(ttl=1.0) + await store.add_messages([_make_msg("old")]) + + # Manually adjust the timestamp to simulate time passing. + store._timestamps[0] = time.monotonic() - 2.0 + + await store.add_messages([_make_msg("new")]) + result = await store.get_messages() + assert len(result) == 1 + assert result[0].content == "new" + + +@pytest.mark.asyncio +async def test_ttl_all_expired() -> None: + """When all messages are expired, get_messages returns an empty list.""" + store = InMemoryMessageStore(ttl=0.5) + await store.add_messages([_make_msg("a"), _make_msg("b")]) + + # Backdate all timestamps. + now = time.monotonic() + store._timestamps = [now - 1.0, now - 1.0] + + result = await store.get_messages() + assert result == [] + + +@pytest.mark.asyncio +async def test_ttl_no_expiry_when_none() -> None: + """When TTL is None, messages should never expire regardless of age.""" + store = InMemoryMessageStore(ttl=None) + await store.add_messages([_make_msg("forever")]) + # Backdate far into the past. + store._timestamps[0] = time.monotonic() - 999999 + result = await store.get_messages() + assert len(result) == 1 + + +@pytest.mark.asyncio +async def test_ttl_boundary() -> None: + """A message exactly at the TTL boundary should still be included.""" + store = InMemoryMessageStore(ttl=5.0) + await store.add_messages([_make_msg("boundary")]) + # Set timestamp to exactly the cutoff. + store._timestamps[0] = time.monotonic() - 5.0 + result = await store.get_messages() + # The message should be pruned because cutoff = now - ttl, and ts < cutoff. + # (ts == cutoff means it's exactly at the boundary -- still valid since we use >= cutoff) + assert len(result) == 1 + + +# --------------------------------------------------------------------------- +# Tests – save_state / load_state +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_save_and_load_state() -> None: + """State should round-trip through save_state and load_state.""" + store = InMemoryMessageStore() + await store.add_messages([_make_msg("saved1"), _make_msg("saved2")]) + state = await store.save_state() + + new_store = InMemoryMessageStore() + await new_store.load_state(state) + messages = await new_store.get_messages() + assert len(messages) == 2 + assert messages[0].content == "saved1" + assert messages[1].content == "saved2" + + +@pytest.mark.asyncio +async def test_load_state_replaces_existing() -> None: + """load_state should replace the current contents, not append.""" + store = InMemoryMessageStore() + await store.add_messages([_make_msg("original")]) + + other = InMemoryMessageStore() + await other.add_messages([_make_msg("replacement")]) + state = await other.save_state() + + await store.load_state(state) + messages = await store.get_messages() + assert len(messages) == 1 + assert messages[0].content == "replacement" + + +@pytest.mark.asyncio +async def test_load_empty_state() -> None: + """Loading an empty state dict should result in an empty store.""" + store = InMemoryMessageStore() + await store.add_messages([_make_msg("stuff")]) + await store.load_state({}) + assert await store.get_messages() == [] + + +# --------------------------------------------------------------------------- +# Tests – synchronous helpers (backwards compatibility) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_extend_sync_helper() -> None: + """The extend() sync helper should add messages identically to add_messages.""" + store = InMemoryMessageStore() + store.extend([_make_msg("sync1"), _make_msg("sync2")]) + result = await store.get_messages() + assert len(result) == 2 + assert result[0].content == "sync1" + + +@pytest.mark.asyncio +async def test_messages_property() -> None: + """The messages property should return the internal list.""" + store = InMemoryMessageStore() + await store.add_messages([_make_msg("via_property")]) + assert len(store.messages) == 1 + assert store.messages[0].content == "via_property" + + +# --------------------------------------------------------------------------- +# Tests – DummyMessageStore (custom implementation contract) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_custom_store_implements_contract() -> None: + """A custom MessageStore subclass should work through the standard API.""" + store = DummyMessageStore() + await store.add_messages([_make_msg("custom")]) + result = await store.get_messages() + assert len(result) == 1 + await store.clear() + assert await store.get_messages() == [] + + +@pytest.mark.asyncio +async def test_custom_store_ttl_property() -> None: + """Custom stores should inherit the TTL property from the base class.""" + store = DummyMessageStore(ttl=120.0) + assert store.ttl == 120.0 + + +@pytest.mark.asyncio +async def test_custom_store_save_load() -> None: + """Custom store save/load round-trip.""" + store = DummyMessageStore() + await store.add_messages([_make_msg("persist")]) + state = await store.save_state() + + store2 = DummyMessageStore() + await store2.load_state(state) + msgs = await store2.get_messages() + assert len(msgs) == 1 + assert msgs[0].content == "persist"