Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""State management for agents, teams and termination conditions."""

from ._message_store import InMemoryMessageStore, MessageStore
from ._states import (
AssistantAgentState,
BaseGroupChatManagerState,
Expand All @@ -18,6 +19,8 @@
"AssistantAgentState",
"BaseGroupChatManagerState",
"ChatAgentContainerState",
"InMemoryMessageStore",
"MessageStore",
"RoundRobinManagerState",
"SelectorManagerState",
"SwarmManagerState",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Message store abstraction for storing message threads in teams.

This module provides the :class:`MessageStore` abstract base class and an
:class:`InMemoryMessageStore` implementation that serves as the default,
backwards-compatible storage backend for group chat message threads.
"""

import time
from abc import ABC, abstractmethod
from typing import Any, List, Mapping, Sequence

from ..messages import BaseAgentEvent, BaseChatMessage


class MessageStore(ABC):
"""Abstract base class for message thread storage.

A ``MessageStore`` is responsible for persisting the message thread that
accumulates during a group chat session. Implementations may choose to
keep messages in memory, write them to a database, or use any other
persistence strategy.

The optional *ttl* (time-to-live) parameter specifies how long messages
should be retained (in seconds). A value of ``None`` means messages
never expire. It is the responsibility of each concrete implementation
to honour the TTL policy.
"""

def __init__(self, *, ttl: float | None = None) -> None:
if ttl is not None and ttl <= 0:
raise ValueError("TTL must be a positive number or None.")
self._ttl = ttl

@property
def ttl(self) -> float | None:
"""Return the TTL (time-to-live) in seconds, or ``None`` for no expiry."""
return self._ttl

@abstractmethod
async def add_messages(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
"""Append one or more messages to the store.

Args:
messages: A sequence of messages to add.
"""
...

@abstractmethod
async def get_messages(self) -> List[BaseAgentEvent | BaseChatMessage]:
"""Return all non-expired messages in insertion order.

Returns:
A list of messages.
"""
...

@abstractmethod
async def clear(self) -> None:
"""Remove **all** messages from the store (including unexpired ones)."""
...

@abstractmethod
async def save_state(self) -> Mapping[str, Any]:
"""Serialise the current store contents so they can be persisted externally.

Returns:
A JSON-serialisable mapping.
"""
...

@abstractmethod
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Restore the store from a previously saved state.

Args:
state: A mapping previously returned by :meth:`save_state`.
"""
...


class InMemoryMessageStore(MessageStore):
"""A simple in-memory message store that is fully backwards compatible with
the previous ``List``-based storage used in :class:`BaseGroupChatManager`.

When a *ttl* is configured, messages older than *ttl* seconds are
automatically pruned on every read (:meth:`get_messages`). Timestamps are
captured at the time :meth:`add_messages` is called.

Example:

.. code-block:: python

store = InMemoryMessageStore(ttl=300) # 5-minute TTL
await store.add_messages([msg1, msg2])
messages = await store.get_messages()
"""

def __init__(self, *, ttl: float | None = None) -> None:
super().__init__(ttl=ttl)
self._messages: List[BaseAgentEvent | BaseChatMessage] = []
# Parallel list of insertion timestamps (epoch seconds) when TTL is enabled.
self._timestamps: List[float] = []

async def add_messages(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
now = time.monotonic()
for msg in messages:
self._messages.append(msg)
self._timestamps.append(now)

async def get_messages(self) -> List[BaseAgentEvent | BaseChatMessage]:
if self._ttl is not None:
self._prune_expired()
return list(self._messages)

async def clear(self) -> None:
self._messages.clear()
self._timestamps.clear()

async def save_state(self) -> Mapping[str, Any]:
return {
"messages": [msg.dump() for msg in self._messages],
"timestamps": list(self._timestamps),
}

async def load_state(self, state: Mapping[str, Any]) -> None:
from ..messages import MessageFactory

factory = MessageFactory()
self._messages = [factory.create(m) for m in state.get("messages", [])]
self._timestamps = list(state.get("timestamps", []))
# Ensure timestamps list is the same length as messages.
while len(self._timestamps) < len(self._messages):
self._timestamps.append(time.monotonic())

# ------------------------------------------------------------------
# Convenience helpers for the group chat manager integration
# ------------------------------------------------------------------

def extend(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
"""Synchronous helper that mirrors ``list.extend`` for easy migration."""
now = time.monotonic()
for msg in messages:
self._messages.append(msg)
self._timestamps.append(now)

@property
def messages(self) -> List[BaseAgentEvent | BaseChatMessage]:
"""Direct access to the underlying list (backwards compatibility)."""
if self._ttl is not None:
self._prune_expired()
return self._messages

# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------

def _prune_expired(self) -> None:
"""Remove messages that have exceeded the TTL."""
if self._ttl is None:
return
cutoff = time.monotonic() - self._ttl
# Walk from the front; messages are in insertion order so timestamps
# are monotonically non-decreasing.
first_valid = 0
for i, ts in enumerate(self._timestamps):
if ts >= cutoff:
first_valid = i
break
else:
# All messages are expired.
first_valid = len(self._messages)
if first_valid > 0:
del self._messages[:first_valid]
del self._timestamps[:first_valid]
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Each team inherits from the BaseGroupChat class.
"""

from ..state import InMemoryMessageStore, MessageStore
from ._group_chat._base_group_chat import BaseGroupChat
from ._group_chat._graph import (
DiGraph,
Expand All @@ -18,6 +19,8 @@

__all__ = [
"BaseGroupChat",
"InMemoryMessageStore",
"MessageStore",
"RoundRobinGroupChat",
"SelectorGroupChat",
"Swarm",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
StructuredMessage,
TextMessage,
)
from ...state import TeamState
from ...state import MessageStore, TeamState
from ._chat_agent_container import ChatAgentContainer
from ._events import (
GroupChatPause,
Expand Down Expand Up @@ -75,6 +75,7 @@ def __init__(
runtime: AgentRuntime | None = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
message_store: MessageStore | None = None,
):
self._name = name
self._description = description
Expand Down Expand Up @@ -150,6 +151,9 @@ def __init__(
# Flag to track if the team events should be emitted.
self._emit_team_events = emit_team_events

# Optional message store for the group chat manager.
self._message_store = message_store

@property
def name(self) -> str:
"""The name of the group chat team."""
Expand All @@ -173,6 +177,7 @@ def _create_group_chat_manager_factory(
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
message_store: MessageStore | None = None,
) -> Callable[[], SequentialRoutedAgent]: ...

def _create_participant_factory(
Expand Down Expand Up @@ -224,6 +229,7 @@ async def _init(self, runtime: AgentRuntime) -> None:
termination_condition=self._termination_condition,
max_turns=self._max_turns,
message_factory=self._message_factory,
message_store=self._message_store,
),
)
# Add subscriptions for the group chat manager.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ...base import TerminationCondition
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, SelectSpeakerEvent, StopMessage
from ...state import InMemoryMessageStore, MessageStore
from ._events import (
GroupChatAgentResponse,
GroupChatError,
Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(
max_turns: int | None,
message_factory: MessageFactory,
emit_team_events: bool = False,
message_store: MessageStore | None = None,
):
super().__init__(
description="Group chat manager",
Expand Down Expand Up @@ -74,7 +76,7 @@ def __init__(
name: topic_type for name, topic_type in zip(participant_names, participant_topic_types, strict=True)
}
self._participant_descriptions = participant_descriptions
self._message_thread: List[BaseAgentEvent | BaseChatMessage] = []
self._message_store: MessageStore = message_store if message_store is not None else InMemoryMessageStore()
self._output_message_queue = output_message_queue
self._termination_condition = termination_condition
self._max_turns = max_turns
Expand All @@ -83,6 +85,43 @@ def __init__(
self._emit_team_events = emit_team_events
self._active_speakers: List[str] = []

@property
def _message_thread(self) -> List[BaseAgentEvent | BaseChatMessage]:
"""Backwards-compatible access to the message thread via the message store.

Subclasses that previously accessed ``self._message_thread`` directly
will continue to work transparently.
"""
if isinstance(self._message_store, InMemoryMessageStore):
return self._message_store.messages
# For non-InMemoryMessageStore implementations, we cannot provide a
# live mutable list. Return a snapshot instead.
import asyncio as _asyncio

loop = _asyncio.get_event_loop()
if loop.is_running():
# We are inside an async context; callers should prefer
# get_messages() but this keeps sync access working for
# simple attribute reads used in select_speaker() etc.
# Fall back to the store's synchronous snapshot if available.
raise RuntimeError(
"Cannot access _message_thread synchronously with a non-InMemoryMessageStore. "
"Use 'await self._message_store.get_messages()' instead."
)
return loop.run_until_complete(self._message_store.get_messages())

@_message_thread.setter
def _message_thread(self, value: List[BaseAgentEvent | BaseChatMessage]) -> None:
"""Allow ``self._message_thread = [...]`` for backwards compatibility."""
if isinstance(self._message_store, InMemoryMessageStore):
self._message_store._messages = value
self._message_store._timestamps = []
else:
raise RuntimeError(
"Cannot set _message_thread directly with a non-InMemoryMessageStore. "
"Use the message store API instead."
)

@rpc
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
"""Handle the start of a group chat by selecting a speaker to start the conversation."""
Expand Down Expand Up @@ -170,7 +209,8 @@ async def handle_agent_response(
raise

async def _transition_to_next_speakers(self, cancellation_token: CancellationToken) -> None:
speaker_names_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
thread = await self._message_store.get_messages()
speaker_names_future = asyncio.ensure_future(self.select_speaker(thread))
# Link the select speaker future to the cancellation token.
cancellation_token.link_future(speaker_names_future)
speaker_names = await speaker_names_future
Expand Down Expand Up @@ -300,7 +340,7 @@ async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseCh
This is called when the group chat receives a GroupChatStart or GroupChatAgentResponse event,
before calling the select_speakers method.
"""
self._message_thread.extend(messages)
await self._message_store.add_messages(messages)

@abstractmethod
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,9 @@ def _reset_execution_state(self) -> None:

async def save_state(self) -> Mapping[str, Any]:
"""Save the execution state."""
messages = await self._message_store.get_messages()
state = {
"message_thread": [message.dump() for message in self._message_thread],
"message_thread": [message.dump() for message in messages],
"current_turn": self._current_turn,
"remaining": {target: dict(counter) for target, counter in self._remaining.items()},
"enqueued_any": dict(self._enqueued_any),
Expand All @@ -522,7 +523,9 @@ async def save_state(self) -> Mapping[str, Any]:

async def load_state(self, state: Mapping[str, Any]) -> None:
"""Restore execution state from saved data."""
self._message_thread = [self._message_factory.create(msg) for msg in state["message_thread"]]
loaded_messages = [self._message_factory.create(msg) for msg in state["message_thread"]]
await self._message_store.clear()
await self._message_store.add_messages(loaded_messages)
self._current_turn = state["current_turn"]
self._remaining = {target: Counter(groups) for target, groups in state["remaining"].items()}
self._enqueued_any = state["enqueued_any"]
Expand All @@ -531,7 +534,7 @@ async def load_state(self, state: Mapping[str, Any]) -> None:
async def reset(self) -> None:
"""Reset execution state to the start of the graph."""
self._current_turn = 0
self._message_thread.clear()
await self._message_store.clear()
if self._termination_condition:
await self._termination_condition.reset()
self._reset_execution_state()
Expand Down
Loading