diff --git a/codex-rs/code-mode/src/cell_actor/mod.rs b/codex-rs/code-mode/src/cell_actor/mod.rs index c1ff0f27ed9e..b04fc7f96984 100644 --- a/codex-rs/code-mode/src/cell_actor/mod.rs +++ b/codex-rs/code-mode/src/cell_actor/mod.rs @@ -22,6 +22,9 @@ use self::conversions::output_item; use self::conversions::runtime_request; use self::types::CellCommand; pub(crate) use self::types::CellError; +#[cfg(test)] +pub(crate) use self::types::CellEvent; +pub(crate) use self::types::CellEvent as ActorEvent; pub(crate) use self::types::CellEventFuture; pub(crate) use self::types::CellHandle; pub(crate) use self::types::CellHost; @@ -30,15 +33,14 @@ pub(crate) use self::types::CellToolCall; pub(crate) use self::types::CompletionCommit; use self::types::CompletionDelivery; use self::types::ObservationDelivery; +pub(crate) use self::types::ObserveMode; use crate::runtime::PendingRuntimeMode; use crate::runtime::RuntimeCommand; use crate::runtime::RuntimeControlCommand; use crate::runtime::RuntimeEvent; use crate::runtime::spawn_runtime; -use crate::session_runtime::CellEvent; use crate::session_runtime::CellExecutionPolicy; use crate::session_runtime::CreateCellRequest as CellRequest; -use crate::session_runtime::ObserveMode; use crate::session_runtime::OutputItem; use crate::session_runtime::PendingFrontier; use crate::session_runtime::PendingGeneration; @@ -89,7 +91,7 @@ struct CellContext { struct Observer { mode: ObserveMode, - response_tx: oneshot::Sender>, + response_tx: oneshot::Sender>, } async fn run_cell( @@ -148,7 +150,7 @@ async fn run_cell( finish_termination( &cell_state, observer.take().map(|observer| observer.response_tx), - CellEvent::Terminated { + ActorEvent::Terminated { content_items: take_termination_content( &mut pending_frontier, pending_frontier_observed, @@ -235,12 +237,12 @@ async fn run_cell( { let delivered = match send_cell_event( response_tx, - CellEvent::Yielded { + ActorEvent::Yielded { content_items: yielded_items, }, ) { Ok(()) => true, - Err(CellEvent::Yielded { content_items }) => { + Err(ActorEvent::Yielded { content_items }) => { pending_initial_yield_items = Some(content_items); has_been_observed = false; false @@ -260,7 +262,7 @@ async fn run_cell( if matches!(mode, ObserveMode::PendingFrontier) && let Some(frontier) = pending_frontier.as_ref() { - if send_cell_event(response_tx, CellEvent::Pending(frontier.clone())).is_ok() { + if send_cell_event(response_tx, ActorEvent::Pending(frontier.clone())).is_ok() { pending_frontier_observed = true; } continue; @@ -285,7 +287,7 @@ async fn run_cell( restore_undelivered_yield( send_observer_event( observer.take(), - CellEvent::Yielded { + ActorEvent::Yielded { content_items: std::mem::take(&mut content_items), }, ), @@ -317,7 +319,7 @@ async fn run_cell( finish_termination( &cell_state, observer.take().map(|observer| observer.response_tx), - CellEvent::Terminated { + ActorEvent::Terminated { content_items: termination_content_items, }, ); @@ -330,7 +332,7 @@ async fn run_cell( CallbackCompletion::DrainNotifications, ) .await; - let event = CellEvent::Completed { + let event = ActorEvent::Completed { content_items: take_termination_content( &mut pending_frontier, pending_frontier_observed, @@ -362,7 +364,7 @@ async fn run_cell( finish_termination( &cell_state, response_tx, - CellEvent::Terminated { + ActorEvent::Terminated { content_items: rejected_completion_content(rejected_event), }, ); @@ -408,7 +410,7 @@ async fn run_cell( yield_timer = None; if send_cell_event( observer.response_tx, - CellEvent::Pending(frontier.clone()), + ActorEvent::Pending(frontier.clone()), ) .is_ok() { @@ -433,7 +435,7 @@ async fn run_cell( restore_undelivered_yield( send_observer_event( observer.take(), - CellEvent::Yielded { + ActorEvent::Yielded { content_items: std::mem::take(&mut content_items), }, ), @@ -491,7 +493,7 @@ async fn run_cell( finish_termination( &cell_state, observer.take().map(|observer| observer.response_tx), - CellEvent::Terminated { + ActorEvent::Terminated { content_items: termination_content_items, }, ); @@ -504,7 +506,7 @@ async fn run_cell( CallbackCompletion::DrainNotifications, ) .await; - let event = CellEvent::Completed { + let event = ActorEvent::Completed { content_items: std::mem::take(&mut content_items), error_text, }; @@ -531,7 +533,7 @@ async fn run_cell( finish_termination( &cell_state, response_tx, - CellEvent::Terminated { + ActorEvent::Terminated { content_items: rejected_completion_content(rejected_event), }, ); @@ -568,7 +570,7 @@ async fn run_cell( host.closed().await; } -fn send_observer_event(observer: Option, event: CellEvent) -> Result<(), CellEvent> { +fn send_observer_event(observer: Option, event: ActorEvent) -> Result<(), ActorEvent> { let Some(observer) = observer else { return Err(event); }; @@ -576,9 +578,9 @@ fn send_observer_event(observer: Option, event: CellEvent) -> Result<( } fn send_cell_event( - response_tx: oneshot::Sender>, - event: CellEvent, -) -> Result<(), CellEvent> { + response_tx: oneshot::Sender>, + event: ActorEvent, +) -> Result<(), ActorEvent> { match response_tx.send(Ok(event)) { Ok(()) => Ok(()), Err(Ok(event)) => Err(event), @@ -586,10 +588,13 @@ fn send_cell_event( } } -fn restore_undelivered_yield(delivery: Result<(), CellEvent>, content_items: &mut Vec) { +fn restore_undelivered_yield( + delivery: Result<(), ActorEvent>, + content_items: &mut Vec, +) { match delivery { Ok(()) => {} - Err(CellEvent::Yielded { + Err(ActorEvent::Yielded { content_items: mut undelivered_items, }) => { undelivered_items.append(content_items); @@ -599,9 +604,9 @@ fn restore_undelivered_yield(delivery: Result<(), CellEvent>, content_items: &mu } } -fn rejected_completion_content(event: Option) -> Vec { +fn rejected_completion_content(event: Option) -> Vec { match event { - Some(CellEvent::Completed { content_items, .. }) => content_items, + Some(ActorEvent::Completed { content_items, .. }) => content_items, None => Vec::new(), Some(event) => panic!("completion commit rejected an unexpected event: {event:?}"), } @@ -638,8 +643,8 @@ fn take_termination_content( fn finish_termination( cell_state: &CellState, - observer_tx: Option>>, - event: CellEvent, + observer_tx: Option>>, + event: ActorEvent, ) { if let Some(event) = cell_state.finish_termination(event) && let Some(observer_tx) = observer_tx diff --git a/codex-rs/code-mode/src/cell_actor/types.rs b/codex-rs/code-mode/src/cell_actor/types.rs index 3a0962b7cd4a..425a1e5d520f 100644 --- a/codex-rs/code-mode/src/cell_actor/types.rs +++ b/codex-rs/code-mode/src/cell_actor/types.rs @@ -3,15 +3,15 @@ use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::sync::Mutex; +use std::time::Duration; use serde_json::Value as JsonValue; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio_util::sync::CancellationToken; -use crate::session_runtime::CellEvent; -use crate::session_runtime::ObserveMode; use crate::session_runtime::OutputItem; +use crate::session_runtime::PendingFrontier; use crate::session_runtime::PendingGeneration; use crate::session_runtime::ResumeOutcome; use crate::session_runtime::ToolKind; @@ -23,6 +23,27 @@ pub(crate) type CellEventFuture = pub(crate) type ResumeFuture = Pin> + Send + 'static>>; +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum ObserveMode { + YieldAfter(Duration), + PendingFrontier, +} + +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum CellEvent { + Yielded { + content_items: Vec, + }, + Pending(PendingFrontier), + Completed { + content_items: Vec, + error_text: Option, + }, + Terminated { + content_items: Vec, + }, +} + #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum CellError { Busy, diff --git a/codex-rs/code-mode/src/service.rs b/codex-rs/code-mode/src/service.rs index a527080c8e91..9f2df24793f4 100644 --- a/codex-rs/code-mode/src/service.rs +++ b/codex-rs/code-mode/src/service.rs @@ -93,20 +93,17 @@ impl CodeModeService { pub async fn execute(&self, request: ExecuteRequest) -> Result { let yield_time_ms = request.yield_time_ms.unwrap_or(DEFAULT_EXEC_YIELD_TIME_MS); - let runtime_cell_id = self + let cell = self .runtime .create_cell(runtime_request(request)) .await .map_err(|error| error.to_string())?; let pending_event = self .runtime - .begin_observe( - &runtime_cell_id, - runtime::ObserveMode::YieldAfter(Duration::from_millis(yield_time_ms)), - ) + .begin_wait(&cell, Duration::from_millis(yield_time_ms)) .await .map_err(|error| error.to_string())?; - let cell_id = protocol_cell_id(&runtime_cell_id); + let cell_id = protocol_cell_id(cell.id()); let response_cell_id = cell_id.clone(); let (response_tx, response_rx) = oneshot::channel(); tokio::spawn(async move { @@ -124,19 +121,18 @@ impl CodeModeService { &self, request: ExecuteRequest, ) -> Result { - let runtime_cell_id = self + let cell = self .runtime .create_pausable_cell(runtime_request(request)) .await .map_err(|error| error.to_string())?; - let cell_id = protocol_cell_id(&runtime_cell_id); + let cell_id = protocol_cell_id(cell.id()); let event = self .runtime - .wait_to_pending(&runtime_cell_id) + .wait_to_pending(&cell) .await .map_err(|error| error.to_string())?; - self.record_pending_generation(&runtime_cell_id, &event) - .await; + self.record_pending_generation(cell.id(), &event).await; pending_outcome(&cell_id, event) } @@ -153,12 +149,16 @@ impl CodeModeService { yield_time_ms, } = request; let runtime_cell_id = runtime_cell_id(&cell_id); + let cell = match self.runtime.cell(&runtime_cell_id).await { + Ok(cell) => cell, + Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => { + return missing_wait(cell_id); + } + Err(error) => return Box::pin(async move { Err(error.to_string()) }), + }; match self .runtime - .begin_observe( - &runtime_cell_id, - runtime::ObserveMode::YieldAfter(Duration::from_millis(yield_time_ms)), - ) + .begin_wait(&cell, Duration::from_millis(yield_time_ms)) .await { Ok(pending_event) => Box::pin(async move { @@ -199,6 +199,15 @@ impl CodeModeService { ) -> Result { let cell_id = request.cell_id; let runtime_cell_id = runtime_cell_id(&cell_id); + let cell = match self.runtime.pausable_cell(&runtime_cell_id).await { + Ok(cell) => cell, + Err(runtime::Error::MissingCell(_) | runtime::Error::ClosedCell(_)) => { + return Ok(WaitToPendingOutcome::MissingCell(missing_cell_response( + cell_id, + ))); + } + Err(error) => return Err(error.to_string()), + }; let generation = { self.pending_generations .lock() @@ -208,11 +217,11 @@ impl CodeModeService { }; if let Some(generation) = generation { self.runtime - .resume(&runtime_cell_id, generation) + .resume(&cell, generation) .await .map_err(|error| error.to_string())?; } - match self.runtime.wait_to_pending(&runtime_cell_id).await { + match self.runtime.wait_to_pending(&cell).await { Ok(event) => { self.record_pending_generation(&runtime_cell_id, &event) .await; @@ -230,15 +239,15 @@ impl CodeModeService { async fn record_pending_generation( &self, cell_id: &runtime::CellId, - event: &runtime::CellEvent, + event: &runtime::PausableCellEvent, ) { let mut generations = self.pending_generations.lock().await; match event { - runtime::CellEvent::Pending(frontier) => { + runtime::PausableCellEvent::Pending(frontier) => { generations.insert(cell_id.clone(), frontier.generation); } - runtime::CellEvent::Yielded { .. } => {} - runtime::CellEvent::Completed { .. } | runtime::CellEvent::Terminated { .. } => { + runtime::PausableCellEvent::Completed { .. } + | runtime::PausableCellEvent::Terminated { .. } => { generations.remove(cell_id); } } @@ -371,10 +380,10 @@ fn protocol_cell_id(cell_id: &runtime::CellId) -> CellId { fn pending_outcome( cell_id: &CellId, - event: runtime::CellEvent, + event: runtime::PausableCellEvent, ) -> Result { match event { - runtime::CellEvent::Pending(runtime::PendingFrontier { + runtime::PausableCellEvent::Pending(runtime::PendingFrontier { content_items, pending_tool_call_ids, .. @@ -383,9 +392,22 @@ fn pending_outcome( content_items: content_items.into_iter().map(output_item).collect(), pending_tool_call_ids, }), - event => Ok(ExecuteToPendingOutcome::Completed(runtime_response( - cell_id, event, - )?)), + runtime::PausableCellEvent::Completed { + content_items, + error_text, + } => Ok(ExecuteToPendingOutcome::Completed( + RuntimeResponse::Result { + cell_id: cell_id.clone(), + content_items: content_items.into_iter().map(output_item).collect(), + error_text, + }, + )), + runtime::PausableCellEvent::Terminated { content_items } => Ok( + ExecuteToPendingOutcome::Completed(RuntimeResponse::Terminated { + cell_id: cell_id.clone(), + content_items: content_items.into_iter().map(output_item).collect(), + }), + ), } } @@ -410,9 +432,6 @@ fn runtime_response( cell_id: cell_id.clone(), content_items: content_items.into_iter().map(output_item).collect(), }), - runtime::CellEvent::Pending(_) => { - Err("cell returned a pending frontier unexpectedly".to_string()) - } } } diff --git a/codex-rs/code-mode/src/session_runtime/mod.rs b/codex-rs/code-mode/src/session_runtime/mod.rs index 683f6346d3cc..b66b38a18805 100644 --- a/codex-rs/code-mode/src/session_runtime/mod.rs +++ b/codex-rs/code-mode/src/session_runtime/mod.rs @@ -6,21 +6,25 @@ use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; +use std::time::Duration; use serde_json::Value as JsonValue; use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; use tokio_util::task::TaskTracker; +pub(crate) use self::types::Cell; pub(crate) use self::types::CellEvent; pub(crate) use self::types::CellExecutionPolicy; pub(crate) use self::types::CellId; +pub(crate) use self::types::CellKind; pub(crate) use self::types::CreateCellRequest; pub(crate) use self::types::Error; pub(crate) use self::types::ImageDetail; pub(crate) use self::types::NestedToolCall; -pub(crate) use self::types::ObserveMode; pub(crate) use self::types::OutputItem; +pub(crate) use self::types::PausableCell; +pub(crate) use self::types::PausableCellEvent; pub(crate) use self::types::PendingFrontier; pub(crate) use self::types::PendingGeneration; pub(crate) use self::types::ResumeOutcome; @@ -28,6 +32,7 @@ pub(crate) use self::types::SessionRuntimeDelegate; pub(crate) use self::types::ToolDefinition; pub(crate) use self::types::ToolKind; pub(crate) use self::types::ToolName; +use crate::cell_actor::ActorEvent; use crate::cell_actor::CellActor; use crate::cell_actor::CellError; use crate::cell_actor::CellEventFuture; @@ -36,8 +41,11 @@ use crate::cell_actor::CellHost; use crate::cell_actor::CellState; use crate::cell_actor::CellToolCall; use crate::cell_actor::CompletionCommit; +use crate::cell_actor::ObserveMode; type RuntimeEventFuture = Pin> + Send + 'static>>; +type PausableRuntimeEventFuture = + Pin> + Send + 'static>>; /// Owns all cells and shared state for one transport-neutral code-mode session. pub(crate) struct SessionRuntime { @@ -46,13 +54,34 @@ pub(crate) struct SessionRuntime { struct Inner { stored_values: Mutex>, - cells: Mutex>, + cells: Mutex>, cell_tasks: TaskTracker, shutdown_token: CancellationToken, delegate: Arc, next_cell_id: AtomicU64, } +#[derive(Clone)] +enum RegisteredCell { + Continuing(CellHandle), + Pausable(CellHandle), +} + +impl RegisteredCell { + fn kind(&self) -> CellKind { + match self { + Self::Continuing(_) => CellKind::Continuing, + Self::Pausable(_) => CellKind::Pausable, + } + } + + fn handle(&self) -> &CellHandle { + match self { + Self::Continuing(handle) | Self::Pausable(handle) => handle, + } + } +} + impl SessionRuntime { pub(crate) fn new(delegate: Arc) -> Self { Self { @@ -71,84 +100,90 @@ impl SessionRuntime { !self.inner.shutdown_token.is_cancelled() } - pub(crate) async fn create_cell(&self, request: CreateCellRequest) -> Result { - self.create_cell_with_execution_policy(request, CellExecutionPolicy::ContinueWhenUnblocked) - .await + pub(crate) async fn create_cell(&self, request: CreateCellRequest) -> Result { + let id = self + .start_cell( + request, + CellExecutionPolicy::ContinueWhenUnblocked, + CellKind::Continuing, + ) + .await?; + Ok(Cell::new(id)) } pub(crate) async fn create_pausable_cell( &self, request: CreateCellRequest, - ) -> Result { - self.create_cell_with_execution_policy(request, CellExecutionPolicy::PauseAtPendingFrontier) - .await - } - - async fn create_cell_with_execution_policy( - &self, - request: CreateCellRequest, - execution_policy: CellExecutionPolicy, - ) -> Result { - if self.inner.shutdown_token.is_cancelled() { - return Err(Error::ShuttingDown); - } - let cell_id = self.allocate_cell_id(); - self.start_cell(cell_id.clone(), request, execution_policy) + ) -> Result { + let id = self + .start_cell( + request, + CellExecutionPolicy::PauseAtPendingFrontier, + CellKind::Pausable, + ) .await?; - Ok(cell_id) + Ok(PausableCell::new(id)) } - pub(crate) async fn observe( + #[cfg(test)] + pub(crate) async fn wait( &self, - cell_id: &CellId, - mode: ObserveMode, + cell: &Cell, + yield_after: Duration, ) -> Result { - self.begin_observe(cell_id, mode).await?.event().await + self.begin_wait(cell, yield_after).await?.event().await } - pub(crate) async fn begin_observe( + pub(crate) async fn begin_wait( &self, - cell_id: &CellId, - mode: ObserveMode, + cell: &Cell, + yield_after: Duration, ) -> Result { - let handle = self - .inner - .cells - .lock() - .await - .get(cell_id) - .cloned() - .ok_or_else(|| Error::MissingCell(cell_id.clone()))?; + let handle = self.continuing_handle(cell.id()).await?; Ok(PendingEvent { - event: map_actor_event(cell_id.clone(), handle.observe(mode)), + event: map_cell_event( + cell.id().clone(), + handle.observe(ObserveMode::YieldAfter(yield_after)), + ), }) } - pub(crate) async fn wait_to_pending(&self, cell_id: &CellId) -> Result { - self.observe(cell_id, ObserveMode::PendingFrontier).await + pub(crate) async fn wait_to_pending( + &self, + cell: &PausableCell, + ) -> Result { + let handle = self.pausable_handle(cell.id()).await?; + map_pausable_event( + cell.id().clone(), + handle.observe(ObserveMode::PendingFrontier), + ) + .await } pub(crate) async fn resume( &self, - cell_id: &CellId, + cell: &PausableCell, generation: PendingGeneration, ) -> Result { - let handle = self - .inner - .cells - .lock() - .await - .get(cell_id) - .cloned() - .ok_or_else(|| Error::MissingCell(cell_id.clone()))?; + let handle = self.pausable_handle(cell.id()).await?; handle .resume(generation) .await - .map_err(|error| actor_error(cell_id, error)) + .map_err(|error| actor_error(cell.id(), error)) + } + + pub(crate) async fn cell(&self, cell_id: &CellId) -> Result { + self.continuing_handle(cell_id).await?; + Ok(Cell::new(cell_id.clone())) + } + + pub(crate) async fn pausable_cell(&self, cell_id: &CellId) -> Result { + self.pausable_handle(cell_id).await?; + Ok(PausableCell::new(cell_id.clone())) } pub(crate) async fn terminate(&self, cell_id: &CellId) -> Result { - let handle = self + let cell = self .inner .cells .lock() @@ -156,9 +191,10 @@ impl SessionRuntime { .get(cell_id) .cloned() .ok_or_else(|| Error::MissingCell(cell_id.clone()))?; - handle + cell.handle() .terminate() .await + .map(map_terminal_event) .map_err(|error| actor_error(cell_id, error)) } @@ -184,10 +220,14 @@ impl SessionRuntime { async fn start_cell( &self, - cell_id: CellId, request: CreateCellRequest, execution_policy: CellExecutionPolicy, - ) -> Result<(), Error> { + kind: CellKind, + ) -> Result { + if self.inner.shutdown_token.is_cancelled() { + return Err(Error::ShuttingDown); + } + let cell_id = self.allocate_cell_id(); let stored_values = self.inner.stored_values.lock().await.clone(); let host = Arc::new(RuntimeCellHost { cell_id: cell_id.clone(), @@ -204,10 +244,42 @@ impl SessionRuntime { let (handle, task) = CellActor::prepare(request, stored_values, host, cell_state, execution_policy) .map_err(Error::Runtime)?; - cells.insert(cell_id, handle); + let registered = match kind { + CellKind::Continuing => RegisteredCell::Continuing(handle), + CellKind::Pausable => RegisteredCell::Pausable(handle), + }; + cells.insert(cell_id.clone(), registered); self.inner.cell_tasks.spawn(task); drop(cells); - Ok(()) + Ok(cell_id) + } + + async fn continuing_handle(&self, cell_id: &CellId) -> Result { + self.handle_for_kind(cell_id, CellKind::Continuing).await + } + + async fn pausable_handle(&self, cell_id: &CellId) -> Result { + self.handle_for_kind(cell_id, CellKind::Pausable).await + } + + async fn handle_for_kind( + &self, + cell_id: &CellId, + expected: CellKind, + ) -> Result { + let cells = self.inner.cells.lock().await; + let cell = cells + .get(cell_id) + .ok_or_else(|| Error::MissingCell(cell_id.clone()))?; + let actual = cell.kind(); + if actual != expected { + return Err(Error::WrongCellKind { + cell_id: cell_id.clone(), + expected, + actual, + }); + } + Ok(cell.handle().clone()) } fn begin_shutdown(&self) { @@ -274,7 +346,7 @@ impl CellHost for RuntimeCellHost { async fn commit_completion( &self, stored_value_writes: HashMap, - event: CellEvent, + event: ActorEvent, pending_initial_yield_items: Option>, cell_state: Arc, ) -> CompletionCommit { @@ -297,8 +369,60 @@ impl CellHost for RuntimeCellHost { } } -fn map_actor_event(cell_id: CellId, event: CellEventFuture) -> RuntimeEventFuture { - Box::pin(async move { event.await.map_err(|error| actor_error(&cell_id, error)) }) +fn map_cell_event(cell_id: CellId, event: CellEventFuture) -> RuntimeEventFuture { + Box::pin(async move { + match event.await.map_err(|error| actor_error(&cell_id, error))? { + ActorEvent::Yielded { content_items } => Ok(CellEvent::Yielded { content_items }), + ActorEvent::Completed { + content_items, + error_text, + } => Ok(CellEvent::Completed { + content_items, + error_text, + }), + ActorEvent::Terminated { content_items } => Ok(CellEvent::Terminated { content_items }), + ActorEvent::Pending(_) => Err(Error::Runtime(format!( + "continuing cell {cell_id} unexpectedly reached a visible pending frontier" + ))), + } + }) +} + +fn map_pausable_event(cell_id: CellId, event: CellEventFuture) -> PausableRuntimeEventFuture { + Box::pin(async move { + match event.await.map_err(|error| actor_error(&cell_id, error))? { + ActorEvent::Pending(frontier) => Ok(PausableCellEvent::Pending(frontier)), + ActorEvent::Completed { + content_items, + error_text, + } => Ok(PausableCellEvent::Completed { + content_items, + error_text, + }), + ActorEvent::Terminated { content_items } => { + Ok(PausableCellEvent::Terminated { content_items }) + } + ActorEvent::Yielded { .. } => Err(Error::Runtime(format!( + "pausable cell {cell_id} unexpectedly yielded" + ))), + } + }) +} + +fn map_terminal_event(event: ActorEvent) -> CellEvent { + match event { + ActorEvent::Completed { + content_items, + error_text, + } => CellEvent::Completed { + content_items, + error_text, + }, + ActorEvent::Terminated { content_items } => CellEvent::Terminated { content_items }, + ActorEvent::Yielded { .. } | ActorEvent::Pending(_) => { + panic!("termination returned a non-terminal cell event") + } + } } fn actor_error(cell_id: &CellId, error: CellError) -> Error { diff --git a/codex-rs/code-mode/src/session_runtime/tests.rs b/codex-rs/code-mode/src/session_runtime/tests.rs index bb7d899f3344..f908293a738c 100644 --- a/codex-rs/code-mode/src/session_runtime/tests.rs +++ b/codex-rs/code-mode/src/session_runtime/tests.rs @@ -142,9 +142,7 @@ text("done"); Some("second".to_string()) ); assert_eq!( - runtime - .observe(&cell, ObserveMode::YieldAfter(Duration::from_secs(1))) - .await, + runtime.wait(&cell, Duration::from_secs(1)).await, Ok(CellEvent::Completed { content_items: vec![OutputItem::Text { text: "done".to_string(), @@ -184,7 +182,7 @@ text("done"); Some("first".to_string()) ); let first = runtime.wait_to_pending(&cell).await.unwrap(); - let CellEvent::Pending(first_frontier) = &first else { + let PausableCellEvent::Pending(first_frontier) = &first else { panic!("expected the first pending frontier, got {first:?}"); }; assert_eq!( @@ -221,7 +219,7 @@ text("done"); ); let second = runtime.wait_to_pending(&cell).await.unwrap(); - let CellEvent::Pending(second_frontier) = &second else { + let PausableCellEvent::Pending(second_frontier) = &second else { panic!("expected the second pending frontier, got {second:?}"); }; assert_eq!( @@ -237,7 +235,7 @@ text("done"); .resume(&cell, PendingGeneration::new(/*value*/ 3)) .await, Err(Error::InvalidGeneration { - cell_id: cell.clone(), + cell_id: cell.id().clone(), requested: PendingGeneration::new(/*value*/ 3), latest: Some(PendingGeneration::new(/*value*/ 2)), }) @@ -249,7 +247,7 @@ text("done"); release.add_permits(1); assert_eq!( runtime.wait_to_pending(&cell).await, - Ok(CellEvent::Completed { + Ok(PausableCellEvent::Completed { content_items: vec![OutputItem::Text { text: "done".to_string(), }], @@ -293,7 +291,8 @@ text("done"); invocations.sort(); assert_eq!(invocations, vec!["first".to_string(), "second".to_string()]); - let CellEvent::Pending(first_frontier) = runtime.wait_to_pending(&cell).await.unwrap() else { + let PausableCellEvent::Pending(first_frontier) = runtime.wait_to_pending(&cell).await.unwrap() + else { panic!("expected first pending frontier"); }; assert_eq!( @@ -307,7 +306,8 @@ text("done"); ); release.add_permits(1); - let CellEvent::Pending(second_frontier) = runtime.wait_to_pending(&cell).await.unwrap() else { + let PausableCellEvent::Pending(second_frontier) = runtime.wait_to_pending(&cell).await.unwrap() + else { panic!("expected second pending frontier"); }; assert_eq!(second_frontier.pending_tool_call_ids.len(), 1); @@ -324,7 +324,7 @@ text("done"); release.add_permits(1); assert_eq!( runtime.wait_to_pending(&cell).await, - Ok(CellEvent::Completed { + Ok(PausableCellEvent::Completed { content_items: vec![OutputItem::Text { text: "done".to_string(), }], @@ -334,6 +334,127 @@ text("done"); runtime.shutdown().await.unwrap(); } +#[tokio::test] +async fn pausable_cell_drains_a_parallel_host_frontier_without_duplicate_output() { + let (invocations_tx, mut invocations_rx) = mpsc::unbounded_channel(); + let release = Arc::new(Semaphore::new(0)); + let runtime = SessionRuntime::new(Arc::new(BlockingToolDelegate { + invocations_tx, + release: Arc::clone(&release), + })); + let cell = runtime + .create_pausable_cell(CreateCellRequest { + tool_call_id: "call-1".to_string(), + enabled_tools: vec![tool_definition("first"), tool_definition("second")], + source: r#" +await Promise.all([tools.first({}), tools.second({})]); +text("done"); +"# + .to_string(), + }) + .await + .unwrap(); + + let mut invocations = Vec::new(); + for _ in 0..2 { + invocations.push( + tokio::time::timeout(Duration::from_secs(1), invocations_rx.recv()) + .await + .expect("parallel tool invocation timed out") + .expect("parallel invocation channel closed"), + ); + } + invocations.sort(); + assert_eq!(invocations, vec!["first".to_string(), "second".to_string()]); + + let first_event = runtime.wait_to_pending(&cell).await.unwrap(); + let PausableCellEvent::Pending(first_frontier) = first_event else { + panic!("expected the parallel pending frontier"); + }; + assert_eq!( + first_frontier.generation, + PendingGeneration::new(/*value*/ 1) + ); + assert_eq!(first_frontier.pending_tool_call_ids.len(), 2); + assert_eq!( + runtime.wait_to_pending(&cell).await, + Ok(PausableCellEvent::Pending(first_frontier.clone())) + ); + + release.add_permits(2); + let mut generation = first_frontier.generation; + let mut accumulated_content = first_frontier.content_items; + let mut completed = false; + for _ in 0..3 { + assert_eq!( + runtime.resume(&cell, generation).await, + Ok(ResumeOutcome::Resumed) + ); + match tokio::time::timeout(Duration::from_secs(1), runtime.wait_to_pending(&cell)) + .await + .expect("synchronous driver timed out") + .unwrap() + { + PausableCellEvent::Pending(frontier) => { + assert!(frontier.generation > generation); + generation = frontier.generation; + accumulated_content.extend(frontier.content_items); + } + PausableCellEvent::Completed { + content_items, + error_text, + } => { + assert_eq!(error_text, None); + accumulated_content.extend(content_items); + completed = true; + break; + } + PausableCellEvent::Terminated { content_items } => { + panic!("synchronous driver terminated with output: {content_items:?}"); + } + } + } + assert!(completed, "synchronous driver did not reach completion"); + assert_eq!( + accumulated_content, + vec![OutputItem::Text { + text: "done".to_string(), + }] + ); + runtime.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn cell_capabilities_reject_ids_of_the_other_kind() { + let runtime = SessionRuntime::new(Arc::new(RecordingDelegate)); + let cell = runtime + .create_cell(execute_request("await new Promise(() => {});")) + .await + .unwrap(); + let pausable = runtime + .create_pausable_cell(execute_request("await new Promise(() => {});")) + .await + .unwrap(); + + assert_eq!( + runtime.pausable_cell(cell.id()).await, + Err(Error::WrongCellKind { + cell_id: cell.id().clone(), + expected: CellKind::Pausable, + actual: CellKind::Continuing, + }) + ); + assert_eq!( + runtime.cell(pausable.id()).await, + Err(Error::WrongCellKind { + cell_id: pausable.id().clone(), + expected: CellKind::Continuing, + actual: CellKind::Pausable, + }) + ); + runtime.shutdown().await.unwrap(); +} + #[tokio::test] async fn pending_observation_waits_for_resumed_work_to_reach_a_new_frontier() { let (invocations_tx, mut invocations_rx) = mpsc::unbounded_channel(); @@ -360,7 +481,7 @@ text("done"); .expect("tool invocation timed out"), Some("blocked".to_string()) ); - let CellEvent::Pending(frontier) = runtime.wait_to_pending(&cell).await.unwrap() else { + let PausableCellEvent::Pending(frontier) = runtime.wait_to_pending(&cell).await.unwrap() else { panic!("expected a pending frontier"); }; @@ -378,7 +499,7 @@ text("done"); release.add_permits(1); assert_eq!( next_event.await, - Ok(CellEvent::Completed { + Ok(PausableCellEvent::Completed { content_items: vec![OutputItem::Text { text: "done".to_string(), }], @@ -396,7 +517,7 @@ async fn termination_rejects_a_waiting_store_commit_before_the_next_cell_can_loa cell_id: CellId::new("terminating-writer"), inner: Arc::clone(&runtime.inner), }; - let completion = CellEvent::Completed { + let completion = ActorEvent::Completed { content_items: vec![OutputItem::Text { text: "uncommitted output".to_string(), }], @@ -421,7 +542,7 @@ async fn termination_rejects_a_waiting_store_commit_before_the_next_cell_can_loa let termination = cell_state.request_termination(); drop(stored_values); assert_eq!(commit.await, CompletionCommit::Rejected(completion)); - let terminated = CellEvent::Terminated { + let terminated = ActorEvent::Terminated { content_items: Vec::new(), }; assert_eq!( @@ -447,9 +568,7 @@ async fn termination_rejects_a_waiting_store_commit_before_the_next_cell_can_loa .await .unwrap(); assert_eq!( - runtime - .observe(&reader, ObserveMode::YieldAfter(Duration::from_secs(1))) - .await, + runtime.wait(&reader, Duration::from_secs(1)).await, Ok(CellEvent::Completed { content_items: vec![OutputItem::Text { text: "undefined".to_string(), @@ -512,13 +631,10 @@ async fn drop_terminates_cells_when_the_registry_is_locked() { .create_cell(execute_request("while (true) {}")) .await .unwrap(); - assert_eq!(cell, CellId::new("1")); + assert_eq!(cell.id(), &CellId::new("1")); assert_eq!( runtime - .observe( - &cell, - ObserveMode::YieldAfter(Duration::from_millis(/*millis*/ 1)), - ) + .wait(&cell, Duration::from_millis(/*millis*/ 1)) .await, Ok(CellEvent::Yielded { content_items: Vec::new(), diff --git a/codex-rs/code-mode/src/session_runtime/types.rs b/codex-rs/code-mode/src/session_runtime/types.rs index 21e8efd966f8..4b45ac6321a1 100644 --- a/codex-rs/code-mode/src/session_runtime/types.rs +++ b/codex-rs/code-mode/src/session_runtime/types.rs @@ -1,6 +1,5 @@ use std::fmt; use std::future::Future; -use std::time::Duration; use serde_json::Value as JsonValue; use tokio_util::sync::CancellationToken; @@ -26,19 +25,60 @@ impl fmt::Display for CellId { } /// Controls how a cell advances when its runtime is waiting for external input. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] pub(crate) enum CellExecutionPolicy { /// Process tool and timer results even when no observation is attached. + #[default] ContinueWhenUnblocked, /// Remain paused at a pending frontier until an explicit resume advances it. PauseAtPendingFrontier, } -/// Selects the next observable frontier for a running cell. +/// A cell that continues whenever external input unblocks its runtime. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub(crate) struct Cell { + id: CellId, +} + +impl Cell { + pub(super) fn new(id: CellId) -> Self { + Self { id } + } + + pub(crate) fn id(&self) -> &CellId { + &self.id + } +} + +/// A cell that remains paused at each pending frontier until explicitly resumed. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub(crate) struct PausableCell { + id: CellId, +} + +impl PausableCell { + pub(super) fn new(id: CellId) -> Self { + Self { id } + } + + pub(crate) fn id(&self) -> &CellId { + &self.id + } +} + #[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub(crate) enum ObserveMode { - YieldAfter(Duration), - PendingFrontier, +pub(crate) enum CellKind { + Continuing, + Pausable, +} + +impl fmt::Display for CellKind { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Continuing => formatter.write_str("continuing"), + Self::Pausable => formatter.write_str("pausable"), + } + } } /// Identifies one durable pending frontier of a pausable cell. @@ -69,6 +109,18 @@ pub(crate) enum CellEvent { Yielded { content_items: Vec, }, + Completed { + content_items: Vec, + error_text: Option, + }, + Terminated { + content_items: Vec, + }, +} + +/// An observable lifecycle event for a pausable cell. +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum PausableCellEvent { Pending(PendingFrontier), Completed { content_items: Vec, @@ -177,6 +229,11 @@ pub(crate) enum Error { BusyObserver(CellId), AlreadyTerminating(CellId), ClosedCell(CellId), + WrongCellKind { + cell_id: CellId, + expected: CellKind, + actual: CellKind, + }, InvalidGeneration { cell_id: CellId, requested: PendingGeneration, @@ -203,6 +260,14 @@ impl fmt::Display for Error { Self::ClosedCell(cell_id) => { write!(formatter, "exec cell {cell_id} closed unexpectedly") } + Self::WrongCellKind { + cell_id, + expected, + actual, + } => write!( + formatter, + "exec cell {cell_id} is {actual}, expected {expected}" + ), Self::InvalidGeneration { cell_id, requested,