Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions codex-rs/code-mode/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,9 @@ impl runtime::SessionRuntimeDelegate for ProtocolDelegate {
}

fn runtime_request(request: ExecuteRequest) -> runtime::CreateCellRequest {
let idempotency_key = format!("{}:{}", request.tool_call_id, request.source);
runtime::CreateCellRequest {
idempotency_key,
tool_call_id: request.tool_call_id,
enabled_tools: request
.enabled_tools
Expand Down
3 changes: 3 additions & 0 deletions codex-rs/code-mode/src/service_contract_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,12 @@ async fn observed_natural_completion_wins_over_termination() {
}
);
tokio::time::timeout(Duration::from_secs(1), async {
let mut probe_index = 0;
loop {
probe_index += 1;
let response = service
.execute(ExecuteRequest {
tool_call_id: format!("probe-{probe_index}"),
yield_time_ms: Some(60_000),
..execute_request(r#"text(String(load("finished")));"#)
})
Expand Down
44 changes: 42 additions & 2 deletions codex-rs/code-mode/src/session_runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::time::Duration;

use serde_json::Value as JsonValue;
use tokio::sync::Mutex;
use tokio::sync::OnceCell;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;

Expand Down Expand Up @@ -55,12 +56,19 @@ pub(crate) struct SessionRuntime<D: SessionRuntimeDelegate> {
struct Inner<D: SessionRuntimeDelegate> {
stored_values: Mutex<HashMap<String, JsonValue>>,
cells: Mutex<HashMap<CellId, RegisteredCell>>,
created_cells: Mutex<HashMap<String, Arc<OnceCell<IdempotentCell>>>>,
cell_tasks: TaskTracker,
shutdown_token: CancellationToken,
delegate: Arc<D>,
next_cell_id: AtomicU64,
}

#[derive(Clone)]
struct IdempotentCell {
id: CellId,
kind: CellKind,
}

#[derive(Clone)]
enum RegisteredCell {
Continuing(CellHandle),
Expand Down Expand Up @@ -88,6 +96,7 @@ impl<D: SessionRuntimeDelegate> SessionRuntime<D> {
inner: Arc::new(Inner {
stored_values: Mutex::new(HashMap::new()),
cells: Mutex::new(HashMap::new()),
created_cells: Mutex::new(HashMap::new()),
cell_tasks: TaskTracker::new(),
shutdown_token: CancellationToken::new(),
delegate,
Expand All @@ -102,7 +111,7 @@ impl<D: SessionRuntimeDelegate> SessionRuntime<D> {

pub(crate) async fn create_cell(&self, request: CreateCellRequest) -> Result<Cell, Error> {
let id = self
.start_cell(
.create_idempotent_cell(
request,
CellExecutionPolicy::ContinueWhenUnblocked,
CellKind::Continuing,
Expand All @@ -116,7 +125,7 @@ impl<D: SessionRuntimeDelegate> SessionRuntime<D> {
request: CreateCellRequest,
) -> Result<PausableCell, Error> {
let id = self
.start_cell(
.create_idempotent_cell(
request,
CellExecutionPolicy::PauseAtPendingFrontier,
CellKind::Pausable,
Expand Down Expand Up @@ -254,6 +263,37 @@ impl<D: SessionRuntimeDelegate> SessionRuntime<D> {
Ok(cell_id)
}

async fn create_idempotent_cell(
&self,
request: CreateCellRequest,
execution_policy: CellExecutionPolicy,
kind: CellKind,
) -> Result<CellId, Error> {
let key = request.idempotency_key.clone();
let created_cell = {
let mut created_cells = self.inner.created_cells.lock().await;
Arc::clone(
created_cells
.entry(key)
.or_insert_with(|| Arc::new(OnceCell::new())),
)
};
let existing = created_cell
.get_or_try_init(|| async move {
let id = self.start_cell(request, execution_policy, kind).await?;
Ok::<IdempotentCell, Error>(IdempotentCell { id, kind })
})
.await?;
if existing.kind != kind {
return Err(Error::WrongCellKind {
cell_id: existing.id.clone(),
expected: kind,
actual: existing.kind,
});
}
Ok(existing.id.clone())
}

async fn continuing_handle(&self, cell_id: &CellId) -> Result<CellHandle, Error> {
self.handle_for_kind(cell_id, CellKind::Continuing).await
}
Expand Down
37 changes: 36 additions & 1 deletion codex-rs/code-mode/src/session_runtime/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ async fn default_policy_resolves_tools_before_the_first_observation() {
let runtime = SessionRuntime::new(Arc::new(ImmediateToolDelegate { invocations_tx }));
let cell = runtime
.create_cell(CreateCellRequest {
idempotency_key: "default-policy".to_string(),
tool_call_id: "call-1".to_string(),
enabled_tools: vec![tool_definition("first"), tool_definition("second")],
source: r#"
Expand Down Expand Up @@ -153,6 +154,31 @@ text("done");
runtime.shutdown().await.unwrap();
}

#[tokio::test]
async fn concurrent_create_retries_return_the_same_cell_for_an_idempotency_key() {
let runtime = SessionRuntime::new(Arc::new(RecordingDelegate));
let source = "await new Promise(() => {});";
let (first, retry) = tokio::join!(
runtime.create_cell(execute_request(source)),
runtime.create_cell(execute_request(source)),
);
let first = first.unwrap();
let retry = retry.unwrap();

assert_eq!(retry, first);
assert_eq!(runtime.inner.cells.lock().await.len(), 1);
assert_eq!(
runtime.create_pausable_cell(execute_request(source)).await,
Err(Error::WrongCellKind {
cell_id: first.id().clone(),
expected: CellKind::Pausable,
actual: CellKind::Continuing,
})
);

runtime.shutdown().await.unwrap();
}

#[tokio::test]
async fn pausable_cell_supports_a_synchronous_host_driver() {
let (invocations_tx, mut invocations_rx) = mpsc::unbounded_channel();
Expand All @@ -163,6 +189,7 @@ async fn pausable_cell_supports_a_synchronous_host_driver() {
}));
let cell = runtime
.create_pausable_cell(CreateCellRequest {
idempotency_key: "synchronous-driver".to_string(),
tool_call_id: "call-1".to_string(),
enabled_tools: vec![tool_definition("first"), tool_definition("second")],
source: r#"
Expand Down Expand Up @@ -267,6 +294,7 @@ async fn pending_frontier_reports_only_authoritatively_outstanding_parallel_tool
}));
let cell = runtime
.create_pausable_cell(CreateCellRequest {
idempotency_key: "authoritative-outstanding".to_string(),
tool_call_id: "call-1".to_string(),
enabled_tools: vec![tool_definition("first"), tool_definition("second")],
source: r#"
Expand Down Expand Up @@ -344,6 +372,7 @@ async fn pausable_cell_drains_a_parallel_host_frontier_without_duplicate_output(
}));
let cell = runtime
.create_pausable_cell(CreateCellRequest {
idempotency_key: "parallel-drain".to_string(),
tool_call_id: "call-1".to_string(),
enabled_tools: vec![tool_definition("first"), tool_definition("second")],
source: r#"
Expand Down Expand Up @@ -432,7 +461,10 @@ async fn cell_capabilities_reject_ids_of_the_other_kind() {
.await
.unwrap();
let pausable = runtime
.create_pausable_cell(execute_request("await new Promise(() => {});"))
.create_pausable_cell(CreateCellRequest {
idempotency_key: "pausable-cell".to_string(),
..execute_request("await new Promise(() => {});")
})
.await
.unwrap();

Expand Down Expand Up @@ -465,6 +497,7 @@ async fn pending_observation_waits_for_resumed_work_to_reach_a_new_frontier() {
}));
let cell = runtime
.create_pausable_cell(CreateCellRequest {
idempotency_key: "pending-observation".to_string(),
tool_call_id: "call-1".to_string(),
enabled_tools: vec![tool_definition("blocked")],
source: r#"
Expand Down Expand Up @@ -561,6 +594,7 @@ async fn termination_rejects_a_waiting_store_commit_before_the_next_cell_can_loa

let reader = runtime
.create_cell(CreateCellRequest {
idempotency_key: "reader".to_string(),
tool_call_id: "reader".to_string(),
enabled_tools: Vec::new(),
source: r#"text(String(load("candidate")));"#.to_string(),
Expand All @@ -581,6 +615,7 @@ async fn termination_rejects_a_waiting_store_commit_before_the_next_cell_can_loa

fn execute_request(source: &str) -> CreateCellRequest {
CreateCellRequest {
idempotency_key: format!("call-1:{source}"),
tool_call_id: "call-1".to_string(),
enabled_tools: Vec::new(),
source: source.to_string(),
Expand Down
1 change: 1 addition & 0 deletions codex-rs/code-mode/src/session_runtime/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ pub(crate) enum ImageDetail {
///
/// The owning session assigns the cell ID when it admits the request.
pub(crate) struct CreateCellRequest {
pub(crate) idempotency_key: String,
pub(crate) tool_call_id: String,
pub(crate) enabled_tools: Vec<ToolDefinition>,
pub(crate) source: String,
Expand Down
Loading