From 8eac740df2066ac6fa944631fd551f13658cdeea Mon Sep 17 00:00:00 2001 From: Channing Conger Date: Mon, 22 Jun 2026 01:18:32 +0000 Subject: [PATCH] code-mode: make pending frontiers generation-resumable --- codex-rs/code-mode/src/cell_actor/mod.rs | 243 ++++++++------- codex-rs/code-mode/src/cell_actor/tests.rs | 213 +++++++++---- codex-rs/code-mode/src/cell_actor/types.rs | 48 ++- codex-rs/code-mode/src/runtime/mod.rs | 65 ++-- codex-rs/code-mode/src/service.rs | 89 ++++-- codex-rs/code-mode/src/service_tests.rs | 8 +- codex-rs/code-mode/src/session_runtime/mod.rs | 43 ++- .../code-mode/src/session_runtime/tests.rs | 279 +++++++++++++++++- .../code-mode/src/session_runtime/types.rs | 70 +++-- 9 files changed, 799 insertions(+), 259 deletions(-) diff --git a/codex-rs/code-mode/src/cell_actor/mod.rs b/codex-rs/code-mode/src/cell_actor/mod.rs index 019e0d990ead..c1ff0f27ed9e 100644 --- a/codex-rs/code-mode/src/cell_actor/mod.rs +++ b/codex-rs/code-mode/src/cell_actor/mod.rs @@ -40,6 +40,9 @@ use crate::session_runtime::CellExecutionPolicy; use crate::session_runtime::CreateCellRequest as CellRequest; use crate::session_runtime::ObserveMode; use crate::session_runtime::OutputItem; +use crate::session_runtime::PendingFrontier; +use crate::session_runtime::PendingGeneration; +use crate::session_runtime::ResumeOutcome; use crate::session_runtime::ToolName as CellToolName; pub(crate) struct CellActor; @@ -105,9 +108,11 @@ async fn run_cell( let cancellation_token = cell_state.cancellation_token(); let callback_cancellation_token = cancellation_token.child_token(); let mut content_items = Vec::new(); - let mut pending_tool_call_ids = Vec::new(); let mut pending_initial_yield_items: Option> = None; - let mut pending_frontier_ready = false; + let mut pending_frontier: Option = None; + let mut pending_frontier_observed = false; + let mut next_pending_generation = 1; + let mut last_resumed_generation = None; let mut observer: Option = None; let mut has_been_observed = false; let mut termination = false; @@ -126,10 +131,6 @@ async fn run_cell( _ = cancellation_token.cancelled(), if !termination => { termination = true; yield_timer = None; - if let Some(mut yielded_items) = pending_initial_yield_items.take() { - yielded_items.append(&mut content_items); - content_items = yielded_items; - } drop(command_rx.take()); begin_termination( &runtime_tx, @@ -148,7 +149,12 @@ async fn run_cell( &cell_state, observer.take().map(|observer| observer.response_tx), CellEvent::Terminated { - content_items: std::mem::take(&mut content_items), + content_items: take_termination_content( + &mut pending_frontier, + pending_frontier_observed, + &mut pending_initial_yield_items, + &mut content_items, + ), }, ); break; @@ -160,10 +166,50 @@ async fn run_cell( None => std::future::pending::>().await, } } => { - let Some(CellCommand::Observe { mode, response_tx }) = maybe_command else { + let Some(command) = maybe_command else { cancellation_token.cancel(); continue; }; + let (mode, response_tx) = match command { + CellCommand::Observe { mode, response_tx } => (mode, response_tx), + CellCommand::Resume { generation, response_tx } => { + let result = if termination { + Err(CellError::Closed) + } else if let Some(frontier) = pending_frontier.as_ref() { + let current = frontier.generation; + match generation.cmp(¤t) { + std::cmp::Ordering::Less => Ok(ResumeOutcome::AlreadyRunning), + std::cmp::Ordering::Greater => { + Err(CellError::InvalidGeneration { + requested: generation, + latest: Some(current), + }) + } + std::cmp::Ordering::Equal => { + pending_frontier = None; + pending_frontier_observed = false; + last_resumed_generation = Some(generation); + runtime_paused = false; + let _ = runtime_control_tx + .send(RuntimeControlCommand::Continue); + Ok(ResumeOutcome::Resumed) + } + } + } else { + let latest = last_resumed_generation; + if latest.is_some_and(|latest| generation <= latest) { + Ok(ResumeOutcome::AlreadyRunning) + } else { + Err(CellError::InvalidGeneration { + requested: generation, + latest, + }) + } + }; + let _ = response_tx.send(result); + continue; + } + }; if response_tx.is_closed() { continue; } @@ -204,38 +250,29 @@ async fn run_cell( } }; if delivered && runtime_paused { - pending_frontier_ready = false; - pending_tool_call_ids.clear(); - resume_for_observation( - mode, - &mut runtime_paused, - &runtime_tx, - &runtime_control_tx, - ); + pending_frontier = None; + pending_frontier_observed = false; + let _ = runtime_control_tx.send(RuntimeControlCommand::Continue); + runtime_paused = false; } continue; } - if matches!(mode, ObserveMode::PendingFrontier) && pending_frontier_ready { - pending_frontier_ready = !send_pending_event( - response_tx, - &mut pending_initial_yield_items, - &mut content_items, - &mut pending_tool_call_ids, - ); + if matches!(mode, ObserveMode::PendingFrontier) + && let Some(frontier) = pending_frontier.as_ref() + { + if send_cell_event(response_tx, CellEvent::Pending(frontier.clone())).is_ok() { + pending_frontier_observed = true; + } continue; } observer = Some(Observer { mode, response_tx }); yield_timer = observer.as_ref().and_then(observer_timer); if runtime_paused && matches!(mode, ObserveMode::YieldAfter(_)) { - pending_frontier_ready = false; - pending_tool_call_ids.clear(); + pending_frontier = None; + pending_frontier_observed = false; + let _ = runtime_control_tx.send(RuntimeControlCommand::Continue); + runtime_paused = false; } - resume_for_observation( - mode, - &mut runtime_paused, - &runtime_tx, - &runtime_control_tx, - ); } _ = async { if let Some(yield_timer) = yield_timer.as_mut() { @@ -265,7 +302,9 @@ async fn run_cell( let Some(event) = maybe_event else { runtime_closed = true; if termination || cancellation_token.is_cancelled() { - let termination_content_items = take_all_content( + let termination_content_items = take_termination_content( + &mut pending_frontier, + pending_frontier_observed, &mut pending_initial_yield_items, &mut content_items, ); @@ -292,14 +331,19 @@ async fn run_cell( ) .await; let event = CellEvent::Completed { - content_items: std::mem::take(&mut content_items), + content_items: take_termination_content( + &mut pending_frontier, + pending_frontier_observed, + &mut pending_initial_yield_items, + &mut content_items, + ), error_text: Some("exec runtime ended unexpectedly".to_string()), }; let rejected_event = match host .commit_completion( HashMap::new(), event, - pending_initial_yield_items.take(), + /*pending_initial_yield_items*/ None, Arc::clone(&cell_state), ) .await @@ -331,30 +375,46 @@ async fn run_cell( RuntimeEvent::Started => { yield_timer = observer.as_ref().and_then(observer_timer); } - RuntimeEvent::Pending => { + RuntimeEvent::Pending { + pending_tool_call_ids, + } => { runtime_paused = true; - if let Some(observer) = observer.take_if(|observer| { - observer.mode == ObserveMode::PendingFrontier - }) { - yield_timer = None; - pending_frontier_ready = !send_pending_event( - observer.response_tx, - &mut pending_initial_yield_items, - &mut content_items, - &mut pending_tool_call_ids, - ); - } else if observer.is_some() - || matches!( - execution_policy, - CellExecutionPolicy::ContinueWhenUnblocked - ) - { - pending_frontier_ready = false; - pending_tool_call_ids.clear(); + if matches!( + execution_policy, + CellExecutionPolicy::ContinueWhenUnblocked + ) { + pending_frontier = None; let _ = runtime_control_tx.send(RuntimeControlCommand::Continue); runtime_paused = false; } else { - pending_frontier_ready = true; + if pending_frontier.is_none() { + pending_frontier_observed = false; + } + let frontier = pending_frontier.get_or_insert_with(|| { + let generation = PendingGeneration::new(next_pending_generation); + next_pending_generation += 1; + PendingFrontier { + generation, + content_items: take_all_content( + &mut pending_initial_yield_items, + &mut content_items, + ), + pending_tool_call_ids, + } + }); + if let Some(observer) = observer.take_if(|observer| { + observer.mode == ObserveMode::PendingFrontier + }) { + yield_timer = None; + if send_cell_event( + observer.response_tx, + CellEvent::Pending(frontier.clone()), + ) + .is_ok() + { + pending_frontier_observed = true; + } + } } } RuntimeEvent::ContentItem(item) => content_items.push(output_item(item)), @@ -396,7 +456,6 @@ async fn run_cell( ); } RuntimeEvent::ToolCall { id, name, kind, input } => { - pending_tool_call_ids.push(id.clone()); spawn_tool( &mut tool_tasks, Arc::clone(&host), @@ -417,7 +476,9 @@ async fn run_cell( runtime_closed = true; yield_timer = None; if termination || cancellation_token.is_cancelled() { - let termination_content_items = take_all_content( + let termination_content_items = take_termination_content( + &mut pending_frontier, + pending_frontier_observed, &mut pending_initial_yield_items, &mut content_items, ); @@ -525,40 +586,6 @@ fn send_cell_event( } } -fn send_pending_event( - response_tx: oneshot::Sender>, - pending_initial_yield_items: &mut Option>, - content_items: &mut Vec, - pending_tool_call_ids: &mut Vec, -) -> bool { - let had_initial_yield = pending_initial_yield_items.is_some(); - let mut delivered_items = pending_initial_yield_items.take().unwrap_or_default(); - let initial_yield_len = delivered_items.len(); - delivered_items.append(content_items); - match send_cell_event( - response_tx, - CellEvent::Pending { - content_items: delivered_items, - pending_tool_call_ids: std::mem::take(pending_tool_call_ids), - }, - ) { - Ok(()) => true, - Err(CellEvent::Pending { - content_items: mut undelivered_items, - pending_tool_call_ids: undelivered_tool_call_ids, - }) => { - let following_items = undelivered_items.split_off(initial_yield_len); - if had_initial_yield { - *pending_initial_yield_items = Some(undelivered_items); - } - *content_items = following_items; - *pending_tool_call_ids = undelivered_tool_call_ids; - false - } - Err(event) => panic!("pending delivery returned an unexpected event: {event:?}"), - } -} - fn restore_undelivered_yield(delivery: Result<(), CellEvent>, content_items: &mut Vec) { match delivery { Ok(()) => {} @@ -591,6 +618,24 @@ fn take_all_content( yielded_items } +fn take_termination_content( + pending_frontier: &mut Option, + pending_frontier_observed: bool, + pending_initial_yield_items: &mut Option>, + content_items: &mut Vec, +) -> Vec { + let mut termination_content = match pending_frontier.take() { + Some(_) if pending_frontier_observed => Vec::new(), + Some(frontier) => frontier.content_items, + None => Vec::new(), + }; + termination_content.append(&mut take_all_content( + pending_initial_yield_items, + content_items, + )); + termination_content +} + fn finish_termination( cell_state: &CellState, observer_tx: Option>>, @@ -610,24 +655,6 @@ fn observer_timer(observer: &Observer) -> Option, - runtime_control_tx: &std::sync::mpsc::Sender, -) { - if *runtime_paused { - let control = match mode { - ObserveMode::YieldAfter(_) => RuntimeControlCommand::Continue, - ObserveMode::PendingFrontier => RuntimeControlCommand::Resume, - }; - let _ = runtime_control_tx.send(control); - *runtime_paused = false; - } else if matches!(mode, ObserveMode::PendingFrontier) { - let _ = runtime_tx.send(RuntimeCommand::ObservePendingFrontier); - } -} - fn begin_termination( runtime_tx: &std::sync::mpsc::Sender, runtime_control_tx: &std::sync::mpsc::Sender, diff --git a/codex-rs/code-mode/src/cell_actor/tests.rs b/codex-rs/code-mode/src/cell_actor/tests.rs index ec5c112c2c5a..5f91099347a4 100644 --- a/codex-rs/code-mode/src/cell_actor/tests.rs +++ b/codex-rs/code-mode/src/cell_actor/tests.rs @@ -168,6 +168,16 @@ async fn wait_for_notification(host: &RecordingHost) { .expect("notification barrier timed out"); } +async fn wait_for_completion(host: &RecordingHost) { + tokio::time::timeout(Duration::from_secs(1), async { + while !host.committed.load(Ordering::Acquire) { + tokio::task::yield_now().await; + } + }) + .await + .expect("completion barrier timed out"); +} + #[tokio::test] async fn completion_and_output_are_buffered_until_the_first_observation() { let host = Arc::new(RecordingHost::default()); @@ -210,7 +220,12 @@ async fn completion_and_output_are_buffered_until_the_first_observation() { async fn continuing_harness_advances_an_unobserved_pending_frontier() { let host = Arc::new(RecordingHost::default()); let harness = spawn_cell_actor_harness_with_host(Arc::clone(&host)); - harness.event_tx.send(RuntimeEvent::Pending).unwrap(); + harness + .event_tx + .send(RuntimeEvent::Pending { + pending_tool_call_ids: Vec::new(), + }) + .unwrap(); harness .event_tx .send(RuntimeEvent::Notify { @@ -220,16 +235,10 @@ async fn continuing_harness_advances_an_unobserved_pending_frontier() { .unwrap(); wait_for_notification(&host).await; - loop { - match harness.runtime_control_rx.try_recv() { - Ok(RuntimeControlCommand::Continue) => break, - Ok(command) => panic!("expected continue, got {command:?}"), - Err(std_mpsc::TryRecvError::Empty) => tokio::task::yield_now().await, - Err(std_mpsc::TryRecvError::Disconnected) => { - panic!("runtime control channel disconnected") - } - } - } + assert!(matches!( + harness.runtime_control_rx.try_recv(), + Ok(RuntimeControlCommand::Continue) + )); let termination = harness.handle.terminate(); drop(harness.event_tx); assert_eq!( @@ -245,7 +254,12 @@ async fn continuing_harness_advances_an_unobserved_pending_frontier() { async fn pending_frontier_is_buffered_while_runtime_commands_are_queued() { let host = Arc::new(RecordingHost::default()); let harness = spawn_pausable_cell_actor_harness_with_host(Arc::clone(&host)); - harness.event_tx.send(RuntimeEvent::Pending).unwrap(); + harness + .event_tx + .send(RuntimeEvent::Pending { + pending_tool_call_ids: Vec::new(), + }) + .unwrap(); harness .event_tx .send(RuntimeEvent::Notify { @@ -270,10 +284,11 @@ async fn pending_frontier_is_buffered_while_runtime_commands_are_queued() { )); assert_eq!( harness.handle.observe(ObserveMode::PendingFrontier).await, - Ok(CellEvent::Pending { + Ok(CellEvent::Pending(PendingFrontier { + generation: PendingGeneration::new(/*value*/ 1), content_items: Vec::new(), pending_tool_call_ids: Vec::new(), - }) + })) ); assert!(matches!( harness.runtime_control_rx.try_recv(), @@ -292,54 +307,75 @@ async fn pending_frontier_is_buffered_while_runtime_commands_are_queued() { } #[tokio::test] -async fn buffered_yield_observation_resumes_an_unobserved_pending_frontier() { +async fn termination_preserves_an_unobserved_pending_frontier() { let host = Arc::new(RecordingHost::default()); let harness = spawn_pausable_cell_actor_harness_with_host(Arc::clone(&host)); - harness.event_tx.send(RuntimeEvent::YieldRequested).unwrap(); - harness.event_tx.send(RuntimeEvent::Pending).unwrap(); + harness + .event_tx + .send(RuntimeEvent::ContentItem( + FunctionCallOutputContentItem::InputText { + text: "unobserved".to_string(), + }, + )) + .unwrap(); + harness + .event_tx + .send(RuntimeEvent::Pending { + pending_tool_call_ids: Vec::new(), + }) + .unwrap(); harness .event_tx .send(RuntimeEvent::Notify { - call_id: "notify-1".to_string(), - text: "pending processed".to_string(), + call_id: "pending-barrier".to_string(), + text: "barrier".to_string(), }) .unwrap(); - while !host.notified.load(Ordering::Acquire) { - tokio::task::yield_now().await; - } + wait_for_notification(&host).await; + let termination = harness.handle.terminate(); + drop(harness.event_tx); assert_eq!( - harness - .handle - .observe(ObserveMode::YieldAfter(Duration::from_secs(60))) - .await, - Ok(CellEvent::Yielded { - content_items: Vec::new(), + termination.await, + Ok(CellEvent::Terminated { + content_items: vec![OutputItem::Text { + text: "unobserved".to_string(), + }], }) ); - loop { - match harness.runtime_control_rx.try_recv() { - Ok(RuntimeControlCommand::Continue) => break, - Ok(command) => panic!("expected continue, got {command:?}"), - Err(std_mpsc::TryRecvError::Empty) => tokio::task::yield_now().await, - Err(std_mpsc::TryRecvError::Disconnected) => { - panic!("runtime control channel disconnected") - } - } - } + harness.task.await.unwrap(); +} - host.notified.store(false, Ordering::Release); - harness.event_tx.send(RuntimeEvent::Pending).unwrap(); +#[tokio::test] +async fn repeated_pending_observation_does_not_resume_an_unobserved_frontier() { + let host = Arc::new(RecordingHost::default()); + let harness = spawn_pausable_cell_actor_harness_with_host(Arc::clone(&host)); + harness.event_tx.send(RuntimeEvent::YieldRequested).unwrap(); + harness + .event_tx + .send(RuntimeEvent::Pending { + pending_tool_call_ids: Vec::new(), + }) + .unwrap(); harness .event_tx .send(RuntimeEvent::Notify { - call_id: "notify-2".to_string(), - text: "later pending processed".to_string(), + call_id: "notify-1".to_string(), + text: "pending processed".to_string(), }) .unwrap(); while !host.notified.load(Ordering::Acquire) { tokio::task::yield_now().await; } + + assert_eq!( + harness.handle.observe(ObserveMode::PendingFrontier).await, + Ok(CellEvent::Pending(PendingFrontier { + generation: PendingGeneration::new(/*value*/ 1), + content_items: Vec::new(), + pending_tool_call_ids: Vec::new(), + })) + ); assert!(matches!( harness.runtime_control_rx.try_recv(), Err(std_mpsc::TryRecvError::Empty) @@ -425,7 +461,7 @@ async fn first_observation_preserves_a_yield_that_raced_with_creation() { } #[tokio::test] -async fn dropped_pending_observer_preserves_pre_observation_yield() { +async fn dropped_pending_observer_preserves_the_durable_frontier() { let host = Arc::new(RecordingHost::default()); let harness = spawn_pausable_cell_actor_harness_with_host(Arc::clone(&host)); harness @@ -461,7 +497,12 @@ async fn dropped_pending_observer_preserves_pre_observation_yield() { Err(CellError::Busy) ); drop(dropped_observation); - harness.event_tx.send(RuntimeEvent::Pending).unwrap(); + harness + .event_tx + .send(RuntimeEvent::Pending { + pending_tool_call_ids: Vec::new(), + }) + .unwrap(); harness .event_tx .send(RuntimeEvent::Notify { @@ -472,19 +513,19 @@ async fn dropped_pending_observer_preserves_pre_observation_yield() { wait_for_notification(&host).await; assert_eq!( - tokio::time::timeout( - Duration::from_secs(1), - harness - .handle - .observe(ObserveMode::YieldAfter(Duration::from_secs(60))), - ) - .await - .expect("initial yield was not preserved after failed pending delivery"), - Ok(CellEvent::Yielded { - content_items: vec![OutputItem::Text { - text: "before".to_string(), - }], - }) + harness.handle.observe(ObserveMode::PendingFrontier).await, + Ok(CellEvent::Pending(PendingFrontier { + generation: PendingGeneration::new(/*value*/ 1), + content_items: vec![ + OutputItem::Text { + text: "before".to_string(), + }, + OutputItem::Text { + text: "after".to_string(), + }, + ], + pending_tool_call_ids: Vec::new(), + })) ); let termination = harness.handle.terminate(); @@ -492,9 +533,7 @@ async fn dropped_pending_observer_preserves_pre_observation_yield() { assert_eq!( termination.await, Ok(CellEvent::Terminated { - content_items: vec![OutputItem::Text { - text: "after".to_string(), - }], + content_items: Vec::new(), }) ); harness.task.await.unwrap(); @@ -861,7 +900,12 @@ async fn dropped_pending_observer_preserves_the_frontier_for_the_next_observatio input: Some(serde_json::json!({})), }) .unwrap(); - harness.event_tx.send(RuntimeEvent::Pending).unwrap(); + harness + .event_tx + .send(RuntimeEvent::Pending { + pending_tool_call_ids: vec!["tool-1".to_string()], + }) + .unwrap(); harness .event_tx .send(RuntimeEvent::Notify { @@ -873,10 +917,11 @@ async fn dropped_pending_observer_preserves_the_frontier_for_the_next_observatio assert_eq!( harness.handle.observe(ObserveMode::PendingFrontier).await, - Ok(CellEvent::Pending { + Ok(CellEvent::Pending(PendingFrontier { + generation: PendingGeneration::new(/*value*/ 1), content_items: Vec::new(), pending_tool_call_ids: vec!["tool-1".to_string()], - }) + })) ); assert!(matches!( harness.runtime_control_rx.try_recv(), @@ -894,6 +939,48 @@ async fn dropped_pending_observer_preserves_the_frontier_for_the_next_observatio harness.task.await.unwrap(); } +#[tokio::test] +async fn unexpected_runtime_loss_preserves_an_unobserved_pending_frontier() { + let host = Arc::new(RecordingHost::default()); + let harness = spawn_pausable_cell_actor_harness_with_host(Arc::clone(&host)); + harness + .event_tx + .send(RuntimeEvent::ContentItem( + FunctionCallOutputContentItem::InputText { + text: "pending output".to_string(), + }, + )) + .unwrap(); + harness + .event_tx + .send(RuntimeEvent::Pending { + pending_tool_call_ids: vec!["tool-1".to_string()], + }) + .unwrap(); + harness + .event_tx + .send(RuntimeEvent::Notify { + call_id: "pending-frontier-barrier".to_string(), + text: "barrier".to_string(), + }) + .unwrap(); + wait_for_notification(&host).await; + + drop(harness.event_tx); + wait_for_completion(&host).await; + + assert_eq!( + harness.handle.observe(ObserveMode::PendingFrontier).await, + Ok(CellEvent::Completed { + content_items: vec![OutputItem::Text { + text: "pending output".to_string(), + }], + error_text: Some("exec runtime ended unexpectedly".to_string()), + }) + ); + harness.task.await.unwrap(); +} + #[tokio::test] async fn only_the_first_termination_claims_a_buffered_completion() { let cell_state = CellState::new(CancellationToken::new()); diff --git a/codex-rs/code-mode/src/cell_actor/types.rs b/codex-rs/code-mode/src/cell_actor/types.rs index 36625616d01b..3a0962b7cd4a 100644 --- a/codex-rs/code-mode/src/cell_actor/types.rs +++ b/codex-rs/code-mode/src/cell_actor/types.rs @@ -12,17 +12,26 @@ use tokio_util::sync::CancellationToken; use crate::session_runtime::CellEvent; use crate::session_runtime::ObserveMode; use crate::session_runtime::OutputItem; +use crate::session_runtime::PendingGeneration; +use crate::session_runtime::ResumeOutcome; use crate::session_runtime::ToolKind; use crate::session_runtime::ToolName; pub(crate) type CellEventFuture = Pin> + Send + 'static>>; +pub(crate) type ResumeFuture = + Pin> + Send + 'static>>; + #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum CellError { Busy, AlreadyTerminating, Closed, + InvalidGeneration { + requested: PendingGeneration, + latest: Option, + }, } pub(crate) struct CellToolCall { @@ -91,6 +100,24 @@ impl CellHandle { response_event(response_rx) } + pub(crate) fn resume(&self, generation: PendingGeneration) -> ResumeFuture { + if !self.state.accepting_observations() { + return closed_resume(); + } + let (response_tx, response_rx) = oneshot::channel(); + if self + .command_tx + .send(CellCommand::Resume { + generation, + response_tx, + }) + .is_err() + { + return closed_resume(); + } + Box::pin(async move { response_rx.await.unwrap_or(Err(CellError::Closed)) }) + } + pub(crate) fn terminate(&self) -> CellEventFuture { self.state.request_termination() } @@ -426,15 +453,10 @@ fn prepend_initial_yield( content_items: pending_initial_yield_items, } } - CellEvent::Pending { - mut content_items, - pending_tool_call_ids, - } => { - pending_initial_yield_items.append(&mut content_items); - CellEvent::Pending { - content_items: pending_initial_yield_items, - pending_tool_call_ids, - } + CellEvent::Pending(mut frontier) => { + pending_initial_yield_items.append(&mut frontier.content_items); + frontier.content_items = pending_initial_yield_items; + CellEvent::Pending(frontier) } CellEvent::Completed { mut content_items, @@ -460,6 +482,10 @@ pub(super) enum CellCommand { mode: ObserveMode, response_tx: oneshot::Sender>, }, + Resume { + generation: PendingGeneration, + response_tx: oneshot::Sender>, + }, } fn response_event(response_rx: oneshot::Receiver>) -> CellEventFuture { @@ -473,3 +499,7 @@ fn ready_event(event: CellEvent) -> CellEventFuture { fn closed_event() -> CellEventFuture { Box::pin(async { Err(CellError::Closed) }) } + +fn closed_resume() -> ResumeFuture { + Box::pin(async { Err(CellError::Closed) }) +} diff --git a/codex-rs/code-mode/src/runtime/mod.rs b/codex-rs/code-mode/src/runtime/mod.rs index 42af5c0b9577..9e8f6878c507 100644 --- a/codex-rs/code-mode/src/runtime/mod.rs +++ b/codex-rs/code-mode/src/runtime/mod.rs @@ -25,7 +25,6 @@ pub(crate) enum RuntimeCommand { ToolResponse { id: String, result: JsonValue }, ToolError { id: String, error_text: String }, TimeoutFired { id: u64 }, - ObservePendingFrontier, Terminate, } @@ -39,14 +38,15 @@ pub(crate) enum PendingRuntimeMode { #[derive(Debug)] pub(crate) enum RuntimeControlCommand { Continue, - Resume, Terminate, } #[derive(Debug)] pub(crate) enum RuntimeEvent { Started, - Pending, + Pending { + pending_tool_call_ids: Vec, + }, ContentItem(FunctionCallOutputContentItem), YieldRequested, ToolCall { @@ -221,9 +221,21 @@ fn run_runtime( } let mut pending_promise = pending_promise; - while let Some(command) = - next_runtime_command(&event_tx, &command_rx, &control_rx, pending_mode) - { + loop { + let mut pending_tool_call_ids = scope + .get_slot::() + .map(|state| state.pending_tool_calls.keys().cloned().collect::>()) + .unwrap_or_default(); + pending_tool_call_ids.sort(); + let Some(command) = next_runtime_command( + &event_tx, + &command_rx, + &control_rx, + pending_mode, + pending_tool_call_ids, + ) else { + break; + }; match command { RuntimeCommand::Terminate => break, RuntimeCommand::ToolResponse { id, result } => { @@ -248,7 +260,6 @@ fn run_runtime( return; } } - RuntimeCommand::ObservePendingFrontier => {} } scope.perform_microtask_checkpoint(); @@ -277,24 +288,24 @@ fn next_runtime_command( command_rx: &std_mpsc::Receiver, control_rx: &std_mpsc::Receiver, pending_mode: PendingRuntimeMode, + pending_tool_call_ids: Vec, ) -> Option { - loop { - match command_rx.try_recv() { - Ok(command) => return Some(command), - Err(std_mpsc::TryRecvError::Disconnected) => return None, - Err(std_mpsc::TryRecvError::Empty) => {} - } + match command_rx.try_recv() { + Ok(command) => return Some(command), + Err(std_mpsc::TryRecvError::Disconnected) => return None, + Err(std_mpsc::TryRecvError::Empty) => {} + } - let _ = event_tx.send(RuntimeEvent::Pending); - match pending_mode { - #[cfg(test)] - PendingRuntimeMode::Continue => return command_rx.recv().ok(), - PendingRuntimeMode::PauseUntilResumed => match control_rx.recv().ok()? { - RuntimeControlCommand::Continue => return command_rx.recv().ok(), - RuntimeControlCommand::Resume => continue, - RuntimeControlCommand::Terminate => return Some(RuntimeCommand::Terminate), - }, - } + let _ = event_tx.send(RuntimeEvent::Pending { + pending_tool_call_ids, + }); + match pending_mode { + #[cfg(test)] + PendingRuntimeMode::Continue => command_rx.recv().ok(), + PendingRuntimeMode::PauseUntilResumed => match control_rx.recv().ok()? { + RuntimeControlCommand::Continue => command_rx.recv().ok(), + RuntimeControlCommand::Terminate => Some(RuntimeCommand::Terminate), + }, } } @@ -385,7 +396,7 @@ mod tests { } #[tokio::test] - async fn pending_mode_freezes_runtime_commands_until_resume() { + async fn pending_mode_freezes_runtime_commands_until_continue() { let (event_tx, mut event_rx) = mpsc::unbounded_channel(); let (runtime_tx, runtime_control_tx, _runtime_terminate_handle) = spawn_runtime( HashMap::new(), @@ -413,7 +424,7 @@ await new Promise(() => {}); .await .unwrap() .unwrap(), - RuntimeEvent::Pending + RuntimeEvent::Pending { .. } )); runtime_tx @@ -426,7 +437,7 @@ await new Promise(() => {}); ); runtime_control_tx - .send(RuntimeControlCommand::Resume) + .send(RuntimeControlCommand::Continue) .unwrap(); let content_event = tokio::time::timeout(Duration::from_secs(1), event_rx.recv()) @@ -444,7 +455,7 @@ await new Promise(() => {}); .await .unwrap() .unwrap(), - RuntimeEvent::Pending + RuntimeEvent::Pending { .. } )); runtime_control_tx diff --git a/codex-rs/code-mode/src/service.rs b/codex-rs/code-mode/src/service.rs index 59afe9b7b1a7..a527080c8e91 100644 --- a/codex-rs/code-mode/src/service.rs +++ b/codex-rs/code-mode/src/service.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -23,6 +24,7 @@ use codex_code_mode_protocol::WaitRequest; use codex_code_mode_protocol::WaitToPendingOutcome; use codex_code_mode_protocol::WaitToPendingRequest; use serde_json::Value as JsonValue; +use tokio::sync::Mutex; use tokio::sync::oneshot; use tokio_util::sync::CancellationToken; @@ -74,6 +76,7 @@ impl CodeModeSessionProvider for InProcessCodeModeSessionProvider { pub struct CodeModeService { runtime: SessionRuntime, + pending_generations: Mutex>, } impl CodeModeService { @@ -84,6 +87,7 @@ impl CodeModeService { pub fn with_delegate(delegate: Arc) -> Self { Self { runtime: SessionRuntime::new(Arc::new(ProtocolDelegate { delegate })), + pending_generations: Mutex::new(HashMap::new()), } } @@ -91,10 +95,7 @@ impl CodeModeService { let yield_time_ms = request.yield_time_ms.unwrap_or(DEFAULT_EXEC_YIELD_TIME_MS); let runtime_cell_id = self .runtime - .create_cell_with_execution_policy( - runtime_request(request), - runtime::CellExecutionPolicy::PauseAtPendingFrontier, - ) + .create_cell(runtime_request(request)) .await .map_err(|error| error.to_string())?; let pending_event = self @@ -125,15 +126,17 @@ impl CodeModeService { ) -> Result { let runtime_cell_id = self .runtime - .create_cell(runtime_request(request)) + .create_pausable_cell(runtime_request(request)) .await .map_err(|error| error.to_string())?; let cell_id = protocol_cell_id(&runtime_cell_id); let event = self .runtime - .observe(&runtime_cell_id, runtime::ObserveMode::PendingFrontier) + .wait_to_pending(&runtime_cell_id) .await .map_err(|error| error.to_string())?; + self.record_pending_generation(&runtime_cell_id, &event) + .await; pending_outcome(&cell_id, event) } @@ -175,13 +178,19 @@ impl CodeModeService { } pub async fn terminate(&self, cell_id: CellId) -> Result { - match self.runtime.terminate(&runtime_cell_id(&cell_id)).await { + let runtime_cell_id = runtime_cell_id(&cell_id); + let outcome = match self.runtime.terminate(&runtime_cell_id).await { Ok(event) => Ok(WaitOutcome::LiveCell(runtime_response(&cell_id, event)?)), Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => { Ok(WaitOutcome::MissingCell(missing_cell_response(cell_id))) } Err(error) => Err(error.to_string()), - } + }; + self.pending_generations + .lock() + .await + .remove(&runtime_cell_id); + outcome } pub async fn wait_to_pending( @@ -189,17 +198,28 @@ impl CodeModeService { request: WaitToPendingRequest, ) -> Result { let cell_id = request.cell_id; - match self - .runtime - .observe( - &runtime_cell_id(&cell_id), - runtime::ObserveMode::PendingFrontier, - ) - .await - { - Ok(event) => Ok(WaitToPendingOutcome::LiveCell(pending_outcome( - &cell_id, event, - )?)), + let runtime_cell_id = runtime_cell_id(&cell_id); + let generation = { + self.pending_generations + .lock() + .await + .get(&runtime_cell_id) + .copied() + }; + if let Some(generation) = generation { + self.runtime + .resume(&runtime_cell_id, generation) + .await + .map_err(|error| error.to_string())?; + } + match self.runtime.wait_to_pending(&runtime_cell_id).await { + Ok(event) => { + self.record_pending_generation(&runtime_cell_id, &event) + .await; + Ok(WaitToPendingOutcome::LiveCell(pending_outcome( + &cell_id, event, + )?)) + } Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => Ok( WaitToPendingOutcome::MissingCell(missing_cell_response(cell_id)), ), @@ -207,11 +227,31 @@ impl CodeModeService { } } + async fn record_pending_generation( + &self, + cell_id: &runtime::CellId, + event: &runtime::CellEvent, + ) { + let mut generations = self.pending_generations.lock().await; + match event { + runtime::CellEvent::Pending(frontier) => { + generations.insert(cell_id.clone(), frontier.generation); + } + runtime::CellEvent::Yielded { .. } => {} + runtime::CellEvent::Completed { .. } | runtime::CellEvent::Terminated { .. } => { + generations.remove(cell_id); + } + } + } + pub async fn shutdown(&self) -> Result<(), String> { - self.runtime + let result = self + .runtime .shutdown() .await - .map_err(|error| error.to_string()) + .map_err(|error| error.to_string()); + self.pending_generations.lock().await.clear(); + result } } @@ -334,10 +374,11 @@ fn pending_outcome( event: runtime::CellEvent, ) -> Result { match event { - runtime::CellEvent::Pending { + runtime::CellEvent::Pending(runtime::PendingFrontier { content_items, pending_tool_call_ids, - } => Ok(ExecuteToPendingOutcome::Pending { + .. + }) => Ok(ExecuteToPendingOutcome::Pending { cell_id: cell_id.clone(), content_items: content_items.into_iter().map(output_item).collect(), pending_tool_call_ids, @@ -369,7 +410,7 @@ fn runtime_response( cell_id: cell_id.clone(), content_items: content_items.into_iter().map(output_item).collect(), }), - runtime::CellEvent::Pending { .. } => { + runtime::CellEvent::Pending(_) => { Err("cell returned a pending frontier unexpectedly".to_string()) } } diff --git a/codex-rs/code-mode/src/service_tests.rs b/codex-rs/code-mode/src/service_tests.rs index 2a5e88ddbd45..3da411815e73 100644 --- a/codex-rs/code-mode/src/service_tests.rs +++ b/codex-rs/code-mode/src/service_tests.rs @@ -329,7 +329,7 @@ await Promise.all([ } #[tokio::test] -async fn execute_to_pending_excludes_delayed_timeout_tool_calls_until_wait() { +async fn wait_to_pending_retains_outstanding_calls_when_a_delayed_call_is_added() { let service = CodeModeService::new(); let initial_response = service @@ -391,7 +391,11 @@ await Promise.all([ WaitToPendingOutcome::LiveCell(ExecuteToPendingOutcome::Pending { cell_id: cell_id("1"), content_items: Vec::new(), - pending_tool_call_ids: vec!["tool-3".to_string()], + pending_tool_call_ids: vec![ + "tool-1".to_string(), + "tool-2".to_string(), + "tool-3".to_string(), + ], }) ); diff --git a/codex-rs/code-mode/src/session_runtime/mod.rs b/codex-rs/code-mode/src/session_runtime/mod.rs index 2a5e84327333..683f6346d3cc 100644 --- a/codex-rs/code-mode/src/session_runtime/mod.rs +++ b/codex-rs/code-mode/src/session_runtime/mod.rs @@ -21,6 +21,9 @@ pub(crate) use self::types::ImageDetail; pub(crate) use self::types::NestedToolCall; pub(crate) use self::types::ObserveMode; pub(crate) use self::types::OutputItem; +pub(crate) use self::types::PendingFrontier; +pub(crate) use self::types::PendingGeneration; +pub(crate) use self::types::ResumeOutcome; pub(crate) use self::types::SessionRuntimeDelegate; pub(crate) use self::types::ToolDefinition; pub(crate) use self::types::ToolKind; @@ -73,7 +76,15 @@ impl SessionRuntime { .await } - pub(crate) async fn create_cell_with_execution_policy( + pub(crate) async fn create_pausable_cell( + &self, + request: CreateCellRequest, + ) -> Result { + self.create_cell_with_execution_policy(request, CellExecutionPolicy::PauseAtPendingFrontier) + .await + } + + async fn create_cell_with_execution_policy( &self, request: CreateCellRequest, execution_policy: CellExecutionPolicy, @@ -113,6 +124,29 @@ impl SessionRuntime { }) } + pub(crate) async fn wait_to_pending(&self, cell_id: &CellId) -> Result { + self.observe(cell_id, ObserveMode::PendingFrontier).await + } + + pub(crate) async fn resume( + &self, + cell_id: &CellId, + generation: PendingGeneration, + ) -> Result { + let handle = self + .inner + .cells + .lock() + .await + .get(cell_id) + .cloned() + .ok_or_else(|| Error::MissingCell(cell_id.clone()))?; + handle + .resume(generation) + .await + .map_err(|error| actor_error(cell_id, error)) + } + pub(crate) async fn terminate(&self, cell_id: &CellId) -> Result { let handle = self .inner @@ -170,7 +204,7 @@ impl SessionRuntime { let (handle, task) = CellActor::prepare(request, stored_values, host, cell_state, execution_policy) .map_err(Error::Runtime)?; - cells.insert(cell_id.clone(), handle); + cells.insert(cell_id, handle); self.inner.cell_tasks.spawn(task); drop(cells); Ok(()) @@ -272,6 +306,11 @@ fn actor_error(cell_id: &CellId, error: CellError) -> Error { CellError::Busy => Error::BusyObserver(cell_id.clone()), CellError::AlreadyTerminating => Error::AlreadyTerminating(cell_id.clone()), CellError::Closed => Error::ClosedCell(cell_id.clone()), + CellError::InvalidGeneration { requested, latest } => Error::InvalidGeneration { + cell_id: cell_id.clone(), + requested, + latest, + }, } } diff --git a/codex-rs/code-mode/src/session_runtime/tests.rs b/codex-rs/code-mode/src/session_runtime/tests.rs index 3b32fd953322..bb7d899f3344 100644 --- a/codex-rs/code-mode/src/session_runtime/tests.rs +++ b/codex-rs/code-mode/src/session_runtime/tests.rs @@ -8,6 +8,7 @@ use std::time::Duration; use pretty_assertions::assert_eq; use serde_json::Value as JsonValue; +use tokio::sync::Semaphore; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; @@ -20,6 +21,11 @@ struct ImmediateToolDelegate { invocations_tx: mpsc::UnboundedSender, } +struct BlockingToolDelegate { + invocations_tx: mpsc::UnboundedSender, + release: Arc, +} + impl SessionRuntimeDelegate for RecordingDelegate { async fn invoke_tool( &self, @@ -65,6 +71,34 @@ impl SessionRuntimeDelegate for ImmediateToolDelegate { fn cell_closed(&self, _cell_id: &CellId) {} } +impl SessionRuntimeDelegate for BlockingToolDelegate { + async fn invoke_tool( + &self, + invocation: NestedToolCall, + cancellation_token: CancellationToken, + ) -> Result { + let _ = self.invocations_tx.send(invocation.tool_name.name); + let permit = tokio::select! { + permit = self.release.acquire() => permit.map_err(|error| error.to_string())?, + () = cancellation_token.cancelled() => return Err("cancelled".to_string()), + }; + permit.forget(); + Ok(JsonValue::Null) + } + + async fn notify( + &self, + _call_id: String, + _cell_id: CellId, + _text: String, + _cancellation_token: CancellationToken, + ) -> Result<(), String> { + Ok(()) + } + + fn cell_closed(&self, _cell_id: &CellId) {} +} + fn tool_definition(name: &str) -> ToolDefinition { ToolDefinition { name: name.to_string(), @@ -78,10 +112,10 @@ fn tool_definition(name: &str) -> ToolDefinition { } #[tokio::test] -async fn continuing_cell_resolves_tools_before_the_first_observation() { +async fn default_policy_resolves_tools_before_the_first_observation() { let (invocations_tx, mut invocations_rx) = mpsc::unbounded_channel(); let runtime = SessionRuntime::new(Arc::new(ImmediateToolDelegate { invocations_tx })); - let cell_id = runtime + let cell = runtime .create_cell(CreateCellRequest { tool_call_id: "call-1".to_string(), enabled_tools: vec![tool_definition("first"), tool_definition("second")], @@ -109,8 +143,241 @@ text("done"); ); assert_eq!( runtime - .observe(&cell_id, ObserveMode::YieldAfter(Duration::from_secs(1))) + .observe(&cell, ObserveMode::YieldAfter(Duration::from_secs(1))) + .await, + Ok(CellEvent::Completed { + content_items: vec![OutputItem::Text { + text: "done".to_string(), + }], + error_text: None, + }) + ); + runtime.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn pausable_cell_supports_a_synchronous_host_driver() { + let (invocations_tx, mut invocations_rx) = mpsc::unbounded_channel(); + let release = Arc::new(Semaphore::new(0)); + let runtime = SessionRuntime::new(Arc::new(BlockingToolDelegate { + invocations_tx, + release: Arc::clone(&release), + })); + let cell = runtime + .create_pausable_cell(CreateCellRequest { + tool_call_id: "call-1".to_string(), + enabled_tools: vec![tool_definition("first"), tool_definition("second")], + source: r#" +await tools.first({}); +await tools.second({}); +text("done"); +"# + .to_string(), + }) + .await + .unwrap(); + + assert_eq!( + tokio::time::timeout(Duration::from_secs(1), invocations_rx.recv()) + .await + .expect("first tool invocation timed out"), + Some("first".to_string()) + ); + let first = runtime.wait_to_pending(&cell).await.unwrap(); + let CellEvent::Pending(first_frontier) = &first else { + panic!("expected the first pending frontier, got {first:?}"); + }; + assert_eq!( + first_frontier.generation, + PendingGeneration::new(/*value*/ 1) + ); + assert_eq!(runtime.wait_to_pending(&cell).await, Ok(first.clone())); + assert!( + tokio::time::timeout(Duration::from_millis(50), invocations_rx.recv()) + .await + .is_err() + ); + + let (first_resume, duplicate_resume) = tokio::join!( + runtime.resume(&cell, first_frontier.generation), + runtime.resume(&cell, first_frontier.generation), + ); + assert!(matches!( + (first_resume, duplicate_resume), + ( + Ok(ResumeOutcome::Resumed), + Ok(ResumeOutcome::AlreadyRunning) + ) | ( + Ok(ResumeOutcome::AlreadyRunning), + Ok(ResumeOutcome::Resumed) + ) + )); + release.add_permits(1); + assert_eq!( + tokio::time::timeout(Duration::from_secs(1), invocations_rx.recv()) + .await + .expect("second tool invocation timed out"), + Some("second".to_string()) + ); + + let second = runtime.wait_to_pending(&cell).await.unwrap(); + let CellEvent::Pending(second_frontier) = &second else { + panic!("expected the second pending frontier, got {second:?}"); + }; + assert_eq!( + second_frontier.generation, + PendingGeneration::new(/*value*/ 2) + ); + assert_eq!( + runtime.resume(&cell, first_frontier.generation).await, + Ok(ResumeOutcome::AlreadyRunning) + ); + assert_eq!( + runtime + .resume(&cell, PendingGeneration::new(/*value*/ 3)) .await, + Err(Error::InvalidGeneration { + cell_id: cell.clone(), + requested: PendingGeneration::new(/*value*/ 3), + latest: Some(PendingGeneration::new(/*value*/ 2)), + }) + ); + assert_eq!( + runtime.resume(&cell, second_frontier.generation).await, + Ok(ResumeOutcome::Resumed) + ); + release.add_permits(1); + assert_eq!( + runtime.wait_to_pending(&cell).await, + Ok(CellEvent::Completed { + content_items: vec![OutputItem::Text { + text: "done".to_string(), + }], + error_text: None, + }) + ); + runtime.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn pending_frontier_reports_only_authoritatively_outstanding_parallel_tools() { + let (invocations_tx, mut invocations_rx) = mpsc::unbounded_channel(); + let release = Arc::new(Semaphore::new(0)); + let runtime = SessionRuntime::new(Arc::new(BlockingToolDelegate { + invocations_tx, + release: Arc::clone(&release), + })); + let cell = runtime + .create_pausable_cell(CreateCellRequest { + tool_call_id: "call-1".to_string(), + enabled_tools: vec![tool_definition("first"), tool_definition("second")], + source: r#" +await Promise.all([tools.first({}), tools.second({})]); +text("done"); +"# + .to_string(), + }) + .await + .unwrap(); + + let mut invocations = vec![ + tokio::time::timeout(Duration::from_secs(1), invocations_rx.recv()) + .await + .expect("first tool invocation timed out") + .expect("first tool invocation channel closed"), + tokio::time::timeout(Duration::from_secs(1), invocations_rx.recv()) + .await + .expect("second tool invocation timed out") + .expect("second tool invocation channel closed"), + ]; + invocations.sort(); + assert_eq!(invocations, vec!["first".to_string(), "second".to_string()]); + + let CellEvent::Pending(first_frontier) = runtime.wait_to_pending(&cell).await.unwrap() else { + panic!("expected first pending frontier"); + }; + assert_eq!( + first_frontier.pending_tool_call_ids, + vec!["tool-1".to_string(), "tool-2".to_string()] + ); + + assert_eq!( + runtime.resume(&cell, first_frontier.generation).await, + Ok(ResumeOutcome::Resumed) + ); + release.add_permits(1); + + let CellEvent::Pending(second_frontier) = runtime.wait_to_pending(&cell).await.unwrap() else { + panic!("expected second pending frontier"); + }; + assert_eq!(second_frontier.pending_tool_call_ids.len(), 1); + assert!( + first_frontier + .pending_tool_call_ids + .contains(&second_frontier.pending_tool_call_ids[0]) + ); + + assert_eq!( + runtime.resume(&cell, second_frontier.generation).await, + Ok(ResumeOutcome::Resumed) + ); + release.add_permits(1); + assert_eq!( + runtime.wait_to_pending(&cell).await, + Ok(CellEvent::Completed { + content_items: vec![OutputItem::Text { + text: "done".to_string(), + }], + error_text: None, + }) + ); + runtime.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn pending_observation_waits_for_resumed_work_to_reach_a_new_frontier() { + let (invocations_tx, mut invocations_rx) = mpsc::unbounded_channel(); + let release = Arc::new(Semaphore::new(0)); + let runtime = SessionRuntime::new(Arc::new(BlockingToolDelegate { + invocations_tx, + release: Arc::clone(&release), + })); + let cell = runtime + .create_pausable_cell(CreateCellRequest { + tool_call_id: "call-1".to_string(), + enabled_tools: vec![tool_definition("blocked")], + source: r#" +await tools.blocked({}); +text("done"); +"# + .to_string(), + }) + .await + .unwrap(); + assert_eq!( + tokio::time::timeout(Duration::from_secs(1), invocations_rx.recv()) + .await + .expect("tool invocation timed out"), + Some("blocked".to_string()) + ); + let CellEvent::Pending(frontier) = runtime.wait_to_pending(&cell).await.unwrap() else { + panic!("expected a pending frontier"); + }; + + assert_eq!( + runtime.resume(&cell, frontier.generation).await, + Ok(ResumeOutcome::Resumed) + ); + let next_event = runtime.wait_to_pending(&cell); + tokio::pin!(next_event); + assert!( + tokio::time::timeout(Duration::from_millis(50), &mut next_event) + .await + .is_err() + ); + release.add_permits(1); + assert_eq!( + next_event.await, Ok(CellEvent::Completed { content_items: vec![OutputItem::Text { text: "done".to_string(), @@ -241,15 +508,15 @@ async fn shutdown_rejects_cell_admission_queued_before_the_registry_lock() { #[tokio::test] async fn drop_terminates_cells_when_the_registry_is_locked() { let runtime = SessionRuntime::new(Arc::new(RecordingDelegate)); - let cell_id = runtime + let cell = runtime .create_cell(execute_request("while (true) {}")) .await .unwrap(); - assert_eq!(cell_id, CellId::new("1")); + assert_eq!(cell, CellId::new("1")); assert_eq!( runtime .observe( - &cell_id, + &cell, ObserveMode::YieldAfter(Duration::from_millis(/*millis*/ 1)), ) .await, diff --git a/codex-rs/code-mode/src/session_runtime/types.rs b/codex-rs/code-mode/src/session_runtime/types.rs index ea6816907fe5..21e8efd966f8 100644 --- a/codex-rs/code-mode/src/session_runtime/types.rs +++ b/codex-rs/code-mode/src/session_runtime/types.rs @@ -25,41 +25,51 @@ impl fmt::Display for CellId { } } -/// Selects the next observable frontier for a running cell. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) enum ObserveMode { - YieldAfter(Duration), - PendingFrontier, -} - /// Controls how a cell advances when its runtime is waiting for external input. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum CellExecutionPolicy { /// Process tool and timer results even when no observation is attached. ContinueWhenUnblocked, - /// Remain paused at a pending frontier until pending execution is advanced. + /// Remain paused at a pending frontier until an explicit resume advances it. PauseAtPendingFrontier, } -impl From for CellExecutionPolicy { - fn from(mode: ObserveMode) -> Self { - match mode { - ObserveMode::YieldAfter(_) => Self::ContinueWhenUnblocked, - ObserveMode::PendingFrontier => Self::PauseAtPendingFrontier, - } +/// Selects the next observable frontier for a running cell. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum ObserveMode { + YieldAfter(Duration), + PendingFrontier, +} + +/// Identifies one durable pending frontier of a pausable cell. +#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] +pub(crate) struct PendingGeneration(u64); + +impl PendingGeneration { + pub(crate) fn new(value: u64) -> Self { + Self(value) + } + + pub(crate) fn get(self) -> u64 { + self.0 } } +/// A repeatable snapshot of one paused runtime frontier. +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct PendingFrontier { + pub(crate) generation: PendingGeneration, + pub(crate) content_items: Vec, + pub(crate) pending_tool_call_ids: Vec, +} + /// An observable cell lifecycle event. #[derive(Clone, Debug, PartialEq)] pub(crate) enum CellEvent { Yielded { content_items: Vec, }, - Pending { - content_items: Vec, - pending_tool_call_ids: Vec, - }, + Pending(PendingFrontier), Completed { content_items: Vec, error_text: Option, @@ -69,6 +79,12 @@ pub(crate) enum CellEvent { }, } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum ResumeOutcome { + Resumed, + AlreadyRunning, +} + /// Output emitted by a cell since its preceding observation. #[derive(Clone, Debug, PartialEq)] pub(crate) enum OutputItem { @@ -161,6 +177,11 @@ pub(crate) enum Error { BusyObserver(CellId), AlreadyTerminating(CellId), ClosedCell(CellId), + InvalidGeneration { + cell_id: CellId, + requested: PendingGeneration, + latest: Option, + }, Runtime(String), } @@ -182,6 +203,19 @@ impl fmt::Display for Error { Self::ClosedCell(cell_id) => { write!(formatter, "exec cell {cell_id} closed unexpectedly") } + Self::InvalidGeneration { + cell_id, + requested, + latest, + } => write!( + formatter, + "exec cell {cell_id} cannot resume generation {}; latest generation is {}", + requested.get(), + latest.map_or_else( + || "none".to_string(), + |generation| generation.get().to_string() + ) + ), Self::Runtime(error_text) => formatter.write_str(error_text), } }