diff --git a/codex-rs/code-mode/src/service.rs b/codex-rs/code-mode/src/service.rs index 9f2df24793f4..ece08c598e03 100644 --- a/codex-rs/code-mode/src/service.rs +++ b/codex-rs/code-mode/src/service.rs @@ -348,7 +348,9 @@ impl runtime::SessionRuntimeDelegate for ProtocolDelegate { } fn runtime_request(request: ExecuteRequest) -> runtime::CreateCellRequest { + let idempotency_key = format!("{}:{}", request.tool_call_id, request.source); runtime::CreateCellRequest { + idempotency_key, tool_call_id: request.tool_call_id, enabled_tools: request .enabled_tools diff --git a/codex-rs/code-mode/src/service_contract_tests.rs b/codex-rs/code-mode/src/service_contract_tests.rs index 8c07da9a8f3d..924adaa3949c 100644 --- a/codex-rs/code-mode/src/service_contract_tests.rs +++ b/codex-rs/code-mode/src/service_contract_tests.rs @@ -287,9 +287,12 @@ async fn observed_natural_completion_wins_over_termination() { } ); tokio::time::timeout(Duration::from_secs(1), async { + let mut probe_index = 0; loop { + probe_index += 1; let response = service .execute(ExecuteRequest { + tool_call_id: format!("probe-{probe_index}"), yield_time_ms: Some(60_000), ..execute_request(r#"text(String(load("finished")));"#) }) diff --git a/codex-rs/code-mode/src/session_runtime/mod.rs b/codex-rs/code-mode/src/session_runtime/mod.rs index b66b38a18805..5357ab7a776a 100644 --- a/codex-rs/code-mode/src/session_runtime/mod.rs +++ b/codex-rs/code-mode/src/session_runtime/mod.rs @@ -10,6 +10,7 @@ use std::time::Duration; use serde_json::Value as JsonValue; use tokio::sync::Mutex; +use tokio::sync::OnceCell; use tokio_util::sync::CancellationToken; use tokio_util::task::TaskTracker; @@ -55,12 +56,19 @@ pub(crate) struct SessionRuntime { struct Inner { stored_values: Mutex>, cells: Mutex>, + created_cells: Mutex>>>, cell_tasks: TaskTracker, shutdown_token: CancellationToken, delegate: Arc, next_cell_id: AtomicU64, } +#[derive(Clone)] +struct IdempotentCell { + id: CellId, + kind: CellKind, +} + #[derive(Clone)] enum RegisteredCell { Continuing(CellHandle), @@ -88,6 +96,7 @@ impl SessionRuntime { inner: Arc::new(Inner { stored_values: Mutex::new(HashMap::new()), cells: Mutex::new(HashMap::new()), + created_cells: Mutex::new(HashMap::new()), cell_tasks: TaskTracker::new(), shutdown_token: CancellationToken::new(), delegate, @@ -102,7 +111,7 @@ impl SessionRuntime { pub(crate) async fn create_cell(&self, request: CreateCellRequest) -> Result { let id = self - .start_cell( + .create_idempotent_cell( request, CellExecutionPolicy::ContinueWhenUnblocked, CellKind::Continuing, @@ -116,7 +125,7 @@ impl SessionRuntime { request: CreateCellRequest, ) -> Result { let id = self - .start_cell( + .create_idempotent_cell( request, CellExecutionPolicy::PauseAtPendingFrontier, CellKind::Pausable, @@ -254,6 +263,37 @@ impl SessionRuntime { Ok(cell_id) } + async fn create_idempotent_cell( + &self, + request: CreateCellRequest, + execution_policy: CellExecutionPolicy, + kind: CellKind, + ) -> Result { + let key = request.idempotency_key.clone(); + let created_cell = { + let mut created_cells = self.inner.created_cells.lock().await; + Arc::clone( + created_cells + .entry(key) + .or_insert_with(|| Arc::new(OnceCell::new())), + ) + }; + let existing = created_cell + .get_or_try_init(|| async move { + let id = self.start_cell(request, execution_policy, kind).await?; + Ok::(IdempotentCell { id, kind }) + }) + .await?; + if existing.kind != kind { + return Err(Error::WrongCellKind { + cell_id: existing.id.clone(), + expected: kind, + actual: existing.kind, + }); + } + Ok(existing.id.clone()) + } + async fn continuing_handle(&self, cell_id: &CellId) -> Result { self.handle_for_kind(cell_id, CellKind::Continuing).await } diff --git a/codex-rs/code-mode/src/session_runtime/tests.rs b/codex-rs/code-mode/src/session_runtime/tests.rs index f908293a738c..ca6069841f82 100644 --- a/codex-rs/code-mode/src/session_runtime/tests.rs +++ b/codex-rs/code-mode/src/session_runtime/tests.rs @@ -117,6 +117,7 @@ async fn default_policy_resolves_tools_before_the_first_observation() { let runtime = SessionRuntime::new(Arc::new(ImmediateToolDelegate { invocations_tx })); let cell = runtime .create_cell(CreateCellRequest { + idempotency_key: "default-policy".to_string(), tool_call_id: "call-1".to_string(), enabled_tools: vec![tool_definition("first"), tool_definition("second")], source: r#" @@ -153,6 +154,31 @@ text("done"); runtime.shutdown().await.unwrap(); } +#[tokio::test] +async fn concurrent_create_retries_return_the_same_cell_for_an_idempotency_key() { + let runtime = SessionRuntime::new(Arc::new(RecordingDelegate)); + let source = "await new Promise(() => {});"; + let (first, retry) = tokio::join!( + runtime.create_cell(execute_request(source)), + runtime.create_cell(execute_request(source)), + ); + let first = first.unwrap(); + let retry = retry.unwrap(); + + assert_eq!(retry, first); + assert_eq!(runtime.inner.cells.lock().await.len(), 1); + assert_eq!( + runtime.create_pausable_cell(execute_request(source)).await, + Err(Error::WrongCellKind { + cell_id: first.id().clone(), + expected: CellKind::Pausable, + actual: CellKind::Continuing, + }) + ); + + runtime.shutdown().await.unwrap(); +} + #[tokio::test] async fn pausable_cell_supports_a_synchronous_host_driver() { let (invocations_tx, mut invocations_rx) = mpsc::unbounded_channel(); @@ -163,6 +189,7 @@ async fn pausable_cell_supports_a_synchronous_host_driver() { })); let cell = runtime .create_pausable_cell(CreateCellRequest { + idempotency_key: "synchronous-driver".to_string(), tool_call_id: "call-1".to_string(), enabled_tools: vec![tool_definition("first"), tool_definition("second")], source: r#" @@ -267,6 +294,7 @@ async fn pending_frontier_reports_only_authoritatively_outstanding_parallel_tool })); let cell = runtime .create_pausable_cell(CreateCellRequest { + idempotency_key: "authoritative-outstanding".to_string(), tool_call_id: "call-1".to_string(), enabled_tools: vec![tool_definition("first"), tool_definition("second")], source: r#" @@ -344,6 +372,7 @@ async fn pausable_cell_drains_a_parallel_host_frontier_without_duplicate_output( })); let cell = runtime .create_pausable_cell(CreateCellRequest { + idempotency_key: "parallel-drain".to_string(), tool_call_id: "call-1".to_string(), enabled_tools: vec![tool_definition("first"), tool_definition("second")], source: r#" @@ -432,7 +461,10 @@ async fn cell_capabilities_reject_ids_of_the_other_kind() { .await .unwrap(); let pausable = runtime - .create_pausable_cell(execute_request("await new Promise(() => {});")) + .create_pausable_cell(CreateCellRequest { + idempotency_key: "pausable-cell".to_string(), + ..execute_request("await new Promise(() => {});") + }) .await .unwrap(); @@ -465,6 +497,7 @@ async fn pending_observation_waits_for_resumed_work_to_reach_a_new_frontier() { })); let cell = runtime .create_pausable_cell(CreateCellRequest { + idempotency_key: "pending-observation".to_string(), tool_call_id: "call-1".to_string(), enabled_tools: vec![tool_definition("blocked")], source: r#" @@ -561,6 +594,7 @@ async fn termination_rejects_a_waiting_store_commit_before_the_next_cell_can_loa let reader = runtime .create_cell(CreateCellRequest { + idempotency_key: "reader".to_string(), tool_call_id: "reader".to_string(), enabled_tools: Vec::new(), source: r#"text(String(load("candidate")));"#.to_string(), @@ -581,6 +615,7 @@ async fn termination_rejects_a_waiting_store_commit_before_the_next_cell_can_loa fn execute_request(source: &str) -> CreateCellRequest { CreateCellRequest { + idempotency_key: format!("call-1:{source}"), tool_call_id: "call-1".to_string(), enabled_tools: Vec::new(), source: source.to_string(), diff --git a/codex-rs/code-mode/src/session_runtime/types.rs b/codex-rs/code-mode/src/session_runtime/types.rs index 4b45ac6321a1..34d532e3bdfb 100644 --- a/codex-rs/code-mode/src/session_runtime/types.rs +++ b/codex-rs/code-mode/src/session_runtime/types.rs @@ -162,6 +162,7 @@ pub(crate) enum ImageDetail { /// /// The owning session assigns the cell ID when it admits the request. pub(crate) struct CreateCellRequest { + pub(crate) idempotency_key: String, pub(crate) tool_call_id: String, pub(crate) enabled_tools: Vec, pub(crate) source: String,