2626from .._async import run_async
2727from ..agent import Agent
2828from ..agent .state import AgentState
29+ from ..experimental .hooks .multiagent import (
30+ AfterMultiAgentInvocationEvent ,
31+ AfterNodeCallEvent ,
32+ BeforeMultiAgentInvocationEvent ,
33+ BeforeNodeCallEvent ,
34+ MultiAgentInitializedEvent ,
35+ )
36+ from ..hooks import HookProvider , HookRegistry
37+ from ..session import SessionManager
2938from ..telemetry import get_tracer
3039from ..types ._events import (
3140 MultiAgentHandoffEvent ,
4049
4150logger = logging .getLogger (__name__ )
4251
52+ _DEFAULT_GRAPH_ID = "default_graph"
53+
4354
4455@dataclass
4556class GraphState :
@@ -223,6 +234,9 @@ def __init__(self) -> None:
223234 self ._execution_timeout : Optional [float ] = None
224235 self ._node_timeout : Optional [float ] = None
225236 self ._reset_on_revisit : bool = False
237+ self ._id : str = _DEFAULT_GRAPH_ID
238+ self ._session_manager : Optional [SessionManager ] = None
239+ self ._hooks : Optional [list [HookProvider ]] = None
226240
227241 def add_node (self , executor : Agent | MultiAgentBase , node_id : str | None = None ) -> GraphNode :
228242 """Add an Agent or MultiAgentBase instance as a node to the graph."""
@@ -313,6 +327,33 @@ def set_node_timeout(self, timeout: float) -> "GraphBuilder":
313327 self ._node_timeout = timeout
314328 return self
315329
330+ def set_graph_id (self , graph_id : str ) -> "GraphBuilder" :
331+ """Set graph id.
332+
333+ Args:
334+ graph_id: Unique graph id
335+ """
336+ self ._id = graph_id
337+ return self
338+
339+ def set_session_manager (self , session_manager : SessionManager ) -> "GraphBuilder" :
340+ """Set session manager for the graph.
341+
342+ Args:
343+ session_manager: SessionManager instance
344+ """
345+ self ._session_manager = session_manager
346+ return self
347+
348+ def set_hook_providers (self , hooks : list [HookProvider ]) -> "GraphBuilder" :
349+ """Set hook providers for the graph.
350+
351+ Args:
352+ hooks: Customer hooks user passes in
353+ """
354+ self ._hooks = hooks
355+ return self
356+
316357 def build (self ) -> "Graph" :
317358 """Build and validate the graph with configured settings."""
318359 if not self .nodes :
@@ -338,6 +379,9 @@ def build(self) -> "Graph":
338379 execution_timeout = self ._execution_timeout ,
339380 node_timeout = self ._node_timeout ,
340381 reset_on_revisit = self ._reset_on_revisit ,
382+ session_manager = self ._session_manager ,
383+ hooks = self ._hooks ,
384+ id = self ._id ,
341385 )
342386
343387 def _validate_graph (self ) -> None :
@@ -365,6 +409,9 @@ def __init__(
365409 execution_timeout : Optional [float ] = None ,
366410 node_timeout : Optional [float ] = None ,
367411 reset_on_revisit : bool = False ,
412+ session_manager : Optional [SessionManager ] = None ,
413+ hooks : Optional [list [HookProvider ]] = None ,
414+ id : str = _DEFAULT_GRAPH_ID ,
368415 ) -> None :
369416 """Initialize Graph with execution limits and reset behavior.
370417
@@ -376,6 +423,9 @@ def __init__(
376423 execution_timeout: Total execution timeout in seconds (default: None - no limit)
377424 node_timeout: Individual node timeout in seconds (default: None - no limit)
378425 reset_on_revisit: Whether to reset node state when revisited (default: False)
426+ session_manager: Session manager for persisting graph state and execution history (default: None)
427+ hooks: List of hook providers for monitoring and extending graph execution behavior (default: None)
428+ id: Unique graph id (default: None)
379429 """
380430 super ().__init__ ()
381431
@@ -391,6 +441,19 @@ def __init__(
391441 self .reset_on_revisit = reset_on_revisit
392442 self .state = GraphState ()
393443 self .tracer = get_tracer ()
444+ self .session_manager = session_manager
445+ self .hooks = HookRegistry ()
446+ if self .session_manager :
447+ self .hooks .add_hook (self .session_manager )
448+ if hooks :
449+ for hook in hooks :
450+ self .hooks .add_hook (hook )
451+
452+ self ._resume_next_nodes : list [GraphNode ] = []
453+ self ._resume_from_session = False
454+ self .id = id
455+
456+ self .hooks .invoke_callbacks (MultiAgentInitializedEvent (self ))
394457
395458 def __call__ (
396459 self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
@@ -453,18 +516,25 @@ async def stream_async(
453516 if invocation_state is None :
454517 invocation_state = {}
455518
519+ self .hooks .invoke_callbacks (BeforeMultiAgentInvocationEvent (self , invocation_state ))
520+
456521 logger .debug ("task=<%s> | starting graph execution" , task )
457522
458523 # Initialize state
459524 start_time = time .time ()
460- self .state = GraphState (
461- status = Status .EXECUTING ,
462- task = task ,
463- total_nodes = len (self .nodes ),
464- edges = [(edge .from_node , edge .to_node ) for edge in self .edges ],
465- entry_points = list (self .entry_points ),
466- start_time = start_time ,
467- )
525+ if not self ._resume_from_session :
526+ # Initialize state
527+ self .state = GraphState (
528+ status = Status .EXECUTING ,
529+ task = task ,
530+ total_nodes = len (self .nodes ),
531+ edges = [(edge .from_node , edge .to_node ) for edge in self .edges ],
532+ entry_points = list (self .entry_points ),
533+ start_time = start_time ,
534+ )
535+ else :
536+ self .state .status = Status .EXECUTING
537+ self .state .start_time = start_time
468538
469539 span = self .tracer .start_multiagent_span (task , "graph" )
470540 with trace_api .use_span (span , end_on_exit = True ):
@@ -499,6 +569,9 @@ async def stream_async(
499569 raise
500570 finally :
501571 self .state .execution_time = round ((time .time () - start_time ) * 1000 )
572+ self .hooks .invoke_callbacks (AfterMultiAgentInvocationEvent (self ))
573+ self ._resume_from_session = False
574+ self ._resume_next_nodes .clear ()
502575
503576 def _validate_graph (self , nodes : dict [str , GraphNode ]) -> None :
504577 """Validate graph nodes for duplicate instances."""
@@ -514,7 +587,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
514587
515588 async def _execute_graph (self , invocation_state : dict [str , Any ]) -> AsyncIterator [Any ]:
516589 """Execute graph and yield TypedEvent objects."""
517- ready_nodes = list (self .entry_points )
590+ ready_nodes = self . _resume_next_nodes if self . _resume_from_session else list (self .entry_points )
518591
519592 while ready_nodes :
520593 # Check execution limits before continuing
@@ -703,7 +776,9 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
703776
704777 async def _execute_node (self , node : GraphNode , invocation_state : dict [str , Any ]) -> AsyncIterator [Any ]:
705778 """Execute a single node and yield TypedEvent objects."""
706- # Reset the node's state if reset_on_revisit is enabled and it's being revisited
779+ self .hooks .invoke_callbacks (BeforeNodeCallEvent (self , node .node_id , invocation_state ))
780+
781+ # Reset the node's state if reset_on_revisit is enabled, and it's being revisited
707782 if self .reset_on_revisit and node in self .state .completed_nodes :
708783 logger .debug ("node_id=<%s> | resetting node state for revisit" , node .node_id )
709784 node .reset_executor_state ()
@@ -844,6 +919,9 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
844919 # Re-raise to stop graph execution (fail-fast behavior)
845920 raise
846921
922+ finally :
923+ self .hooks .invoke_callbacks (AfterNodeCallEvent (self , node .node_id , invocation_state ))
924+
847925 def _accumulate_metrics (self , node_result : NodeResult ) -> None :
848926 """Accumulate metrics from a node result."""
849927 self .state .accumulated_usage ["inputTokens" ] += node_result .accumulated_usage .get ("inputTokens" , 0 )
@@ -928,3 +1006,94 @@ def _build_result(self) -> GraphResult:
9281006 edges = self .state .edges ,
9291007 entry_points = self .state .entry_points ,
9301008 )
1009+
1010+ def serialize_state (self ) -> dict [str , Any ]:
1011+ """Serialize the current graph state to a dictionary."""
1012+ compute_nodes = self ._compute_ready_nodes_for_resume ()
1013+ next_nodes = [n .node_id for n in compute_nodes ] if compute_nodes else []
1014+ return {
1015+ "type" : "graph" ,
1016+ "id" : self .id ,
1017+ "status" : self .state .status .value ,
1018+ "completed_nodes" : [n .node_id for n in self .state .completed_nodes ],
1019+ "failed_nodes" : [n .node_id for n in self .state .failed_nodes ],
1020+ "node_results" : {k : v .to_dict () for k , v in (self .state .results or {}).items ()},
1021+ "next_nodes_to_execute" : next_nodes ,
1022+ "current_task" : self .state .task ,
1023+ "execution_order" : [n .node_id for n in self .state .execution_order ],
1024+ }
1025+
1026+ def deserialize_state (self , payload : dict [str , Any ]) -> None :
1027+ """Restore graph state from a session dict and prepare for execution.
1028+
1029+ This method handles two scenarios:
1030+ 1. If the graph execution ended (no next_nodes_to_execute, eg: Completed, or Failed with dead end nodes),
1031+ resets all nodes and graph state to allow re-execution from the beginning.
1032+ 2. If the graph execution was interrupted mid-execution (has next_nodes_to_execute),
1033+ restores the persisted state and prepares to resume execution from the next ready nodes.
1034+
1035+ Args:
1036+ payload: Dictionary containing persisted state data including status,
1037+ completed nodes, results, and next nodes to execute.
1038+ """
1039+ if not payload .get ("next_nodes_to_execute" ):
1040+ # Reset all nodes
1041+ for node in self .nodes .values ():
1042+ node .reset_executor_state ()
1043+ # Reset graph state
1044+ self .state = GraphState ()
1045+ self ._resume_from_session = False
1046+ return
1047+ else :
1048+ self ._from_dict (payload )
1049+ self ._resume_from_session = True
1050+
1051+ def _compute_ready_nodes_for_resume (self ) -> list [GraphNode ]:
1052+ if self .state .status == Status .PENDING :
1053+ return []
1054+ ready_nodes : list [GraphNode ] = []
1055+ completed_nodes = set (self .state .completed_nodes )
1056+ for node in self .nodes .values ():
1057+ if node in completed_nodes :
1058+ continue
1059+ incoming = [e for e in self .edges if e .to_node is node ]
1060+ if not incoming :
1061+ ready_nodes .append (node )
1062+ elif all (e .from_node in completed_nodes and e .should_traverse (self .state ) for e in incoming ):
1063+ ready_nodes .append (node )
1064+
1065+ return ready_nodes
1066+
1067+ def _from_dict (self , payload : dict [str , Any ]) -> None :
1068+ self .state .status = Status (payload ["status" ])
1069+ # Hydrate completed nodes & results
1070+ raw_results = payload .get ("node_results" ) or {}
1071+ results : dict [str , NodeResult ] = {}
1072+ for node_id , entry in raw_results .items ():
1073+ if node_id not in self .nodes :
1074+ continue
1075+ try :
1076+ results [node_id ] = NodeResult .from_dict (entry )
1077+ except Exception :
1078+ logger .exception ("Failed to hydrate NodeResult for node_id=%s; skipping." , node_id )
1079+ raise
1080+ self .state .results = results
1081+
1082+ self .state .failed_nodes = set (
1083+ self .nodes [node_id ] for node_id in (payload .get ("failed_nodes" ) or []) if node_id in self .nodes
1084+ )
1085+
1086+ # Restore completed nodes from persisted data
1087+ completed_node_ids = payload .get ("completed_nodes" ) or []
1088+ self .state .completed_nodes = {self .nodes [node_id ] for node_id in completed_node_ids if node_id in self .nodes }
1089+
1090+ # Execution order (only nodes that still exist)
1091+ order_node_ids = payload .get ("execution_order" ) or []
1092+ self .state .execution_order = [self .nodes [node_id ] for node_id in order_node_ids if node_id in self .nodes ]
1093+
1094+ # Task
1095+ self .state .task = payload .get ("current_task" , self .state .task )
1096+
1097+ # next nodes to execute
1098+ next_nodes = [self .nodes [nid ] for nid in (payload .get ("next_nodes_to_execute" ) or []) if nid in self .nodes ]
1099+ self ._resume_next_nodes = next_nodes
0 commit comments