Skip to content

Commit 5981d36

Browse files
feat: Enable multiagent session persistent in Graph/Swarm (#1110)
* feat: enable multiagent session persistent # Conflicts: # src/strands/multiagent/graph.py # src/strands/multiagent/swarm.py # tests/strands/multiagent/test_graph.py # tests/strands/multiagent/test_swarm.py # tests_integ/test_multiagent_graph.py # tests_integ/test_multiagent_swarm.py * fix: fix docstring * fix: rebase from main and address comments * fix: fix nit
1 parent 417ebea commit 5981d36

File tree

7 files changed

+730
-26
lines changed

7 files changed

+730
-26
lines changed

src/strands/multiagent/graph.py

Lines changed: 179 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@
2626
from .._async import run_async
2727
from ..agent import Agent
2828
from ..agent.state import AgentState
29+
from ..experimental.hooks.multiagent import (
30+
AfterMultiAgentInvocationEvent,
31+
AfterNodeCallEvent,
32+
BeforeMultiAgentInvocationEvent,
33+
BeforeNodeCallEvent,
34+
MultiAgentInitializedEvent,
35+
)
36+
from ..hooks import HookProvider, HookRegistry
37+
from ..session import SessionManager
2938
from ..telemetry import get_tracer
3039
from ..types._events import (
3140
MultiAgentHandoffEvent,
@@ -40,6 +49,8 @@
4049

4150
logger = logging.getLogger(__name__)
4251

52+
_DEFAULT_GRAPH_ID = "default_graph"
53+
4354

4455
@dataclass
4556
class GraphState:
@@ -223,6 +234,9 @@ def __init__(self) -> None:
223234
self._execution_timeout: Optional[float] = None
224235
self._node_timeout: Optional[float] = None
225236
self._reset_on_revisit: bool = False
237+
self._id: str = _DEFAULT_GRAPH_ID
238+
self._session_manager: Optional[SessionManager] = None
239+
self._hooks: Optional[list[HookProvider]] = None
226240

227241
def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode:
228242
"""Add an Agent or MultiAgentBase instance as a node to the graph."""
@@ -313,6 +327,33 @@ def set_node_timeout(self, timeout: float) -> "GraphBuilder":
313327
self._node_timeout = timeout
314328
return self
315329

330+
def set_graph_id(self, graph_id: str) -> "GraphBuilder":
331+
"""Set graph id.
332+
333+
Args:
334+
graph_id: Unique graph id
335+
"""
336+
self._id = graph_id
337+
return self
338+
339+
def set_session_manager(self, session_manager: SessionManager) -> "GraphBuilder":
340+
"""Set session manager for the graph.
341+
342+
Args:
343+
session_manager: SessionManager instance
344+
"""
345+
self._session_manager = session_manager
346+
return self
347+
348+
def set_hook_providers(self, hooks: list[HookProvider]) -> "GraphBuilder":
349+
"""Set hook providers for the graph.
350+
351+
Args:
352+
hooks: Customer hooks user passes in
353+
"""
354+
self._hooks = hooks
355+
return self
356+
316357
def build(self) -> "Graph":
317358
"""Build and validate the graph with configured settings."""
318359
if not self.nodes:
@@ -338,6 +379,9 @@ def build(self) -> "Graph":
338379
execution_timeout=self._execution_timeout,
339380
node_timeout=self._node_timeout,
340381
reset_on_revisit=self._reset_on_revisit,
382+
session_manager=self._session_manager,
383+
hooks=self._hooks,
384+
id=self._id,
341385
)
342386

343387
def _validate_graph(self) -> None:
@@ -365,6 +409,9 @@ def __init__(
365409
execution_timeout: Optional[float] = None,
366410
node_timeout: Optional[float] = None,
367411
reset_on_revisit: bool = False,
412+
session_manager: Optional[SessionManager] = None,
413+
hooks: Optional[list[HookProvider]] = None,
414+
id: str = _DEFAULT_GRAPH_ID,
368415
) -> None:
369416
"""Initialize Graph with execution limits and reset behavior.
370417
@@ -376,6 +423,9 @@ def __init__(
376423
execution_timeout: Total execution timeout in seconds (default: None - no limit)
377424
node_timeout: Individual node timeout in seconds (default: None - no limit)
378425
reset_on_revisit: Whether to reset node state when revisited (default: False)
426+
session_manager: Session manager for persisting graph state and execution history (default: None)
427+
hooks: List of hook providers for monitoring and extending graph execution behavior (default: None)
428+
id: Unique graph id (default: None)
379429
"""
380430
super().__init__()
381431

@@ -391,6 +441,19 @@ def __init__(
391441
self.reset_on_revisit = reset_on_revisit
392442
self.state = GraphState()
393443
self.tracer = get_tracer()
444+
self.session_manager = session_manager
445+
self.hooks = HookRegistry()
446+
if self.session_manager:
447+
self.hooks.add_hook(self.session_manager)
448+
if hooks:
449+
for hook in hooks:
450+
self.hooks.add_hook(hook)
451+
452+
self._resume_next_nodes: list[GraphNode] = []
453+
self._resume_from_session = False
454+
self.id = id
455+
456+
self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self))
394457

395458
def __call__(
396459
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
@@ -453,18 +516,25 @@ async def stream_async(
453516
if invocation_state is None:
454517
invocation_state = {}
455518

519+
self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state))
520+
456521
logger.debug("task=<%s> | starting graph execution", task)
457522

458523
# Initialize state
459524
start_time = time.time()
460-
self.state = GraphState(
461-
status=Status.EXECUTING,
462-
task=task,
463-
total_nodes=len(self.nodes),
464-
edges=[(edge.from_node, edge.to_node) for edge in self.edges],
465-
entry_points=list(self.entry_points),
466-
start_time=start_time,
467-
)
525+
if not self._resume_from_session:
526+
# Initialize state
527+
self.state = GraphState(
528+
status=Status.EXECUTING,
529+
task=task,
530+
total_nodes=len(self.nodes),
531+
edges=[(edge.from_node, edge.to_node) for edge in self.edges],
532+
entry_points=list(self.entry_points),
533+
start_time=start_time,
534+
)
535+
else:
536+
self.state.status = Status.EXECUTING
537+
self.state.start_time = start_time
468538

469539
span = self.tracer.start_multiagent_span(task, "graph")
470540
with trace_api.use_span(span, end_on_exit=True):
@@ -499,6 +569,9 @@ async def stream_async(
499569
raise
500570
finally:
501571
self.state.execution_time = round((time.time() - start_time) * 1000)
572+
self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self))
573+
self._resume_from_session = False
574+
self._resume_next_nodes.clear()
502575

503576
def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
504577
"""Validate graph nodes for duplicate instances."""
@@ -514,7 +587,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
514587

515588
async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
516589
"""Execute graph and yield TypedEvent objects."""
517-
ready_nodes = list(self.entry_points)
590+
ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points)
518591

519592
while ready_nodes:
520593
# Check execution limits before continuing
@@ -703,7 +776,9 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
703776

704777
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
705778
"""Execute a single node and yield TypedEvent objects."""
706-
# Reset the node's state if reset_on_revisit is enabled and it's being revisited
779+
self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, node.node_id, invocation_state))
780+
781+
# Reset the node's state if reset_on_revisit is enabled, and it's being revisited
707782
if self.reset_on_revisit and node in self.state.completed_nodes:
708783
logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id)
709784
node.reset_executor_state()
@@ -844,6 +919,9 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
844919
# Re-raise to stop graph execution (fail-fast behavior)
845920
raise
846921

922+
finally:
923+
self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state))
924+
847925
def _accumulate_metrics(self, node_result: NodeResult) -> None:
848926
"""Accumulate metrics from a node result."""
849927
self.state.accumulated_usage["inputTokens"] += node_result.accumulated_usage.get("inputTokens", 0)
@@ -928,3 +1006,94 @@ def _build_result(self) -> GraphResult:
9281006
edges=self.state.edges,
9291007
entry_points=self.state.entry_points,
9301008
)
1009+
1010+
def serialize_state(self) -> dict[str, Any]:
1011+
"""Serialize the current graph state to a dictionary."""
1012+
compute_nodes = self._compute_ready_nodes_for_resume()
1013+
next_nodes = [n.node_id for n in compute_nodes] if compute_nodes else []
1014+
return {
1015+
"type": "graph",
1016+
"id": self.id,
1017+
"status": self.state.status.value,
1018+
"completed_nodes": [n.node_id for n in self.state.completed_nodes],
1019+
"failed_nodes": [n.node_id for n in self.state.failed_nodes],
1020+
"node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()},
1021+
"next_nodes_to_execute": next_nodes,
1022+
"current_task": self.state.task,
1023+
"execution_order": [n.node_id for n in self.state.execution_order],
1024+
}
1025+
1026+
def deserialize_state(self, payload: dict[str, Any]) -> None:
1027+
"""Restore graph state from a session dict and prepare for execution.
1028+
1029+
This method handles two scenarios:
1030+
1. If the graph execution ended (no next_nodes_to_execute, eg: Completed, or Failed with dead end nodes),
1031+
resets all nodes and graph state to allow re-execution from the beginning.
1032+
2. If the graph execution was interrupted mid-execution (has next_nodes_to_execute),
1033+
restores the persisted state and prepares to resume execution from the next ready nodes.
1034+
1035+
Args:
1036+
payload: Dictionary containing persisted state data including status,
1037+
completed nodes, results, and next nodes to execute.
1038+
"""
1039+
if not payload.get("next_nodes_to_execute"):
1040+
# Reset all nodes
1041+
for node in self.nodes.values():
1042+
node.reset_executor_state()
1043+
# Reset graph state
1044+
self.state = GraphState()
1045+
self._resume_from_session = False
1046+
return
1047+
else:
1048+
self._from_dict(payload)
1049+
self._resume_from_session = True
1050+
1051+
def _compute_ready_nodes_for_resume(self) -> list[GraphNode]:
1052+
if self.state.status == Status.PENDING:
1053+
return []
1054+
ready_nodes: list[GraphNode] = []
1055+
completed_nodes = set(self.state.completed_nodes)
1056+
for node in self.nodes.values():
1057+
if node in completed_nodes:
1058+
continue
1059+
incoming = [e for e in self.edges if e.to_node is node]
1060+
if not incoming:
1061+
ready_nodes.append(node)
1062+
elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming):
1063+
ready_nodes.append(node)
1064+
1065+
return ready_nodes
1066+
1067+
def _from_dict(self, payload: dict[str, Any]) -> None:
1068+
self.state.status = Status(payload["status"])
1069+
# Hydrate completed nodes & results
1070+
raw_results = payload.get("node_results") or {}
1071+
results: dict[str, NodeResult] = {}
1072+
for node_id, entry in raw_results.items():
1073+
if node_id not in self.nodes:
1074+
continue
1075+
try:
1076+
results[node_id] = NodeResult.from_dict(entry)
1077+
except Exception:
1078+
logger.exception("Failed to hydrate NodeResult for node_id=%s; skipping.", node_id)
1079+
raise
1080+
self.state.results = results
1081+
1082+
self.state.failed_nodes = set(
1083+
self.nodes[node_id] for node_id in (payload.get("failed_nodes") or []) if node_id in self.nodes
1084+
)
1085+
1086+
# Restore completed nodes from persisted data
1087+
completed_node_ids = payload.get("completed_nodes") or []
1088+
self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes}
1089+
1090+
# Execution order (only nodes that still exist)
1091+
order_node_ids = payload.get("execution_order") or []
1092+
self.state.execution_order = [self.nodes[node_id] for node_id in order_node_ids if node_id in self.nodes]
1093+
1094+
# Task
1095+
self.state.task = payload.get("current_task", self.state.task)
1096+
1097+
# next nodes to execute
1098+
next_nodes = [self.nodes[nid] for nid in (payload.get("next_nodes_to_execute") or []) if nid in self.nodes]
1099+
self._resume_next_nodes = next_nodes

0 commit comments

Comments
 (0)