diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 90acb1e7726..fc8985497e6 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -819,6 +819,8 @@ version = "1.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -985,6 +987,7 @@ dependencies = [ "tokio-util", "tracing", "wiremock", + "zstd", ] [[package]] @@ -1348,6 +1351,7 @@ dependencies = [ "which", "wildmatch", "wiremock", + "zstd", ] [[package]] @@ -2109,16 +2113,19 @@ dependencies = [ "codex-protocol", "codex-utils-absolute-path", "codex-utils-cargo-bin", + "http 1.3.1", "notify", "pretty_assertions", "regex-lite", "reqwest", + "serde", "serde_json", "shlex", "tempfile", "tokio", "walkdir", "wiremock", + "zstd", ] [[package]] @@ -3924,6 +3931,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.3", + "libc", +] + [[package]] name = "js-sys" version = "0.3.77" @@ -8809,6 +8826,34 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "zune-core" version = "0.4.12" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 645d5aa2107..98241a8a2ee 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -235,6 +235,7 @@ wildmatch = "2.6.1" wiremock = "0.6" zeroize = "1.8.2" +zstd = "0.13" [workspace.lints] rust = {} diff --git a/codex-rs/codex-api/Cargo.toml b/codex-rs/codex-api/Cargo.toml index e9fc78878b1..3d14848f91d 100644 --- a/codex-rs/codex-api/Cargo.toml +++ b/codex-rs/codex-api/Cargo.toml @@ -19,6 +19,7 @@ tracing = { workspace = true } eventsource-stream = { workspace = true } regex-lite = { workspace = true } tokio-util = { workspace = true, features = ["codec"] } +zstd = { workspace = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/codex-rs/codex-api/src/endpoint/chat.rs b/codex-rs/codex-api/src/endpoint/chat.rs index b7fa0572f0e..45f581222eb 100644 --- a/codex-rs/codex-api/src/endpoint/chat.rs +++ b/codex-rs/codex-api/src/endpoint/chat.rs @@ -6,7 +6,10 @@ use crate::common::ResponseStream; use crate::endpoint::streaming::StreamingClient; use crate::error::ApiError; use crate::provider::Provider; +use crate::provider::RequestCompression; use crate::provider::WireApi; +use crate::requests::body::encode_body; +use crate::requests::body::insert_compression_headers; use crate::sse::chat::spawn_chat_stream; use crate::telemetry::SseTelemetry; use codex_client::HttpTransport; @@ -45,8 +48,13 @@ impl ChatClient { } } - pub async fn stream_request(&self, request: ChatRequest) -> Result { - self.stream(request.body, request.headers).await + pub async fn stream_request( + &self, + request: ChatRequest, + request_compression: RequestCompression, + ) -> Result { + self.stream(request.body, request.headers, request_compression) + .await } pub async fn stream_prompt( @@ -55,6 +63,7 @@ impl ChatClient { prompt: &ApiPrompt, conversation_id: Option, session_source: Option, + request_compression: RequestCompression, ) -> Result { use crate::requests::ChatRequestBuilder; @@ -64,7 +73,7 @@ impl ChatClient { .session_source(session_source) .build(self.streaming.provider())?; - self.stream_request(request).await + self.stream_request(request, request_compression).await } fn path(&self) -> &'static str { @@ -78,9 +87,13 @@ impl ChatClient { &self, body: Value, extra_headers: HeaderMap, + request_compression: RequestCompression, ) -> Result { + let mut headers = extra_headers; + insert_compression_headers(&mut headers, request_compression); + let encoded_body = encode_body(&body, request_compression)?; self.streaming - .stream(self.path(), body, extra_headers, spawn_chat_stream) + .stream(self.path(), encoded_body, headers, spawn_chat_stream) .await } } diff --git a/codex-rs/codex-api/src/endpoint/compact.rs b/codex-rs/codex-api/src/endpoint/compact.rs index 2b02ebd0f09..d661bbeba1e 100644 --- a/codex-rs/codex-api/src/endpoint/compact.rs +++ b/codex-rs/codex-api/src/endpoint/compact.rs @@ -5,6 +5,7 @@ use crate::error::ApiError; use crate::provider::Provider; use crate::provider::WireApi; use crate::telemetry::run_with_request_telemetry; +use codex_client::Body; use codex_client::HttpTransport; use codex_client::RequestTelemetry; use codex_protocol::models::ResponseItem; @@ -54,7 +55,7 @@ impl CompactClient { let builder = || { let mut req = self.provider.build_request(Method::POST, path); req.headers.extend(extra_headers.clone()); - req.body = Some(body.clone()); + req.body = Some(Body::Json(body.clone())); add_auth_headers(&self.auth, req) }; @@ -89,6 +90,7 @@ struct CompactHistoryResponse { #[cfg(test)] mod tests { use super::*; + use crate::provider::RetryConfig; use async_trait::async_trait; use codex_client::Request; diff --git a/codex-rs/codex-api/src/endpoint/responses.rs b/codex-rs/codex-api/src/endpoint/responses.rs index 476e8b8f138..b783dd6c67a 100644 --- a/codex-rs/codex-api/src/endpoint/responses.rs +++ b/codex-rs/codex-api/src/endpoint/responses.rs @@ -6,6 +6,7 @@ use crate::common::TextControls; use crate::endpoint::streaming::StreamingClient; use crate::error::ApiError; use crate::provider::Provider; +use crate::provider::RequestCompression; use crate::provider::WireApi; use crate::requests::ResponsesRequest; use crate::requests::ResponsesRequestBuilder; @@ -15,7 +16,6 @@ use codex_client::HttpTransport; use codex_client::RequestTelemetry; use codex_protocol::protocol::SessionSource; use http::HeaderMap; -use serde_json::Value; use std::sync::Arc; use tracing::instrument; @@ -33,6 +33,7 @@ pub struct ResponsesOptions { pub conversation_id: Option, pub session_source: Option, pub extra_headers: HeaderMap, + pub request_compression: RequestCompression, } impl ResponsesClient { @@ -56,7 +57,7 @@ impl ResponsesClient { &self, request: ResponsesRequest, ) -> Result { - self.stream(request.body, request.headers).await + self.stream(request).await } #[instrument(level = "trace", skip_all, err)] @@ -75,6 +76,7 @@ impl ResponsesClient { conversation_id, session_source, extra_headers, + request_compression, } = options; let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input) @@ -88,6 +90,7 @@ impl ResponsesClient { .session_source(session_source) .store_override(store_override) .extra_headers(extra_headers) + .request_compression(request_compression) .build(self.streaming.provider())?; self.stream_request(request).await @@ -100,13 +103,14 @@ impl ResponsesClient { } } - pub async fn stream( - &self, - body: Value, - extra_headers: HeaderMap, - ) -> Result { + pub async fn stream(&self, request: ResponsesRequest) -> Result { self.streaming - .stream(self.path(), body, extra_headers, spawn_response_stream) + .stream( + self.path(), + request.body, + request.headers, + spawn_response_stream, + ) .await } } diff --git a/codex-rs/codex-api/src/endpoint/streaming.rs b/codex-rs/codex-api/src/endpoint/streaming.rs index 156d4084bc8..b6944e75605 100644 --- a/codex-rs/codex-api/src/endpoint/streaming.rs +++ b/codex-rs/codex-api/src/endpoint/streaming.rs @@ -5,12 +5,15 @@ use crate::error::ApiError; use crate::provider::Provider; use crate::telemetry::SseTelemetry; use crate::telemetry::run_with_request_telemetry; +use codex_client::Body; use codex_client::HttpTransport; use codex_client::RequestTelemetry; use codex_client::StreamResponse; use http::HeaderMap; +use http::HeaderValue; use http::Method; -use serde_json::Value; +use http::header::ACCEPT; +use http::header::CONTENT_TYPE; use std::sync::Arc; use std::time::Duration; @@ -50,17 +53,18 @@ impl StreamingClient { pub(crate) async fn stream( &self, path: &str, - body: Value, + body: Body, extra_headers: HeaderMap, spawner: fn(StreamResponse, Duration, Option>) -> ResponseStream, ) -> Result { let builder = || { let mut req = self.provider.build_request(Method::POST, path); req.headers.extend(extra_headers.clone()); - req.headers.insert( - http::header::ACCEPT, - http::HeaderValue::from_static("text/event-stream"), - ); + req.headers + .insert(ACCEPT, HeaderValue::from_static("text/event-stream")); + req.headers + .entry(CONTENT_TYPE) + .or_insert_with(|| HeaderValue::from_static("application/json")); req.body = Some(body.clone()); add_auth_headers(&self.auth, req) }; diff --git a/codex-rs/codex-api/src/provider.rs b/codex-rs/codex-api/src/provider.rs index 8bd5fc9093c..5d79562ce0d 100644 --- a/codex-rs/codex-api/src/provider.rs +++ b/codex-rs/codex-api/src/provider.rs @@ -41,6 +41,13 @@ impl RetryConfig { } } +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum RequestCompression { + #[default] + None, + Zstd, +} + /// HTTP endpoint configuration used to talk to a concrete API deployment. /// /// Encapsulates base URL, default headers, query params, retry policy, and diff --git a/codex-rs/codex-api/src/requests/body.rs b/codex-rs/codex-api/src/requests/body.rs new file mode 100644 index 00000000000..16ed351778e --- /dev/null +++ b/codex-rs/codex-api/src/requests/body.rs @@ -0,0 +1,40 @@ +use crate::error::ApiError; +use crate::provider::RequestCompression; +use bytes::Bytes; +use codex_client::Body; +use http::HeaderMap; +use http::HeaderValue; +use http::header::CONTENT_ENCODING; +use serde_json::Value; +use std::time::Instant; +use tracing::info; +use zstd::stream::encode_all; + +pub(crate) fn encode_body(body: &Value, compression: RequestCompression) -> Result { + match compression { + RequestCompression::None => Ok(Body::Json(body.clone())), + RequestCompression::Zstd => { + let json = serde_json::to_vec(body).map_err(|err| { + ApiError::Stream(format!("failed to encode request body as json: {err}")) + })?; + let started_at = Instant::now(); + let compressed = encode_all(json.as_slice(), 0).map_err(|err| { + ApiError::Stream(format!("failed to compress request body: {err}")) + })?; + let elapsed = started_at.elapsed(); + info!( + input_bytes = json.len(), + output_bytes = compressed.len(), + elapsed_ms = elapsed.as_millis(), + "compressed request body" + ); + Ok(Body::Bytes(Bytes::from(compressed))) + } + } +} + +pub(crate) fn insert_compression_headers(headers: &mut HeaderMap, compression: RequestCompression) { + if matches!(compression, RequestCompression::Zstd) { + headers.insert(CONTENT_ENCODING, HeaderValue::from_static("zstd")); + } +} diff --git a/codex-rs/codex-api/src/requests/chat.rs b/codex-rs/codex-api/src/requests/chat.rs index 60f450ca0d1..a000354bd94 100644 --- a/codex-rs/codex-api/src/requests/chat.rs +++ b/codex-rs/codex-api/src/requests/chat.rs @@ -351,6 +351,7 @@ fn push_tool_call_message(messages: &mut Vec, tool_call: Value, reasoning #[cfg(test)] mod tests { use super::*; + use crate::provider::RetryConfig; use crate::provider::WireApi; use codex_protocol::models::FunctionCallOutputPayload; diff --git a/codex-rs/codex-api/src/requests/mod.rs b/codex-rs/codex-api/src/requests/mod.rs index f0ab23a25fa..7bf5a798761 100644 --- a/codex-rs/codex-api/src/requests/mod.rs +++ b/codex-rs/codex-api/src/requests/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod body; pub mod chat; pub(crate) mod headers; pub mod responses; diff --git a/codex-rs/codex-api/src/requests/responses.rs b/codex-rs/codex-api/src/requests/responses.rs index 543b79bbe9d..4d4171b75a0 100644 --- a/codex-rs/codex-api/src/requests/responses.rs +++ b/codex-rs/codex-api/src/requests/responses.rs @@ -3,9 +3,13 @@ use crate::common::ResponsesApiRequest; use crate::common::TextControls; use crate::error::ApiError; use crate::provider::Provider; +use crate::provider::RequestCompression; +use crate::requests::body::encode_body; +use crate::requests::body::insert_compression_headers; use crate::requests::headers::build_conversation_headers; use crate::requests::headers::insert_header; use crate::requests::headers::subagent_header; +use codex_client::Body; use codex_protocol::models::ResponseItem; use codex_protocol::protocol::SessionSource; use http::HeaderMap; @@ -13,7 +17,7 @@ use serde_json::Value; /// Assembled request body plus headers for a Responses stream request. pub struct ResponsesRequest { - pub body: Value, + pub body: Body, pub headers: HeaderMap, } @@ -32,6 +36,7 @@ pub struct ResponsesRequestBuilder<'a> { session_source: Option, store_override: Option, headers: HeaderMap, + request_compression: RequestCompression, } impl<'a> ResponsesRequestBuilder<'a> { @@ -94,6 +99,11 @@ impl<'a> ResponsesRequestBuilder<'a> { self } + pub fn request_compression(mut self, request_compression: RequestCompression) -> Self { + self.request_compression = request_compression; + self + } + pub fn build(self, provider: &Provider) -> Result { let model = self .model @@ -137,6 +147,8 @@ impl<'a> ResponsesRequestBuilder<'a> { if let Some(subagent) = subagent_header(&self.session_source) { insert_header(&mut headers, "x-openai-subagent", &subagent); } + insert_compression_headers(&mut headers, self.request_compression); + let body = encode_body(&body, self.request_compression)?; Ok(ResponsesRequest { body, headers }) } @@ -172,8 +184,10 @@ fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) { #[cfg(test)] mod tests { use super::*; + use crate::provider::RetryConfig; use crate::provider::WireApi; + use codex_client::Body; use codex_protocol::protocol::SubAgentSource; use http::HeaderValue; use pretty_assertions::assert_eq; @@ -219,10 +233,12 @@ mod tests { .build(&provider) .expect("request"); - assert_eq!(request.body.get("store"), Some(&Value::Bool(true))); + let Body::Json(body) = &request.body else { + panic!("expected json body for responses request"); + }; + assert_eq!(body.get("store"), Some(&Value::Bool(true))); - let ids: Vec> = request - .body + let ids: Vec> = body .get("input") .and_then(|v| v.as_array()) .into_iter() diff --git a/codex-rs/codex-api/tests/clients.rs b/codex-rs/codex-api/tests/clients.rs index 3dafaf74fae..cf1b5b7e4a1 100644 --- a/codex-rs/codex-api/tests/clients.rs +++ b/codex-rs/codex-api/tests/clients.rs @@ -10,7 +10,9 @@ use codex_api::ChatClient; use codex_api::Provider; use codex_api::ResponsesClient; use codex_api::ResponsesOptions; +use codex_api::ResponsesRequest; use codex_api::WireApi; +use codex_client::Body; use codex_client::HttpTransport; use codex_client::Request; use codex_client::Response; @@ -136,6 +138,13 @@ fn provider(name: &str, wire: WireApi) -> Provider { } } +fn responses_request(body: Value) -> ResponsesRequest { + ResponsesRequest { + body: Body::Json(body), + headers: HeaderMap::new(), + } +} + #[derive(Clone)] struct FlakyTransport { state: Arc>, @@ -201,7 +210,9 @@ async fn chat_client_uses_chat_completions_path_for_chat_wire() -> Result<()> { let client = ChatClient::new(transport, provider("openai", WireApi::Chat), NoAuth); let body = serde_json::json!({ "echo": true }); - let _stream = client.stream(body, HeaderMap::new()).await?; + let _stream = client + .stream(body, HeaderMap::new(), Default::default()) + .await?; let requests = state.take_stream_requests(); assert_path_ends_with(&requests, "/chat/completions"); @@ -215,7 +226,9 @@ async fn chat_client_uses_responses_path_for_responses_wire() -> Result<()> { let client = ChatClient::new(transport, provider("openai", WireApi::Responses), NoAuth); let body = serde_json::json!({ "echo": true }); - let _stream = client.stream(body, HeaderMap::new()).await?; + let _stream = client + .stream(body, HeaderMap::new(), Default::default()) + .await?; let requests = state.take_stream_requests(); assert_path_ends_with(&requests, "/responses"); @@ -228,8 +241,8 @@ async fn responses_client_uses_responses_path_for_responses_wire() -> Result<()> let transport = RecordingTransport::new(state.clone()); let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth); - let body = serde_json::json!({ "echo": true }); - let _stream = client.stream(body, HeaderMap::new()).await?; + let request = responses_request(serde_json::json!({ "echo": true })); + let _stream = client.stream(request).await?; let requests = state.take_stream_requests(); assert_path_ends_with(&requests, "/responses"); @@ -242,8 +255,8 @@ async fn responses_client_uses_chat_path_for_chat_wire() -> Result<()> { let transport = RecordingTransport::new(state.clone()); let client = ResponsesClient::new(transport, provider("openai", WireApi::Chat), NoAuth); - let body = serde_json::json!({ "echo": true }); - let _stream = client.stream(body, HeaderMap::new()).await?; + let request = responses_request(serde_json::json!({ "echo": true })); + let _stream = client.stream(request).await?; let requests = state.take_stream_requests(); assert_path_ends_with(&requests, "/chat/completions"); @@ -257,8 +270,8 @@ async fn streaming_client_adds_auth_headers() -> Result<()> { let auth = StaticAuth::new("secret-token", "acct-1"); let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), auth); - let body = serde_json::json!({ "model": "gpt-test" }); - let _stream = client.stream(body, HeaderMap::new()).await?; + let request = responses_request(serde_json::json!({ "model": "gpt-test" })); + let _stream = client.stream(request).await?; let requests = state.take_stream_requests(); assert_eq!(requests.len(), 1); diff --git a/codex-rs/codex-api/tests/sse_end_to_end.rs b/codex-rs/codex-api/tests/sse_end_to_end.rs index b91cf3a5d8e..8e14c7779e3 100644 --- a/codex-rs/codex-api/tests/sse_end_to_end.rs +++ b/codex-rs/codex-api/tests/sse_end_to_end.rs @@ -8,7 +8,9 @@ use codex_api::AuthProvider; use codex_api::Provider; use codex_api::ResponseEvent; use codex_api::ResponsesClient; +use codex_api::ResponsesRequest; use codex_api::WireApi; +use codex_client::Body; use codex_client::HttpTransport; use codex_client::Request; use codex_client::Response; @@ -94,6 +96,13 @@ fn build_responses_body(events: Vec) -> String { body } +fn responses_request(body: Value) -> ResponsesRequest { + ResponsesRequest { + body: Body::Json(body), + headers: HeaderMap::new(), + } +} + #[tokio::test] async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()> { let item1 = serde_json::json!({ @@ -123,9 +132,8 @@ async fn responses_stream_parses_items_and_completed_end_to_end() -> Result<()> let transport = FixtureSseTransport::new(body); let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth); - let mut stream = client - .stream(serde_json::json!({"echo": true}), HeaderMap::new()) - .await?; + let request = responses_request(serde_json::json!({"echo": true})); + let mut stream = client.stream(request).await?; let mut events = Vec::new(); while let Some(ev) = stream.next().await { @@ -188,9 +196,8 @@ async fn responses_stream_aggregates_output_text_deltas() -> Result<()> { let transport = FixtureSseTransport::new(body); let client = ResponsesClient::new(transport, provider("openai", WireApi::Responses), NoAuth); - let stream = client - .stream(serde_json::json!({"echo": true}), HeaderMap::new()) - .await?; + let request = responses_request(serde_json::json!({"echo": true})); + let stream = client.stream(request).await?; let mut stream = stream.aggregate(); let mut events = Vec::new(); diff --git a/codex-rs/codex-client/src/default_client.rs b/codex-rs/codex-client/src/default_client.rs index efb4d5aec41..1ab612e8428 100644 --- a/codex-rs/codex-client/src/default_client.rs +++ b/codex-rs/codex-client/src/default_client.rs @@ -104,6 +104,13 @@ impl CodexRequestBuilder { self.map(|builder| builder.json(value)) } + pub fn body(self, body: T) -> Self + where + T: Into, + { + self.map(|builder| builder.body(body)) + } + pub async fn send(self) -> Result { let headers = trace_headers(); diff --git a/codex-rs/codex-client/src/lib.rs b/codex-rs/codex-client/src/lib.rs index 66d1083c07d..f0742a6f7e1 100644 --- a/codex-rs/codex-client/src/lib.rs +++ b/codex-rs/codex-client/src/lib.rs @@ -10,6 +10,7 @@ pub use crate::default_client::CodexHttpClient; pub use crate::default_client::CodexRequestBuilder; pub use crate::error::StreamError; pub use crate::error::TransportError; +pub use crate::request::Body; pub use crate::request::Request; pub use crate::request::Response; pub use crate::retry::RetryOn; diff --git a/codex-rs/codex-client/src/request.rs b/codex-rs/codex-client/src/request.rs index f3d205de99c..70974250d42 100644 --- a/codex-rs/codex-client/src/request.rs +++ b/codex-rs/codex-client/src/request.rs @@ -5,12 +5,18 @@ use serde::Serialize; use serde_json::Value; use std::time::Duration; +#[derive(Debug, Clone)] +pub enum Body { + Json(Value), + Bytes(Bytes), +} + #[derive(Debug, Clone)] pub struct Request { pub method: Method, pub url: String, pub headers: HeaderMap, - pub body: Option, + pub body: Option, pub timeout: Option, } @@ -26,7 +32,7 @@ impl Request { } pub fn with_json(mut self, body: &T) -> Self { - self.body = serde_json::to_value(body).ok(); + self.body = serde_json::to_value(body).ok().map(Body::Json); self } } diff --git a/codex-rs/codex-client/src/transport.rs b/codex-rs/codex-client/src/transport.rs index abe6e29ee55..4039f3c6b85 100644 --- a/codex-rs/codex-client/src/transport.rs +++ b/codex-rs/codex-client/src/transport.rs @@ -1,6 +1,7 @@ use crate::default_client::CodexHttpClient; use crate::default_client::CodexRequestBuilder; use crate::error::TransportError; +use crate::request::Body; use crate::request::Request; use crate::request::Response; use async_trait::async_trait; @@ -52,7 +53,10 @@ impl ReqwestTransport { builder = builder.timeout(timeout); } if let Some(body) = req.body { - builder = builder.json(&body); + builder = match body { + Body::Json(value) => builder.json(&value), + Body::Bytes(bytes) => builder.body(bytes), + }; } Ok(builder) } @@ -101,10 +105,10 @@ impl HttpTransport for ReqwestTransport { async fn stream(&self, req: Request) -> Result { if enabled!(Level::TRACE) { trace!( - "{} to {}: {}", - req.method, - req.url, - req.body.as_ref().unwrap_or_default() + method = %req.method, + url = %req.url, + body = ?req.body, + "Sending streaming request" ); } diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 51fab19decf..01c294b4518 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -89,6 +89,7 @@ url = { workspace = true } uuid = { workspace = true, features = ["serde", "v4", "v5"] } which = { workspace = true } wildmatch = { workspace = true } +zstd = { workspace = true } [features] deterministic_process_ids = [] diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 11a3c5c65f3..bc8d4140cda 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -156,6 +156,9 @@ impl ModelClient { let mut refreshed = false; loop { let auth = auth_manager.as_ref().and_then(|m| m.auth()); + let request_compression = self + .provider + .request_compression_for(auth.as_ref().map(|a| a.mode), &self.config.features); let api_provider = self .provider .to_api_provider(auth.as_ref().map(|a| a.mode))?; @@ -171,6 +174,7 @@ impl ModelClient { &api_prompt, Some(conversation_id.clone()), Some(session_source.clone()), + request_compression, ) .await; @@ -245,6 +249,9 @@ impl ModelClient { let mut refreshed = false; loop { let auth = auth_manager.as_ref().and_then(|m| m.auth()); + let request_compression = self + .provider + .request_compression_for(auth.as_ref().map(|a| a.mode), &self.config.features); let api_provider = self .provider .to_api_provider(auth.as_ref().map(|a| a.mode))?; @@ -263,6 +270,7 @@ impl ModelClient { conversation_id: Some(conversation_id.clone()), session_source: Some(session_source.clone()), extra_headers: beta_feature_headers(&self.config), + request_compression, }; let stream_result = client diff --git a/codex-rs/core/src/default_client.rs b/codex-rs/core/src/default_client.rs index 3cd882489d0..bcf06882fd3 100644 --- a/codex-rs/core/src/default_client.rs +++ b/codex-rs/core/src/default_client.rs @@ -24,7 +24,6 @@ use std::sync::OnceLock; pub static USER_AGENT_SUFFIX: LazyLock>> = LazyLock::new(|| Mutex::new(None)); pub const DEFAULT_ORIGINATOR: &str = "codex_cli_rs"; pub const CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR: &str = "CODEX_INTERNAL_ORIGINATOR_OVERRIDE"; - #[derive(Debug, Clone)] pub struct Originator { pub value: String, diff --git a/codex-rs/core/src/features.rs b/codex-rs/core/src/features.rs index 3b22bfc3f45..4085783106d 100644 --- a/codex-rs/core/src/features.rs +++ b/codex-rs/core/src/features.rs @@ -8,7 +8,6 @@ use crate::config::ConfigToml; use crate::config::profile::ConfigProfile; use serde::Deserialize; -use serde::Serialize; use std::collections::BTreeMap; use std::collections::BTreeSet; @@ -74,6 +73,8 @@ pub enum Feature { ApplyPatchFreeform, /// Allow the model to request web searches. WebSearchRequest, + /// Allow request body compression when using ChatGPT auth. + RequestCompression, /// Gate the execpolicy enforcement for shell/unified exec. ExecPolicy, /// Enable Windows sandbox (restricted token) on Windows. @@ -150,16 +151,16 @@ impl FeatureOverrides { impl Features { /// Starts with built-in defaults. pub fn with_defaults() -> Self { - let mut set = BTreeSet::new(); + let mut features = Self { + enabled: BTreeSet::new(), + legacy_usages: BTreeSet::new(), + }; for spec in FEATURES { if spec.default_enabled { - set.insert(spec.id); + features.enable(spec.id); } } - Self { - enabled: set, - legacy_usages: BTreeSet::new(), - } + features } pub fn enabled(&self, f: Feature) -> bool { @@ -196,7 +197,7 @@ impl Features { .map(|usage| (usage.alias.as_str(), usage.feature)) } - /// Apply a table of key -> bool toggles (e.g. from TOML). + /// Apply a table of key -> value toggles (e.g. from TOML). pub fn apply_map(&mut self, m: &BTreeMap) { for (k, v) in m { match feature_for_key(k) { @@ -330,6 +331,12 @@ pub const FEATURES: &[FeatureSpec] = &[ stage: Stage::Stable, default_enabled: false, }, + FeatureSpec { + id: Feature::RequestCompression, + key: "request_compression", + stage: Stage::Experimental, + default_enabled: false, + }, // Beta program. Rendered in the `/experimental` menu for users. FeatureSpec { id: Feature::UnifiedExec, diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index 96173922372..ab39c7ba0ee 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -19,6 +19,8 @@ use std::env::VarError; use std::time::Duration; use crate::error::EnvVarError; +use crate::features::Feature; +use crate::features::Features; const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000; const DEFAULT_STREAM_MAX_RETRIES: u64 = 5; const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4; @@ -253,6 +255,21 @@ impl ModelProviderInfo { pub fn is_openai(&self) -> bool { self.name == OPENAI_PROVIDER_NAME } + + pub fn request_compression_for( + &self, + auth_mode: Option, + features: &Features, + ) -> codex_api::provider::RequestCompression { + if self.is_openai() + && matches!(auth_mode, Some(AuthMode::ChatGPT)) + && features.enabled(Feature::RequestCompression) + { + codex_api::provider::RequestCompression::Zstd + } else { + codex_api::provider::RequestCompression::None + } + } } pub const DEFAULT_LMSTUDIO_PORT: u16 = 1234; diff --git a/codex-rs/core/tests/common/Cargo.toml b/codex-rs/core/tests/common/Cargo.toml index c61a0956862..a2b1231c979 100644 --- a/codex-rs/core/tests/common/Cargo.toml +++ b/codex-rs/core/tests/common/Cargo.toml @@ -15,14 +15,17 @@ codex-core = { workspace = true, features = ["test-support"] } codex-protocol = { workspace = true } codex-utils-absolute-path = { workspace = true } codex-utils-cargo-bin = { workspace = true } +http = { workspace = true } notify = { workspace = true } regex-lite = { workspace = true } serde_json = { workspace = true } +serde = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true, features = ["time"] } walkdir = { workspace = true } wiremock = { workspace = true } shlex = { workspace = true } +zstd = { workspace = true } [dev-dependencies] pretty_assertions = { workspace = true } diff --git a/codex-rs/core/tests/common/lib.rs b/codex-rs/core/tests/common/lib.rs index 9568ec2786c..a447a5d1745 100644 --- a/codex-rs/core/tests/common/lib.rs +++ b/codex-rs/core/tests/common/lib.rs @@ -11,11 +11,15 @@ use regex_lite::Regex; use std::path::PathBuf; pub mod process; +pub mod request; pub mod responses; pub mod streaming_sse; pub mod test_codex; pub mod test_codex_exec; +pub use request::RequestBodyExt; +pub use request::body_contains; + #[track_caller] pub fn assert_regex_match<'s>(pattern: &str, actual: &'s str) -> regex_lite::Captures<'s> { let regex = Regex::new(pattern).unwrap_or_else(|err| { @@ -178,7 +182,7 @@ where F: FnMut(&codex_core::protocol::EventMsg) -> bool, { use tokio::time::Duration; - wait_for_event_with_timeout(codex, predicate, Duration::from_secs(1)).await + wait_for_event_with_timeout(codex, predicate, Duration::from_secs(10)).await } pub async fn wait_for_event_match(codex: &CodexConversation, matcher: F) -> T diff --git a/codex-rs/core/tests/common/request.rs b/codex-rs/core/tests/common/request.rs new file mode 100644 index 00000000000..ce93c4a2471 --- /dev/null +++ b/codex-rs/core/tests/common/request.rs @@ -0,0 +1,59 @@ +use http::header::CONTENT_ENCODING; +use serde::de::DeserializeOwned; +use wiremock::Match; + +pub fn decoded_body_bytes(request: &wiremock::Request) -> Vec { + if is_zstd_encoded(request) { + zstd::decode_all(request.body.as_slice()).unwrap_or_else(|err| { + panic!("failed to decode zstd-encoded request body: {err}"); + }) + } else { + request.body.clone() + } +} + +pub fn decoded_body_string(request: &wiremock::Request) -> String { + String::from_utf8_lossy(&decoded_body_bytes(request)).into_owned() +} + +pub trait RequestBodyExt { + fn json_body(&self) -> T; + fn text_body(&self) -> String; +} + +impl RequestBodyExt for wiremock::Request { + fn json_body(&self) -> T { + serde_json::from_slice(&decoded_body_bytes(self)).unwrap_or_else(|err| { + panic!("failed to decode request body as JSON: {err}"); + }) + } + + fn text_body(&self) -> String { + decoded_body_string(self) + } +} + +pub fn body_contains(needle: impl Into) -> impl Match { + BodyContains { + needle: needle.into(), + } +} + +struct BodyContains { + needle: String, +} + +impl Match for BodyContains { + fn matches(&self, request: &wiremock::Request) -> bool { + decoded_body_string(request).contains(self.needle.as_str()) + } +} + +fn is_zstd_encoded(request: &wiremock::Request) -> bool { + request + .headers + .get(CONTENT_ENCODING) + .and_then(|value| value.to_str().ok()) + .map(|value| value.eq_ignore_ascii_case("zstd")) + .unwrap_or(false) +} diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index 39347714096..8d7901d9d0e 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -15,6 +15,7 @@ use wiremock::ResponseTemplate; use wiremock::matchers::method; use wiremock::matchers::path_regex; +use crate::RequestBodyExt; use crate::test_codex::ApplyPatchModelOutput; #[derive(Debug, Clone)] @@ -67,7 +68,7 @@ pub struct ResponsesRequest(wiremock::Request); impl ResponsesRequest { pub fn body_json(&self) -> Value { - self.0.body_json().unwrap() + self.0.json_body() } /// Returns all `input_text` spans from `message` inputs for the provided role. @@ -83,7 +84,7 @@ impl ResponsesRequest { } pub fn input(&self) -> Vec { - self.0.body_json::().unwrap()["input"] + self.body_json()["input"] .as_array() .expect("input array not found in request") .clone() @@ -721,10 +722,7 @@ pub async fn get_responses_request_bodies(server: &MockServer) -> Vec { get_responses_requests(server) .await .into_iter() - .map(|req| { - req.body_json::() - .expect("request body to be valid JSON") - }) + .map(|req| req.json_body::()) .collect() } diff --git a/codex-rs/core/tests/responses_headers.rs b/codex-rs/core/tests/responses_headers.rs index 3b0ffd2983c..dc2b3990156 100644 --- a/codex-rs/core/tests/responses_headers.rs +++ b/codex-rs/core/tests/responses_headers.rs @@ -10,6 +10,7 @@ use codex_core::Prompt; use codex_core::ResponseEvent; use codex_core::ResponseItem; use codex_core::WireApi; +use codex_core::features::Feature; use codex_core::models_manager::manager::ModelsManager; use codex_otel::otel_manager::OtelManager; use codex_protocol::ConversationId; @@ -317,3 +318,191 @@ async fn responses_respects_model_family_overrides_from_config() { Some("detailed") ); } + +#[tokio::test] +async fn responses_request_body_is_zstd_encoded() { + core_test_support::skip_if_no_network!(); + + let server = responses::start_mock_server().await; + let response_body = responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_completed("resp-1"), + ]); + + let request_recorder = responses::mount_sse_once(&server, response_body).await; + + let provider = ModelProviderInfo { + name: "OpenAI".into(), + base_url: Some(format!("{}/v1", server.uri())), + env_key: None, + env_key_instructions: None, + experimental_bearer_token: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(5_000), + requires_openai_auth: false, + }; + + let codex_home = TempDir::new().expect("failed to create TempDir"); + let mut config = load_default_config_for_test(&codex_home).await; + config.model_provider_id = provider.name.clone(); + config.model_provider = provider.clone(); + config.features.enable(Feature::RequestCompression); + let effort = config.model_reasoning_effort; + let summary = config.model_reasoning_summary; + let model = ModelsManager::get_model_offline(config.model.as_deref()); + config.model = Some(model.clone()); + let config = Arc::new(config); + + let conversation_id = ConversationId::new(); + let session_source = SessionSource::Exec; + let model_family = ModelsManager::construct_model_family_offline(model.as_str(), &config); + let otel_manager = OtelManager::new( + conversation_id, + model.as_str(), + model_family.slug.as_str(), + None, + Some("test@test.com".to_string()), + Some(AuthMode::ChatGPT), + false, + "test".to_string(), + session_source.clone(), + ); + + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let client = ModelClient::new( + Arc::clone(&config), + Some(auth_manager), + model_family, + otel_manager, + provider, + effort, + summary, + conversation_id, + session_source, + ); + let mut prompt = Prompt::default(); + prompt.input = vec![ResponseItem::Message { + id: None, + role: "user".into(), + content: vec![ContentItem::InputText { + text: "hello".into(), + }], + }]; + + let mut stream = client.stream(&prompt).await.expect("stream failed"); + while let Some(event) = stream.next().await { + if matches!(event, Ok(ResponseEvent::Completed { .. })) { + break; + } + } + + let request = request_recorder.single_request(); + assert_eq!(request.header("content-encoding").as_deref(), Some("zstd")); + assert_eq!( + request.header("content-type").as_deref(), + Some("application/json") + ); + let request_body = request.body_json(); + assert_eq!(request_body["stream"].as_bool(), Some(true)); + assert_eq!( + request_body["input"][0]["content"][0]["text"].as_str(), + Some("hello") + ); +} + +#[tokio::test] +async fn responses_request_body_is_uncompressed_when_disabled() { + core_test_support::skip_if_no_network!(); + + let server = responses::start_mock_server().await; + let response_body = responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_completed("resp-1"), + ]); + + let request_recorder = responses::mount_sse_once(&server, response_body).await; + + let provider = ModelProviderInfo { + name: "OpenAI".into(), + base_url: Some(format!("{}/v1", server.uri())), + env_key: None, + env_key_instructions: None, + experimental_bearer_token: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(5_000), + requires_openai_auth: false, + }; + + let codex_home = TempDir::new().expect("failed to create TempDir"); + let mut config = load_default_config_for_test(&codex_home).await; + config.model_provider_id = provider.name.clone(); + config.model_provider = provider.clone(); + let effort = config.model_reasoning_effort; + let summary = config.model_reasoning_summary; + let model = ModelsManager::get_model_offline(config.model.as_deref()); + config.model = Some(model.clone()); + let config = Arc::new(config); + + let conversation_id = ConversationId::new(); + let session_source = SessionSource::Exec; + let model_family = ModelsManager::construct_model_family_offline(model.as_str(), &config); + let otel_manager = OtelManager::new( + conversation_id, + model.as_str(), + model_family.slug.as_str(), + None, + Some("test@test.com".to_string()), + Some(AuthMode::ChatGPT), + false, + "test".to_string(), + session_source.clone(), + ); + + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let client = ModelClient::new( + Arc::clone(&config), + Some(auth_manager), + model_family, + otel_manager, + provider, + effort, + summary, + conversation_id, + session_source, + ); + + let mut prompt = Prompt::default(); + prompt.input = vec![ResponseItem::Message { + id: None, + role: "user".into(), + content: vec![ContentItem::InputText { + text: "hello".into(), + }], + }]; + + let mut stream = client.stream(&prompt).await.expect("stream failed"); + while let Some(event) = stream.next().await { + if matches!(event, Ok(ResponseEvent::Completed { .. })) { + break; + } + } + + let request = request_recorder.single_request(); + assert_eq!(request.header("content-encoding"), None); + assert_eq!( + request.header("content-type").as_deref(), + Some("application/json") + ); +} diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index 5ce6e9f2ffd..63eea84ea0b 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -29,6 +29,8 @@ use codex_protocol::models::ReasoningItemReasoningSummary; use codex_protocol::models::WebSearchAction; use codex_protocol::openai_models::ReasoningEffort; use codex_protocol::user_input::UserInput; +use core_test_support::RequestBodyExt; +use core_test_support::body_contains; use core_test_support::load_default_config_for_test; use core_test_support::load_sse_fixture_with_id; use core_test_support::responses::ev_completed_with_tokens; @@ -51,7 +53,6 @@ use uuid::Uuid; use wiremock::Mock; use wiremock::MockServer; use wiremock::ResponseTemplate; -use wiremock::matchers::body_string_contains; use wiremock::matchers::header_regex; use wiremock::matchers::method; use wiremock::matchers::path; @@ -507,7 +508,7 @@ async fn chatgpt_auth_sends_correct_request() { let request_authorization = request.headers.get("authorization").unwrap(); let request_originator = request.headers.get("originator").unwrap(); let request_chatgpt_account_id = request.headers.get("chatgpt-account-id").unwrap(); - let request_body = request.body_json::().unwrap(); + let request_body = request.json_body::(); assert_eq!( request_conversation_id.to_str().unwrap(), @@ -1495,7 +1496,7 @@ async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Res mount_sse_once_match( &server, - body_string_contains("trigger context window"), + body_contains("trigger context window"), sse_failed( "resp_context_window", "context_length_exceeded", @@ -1506,7 +1507,7 @@ async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Res mount_sse_once_match( &server, - body_string_contains("seed turn"), + body_contains("seed turn"), sse_completed("resp_seed"), ) .await; @@ -1882,8 +1883,7 @@ async fn history_dedupes_streamed_and_final_messages_across_turns() { ]); let r3_input_array = requests[2] - .body_json::() - .unwrap() + .json_body::() .get("input") .and_then(|v| v.as_array()) .cloned() diff --git a/codex-rs/core/tests/suite/compact.rs b/codex-rs/core/tests/suite/compact.rs index 09b5fb18b85..4b88b916ae2 100644 --- a/codex-rs/core/tests/suite/compact.rs +++ b/codex-rs/core/tests/suite/compact.rs @@ -17,6 +17,7 @@ use codex_core::protocol::SandboxPolicy; use codex_core::protocol::WarningEvent; use codex_protocol::config_types::ReasoningSummary; use codex_protocol::user_input::UserInput; +use core_test_support::RequestBodyExt; use core_test_support::load_default_config_for_test; use core_test_support::responses::ev_local_shell_call; use core_test_support::responses::ev_reasoning_item; @@ -132,7 +133,6 @@ async fn summarize_context_three_requests_and_instructions() { // SSE 3: minimal completed; we only need to capture the request body. let sse3 = sse(vec![ev_completed("r3")]); - // Mount the three expected requests in sequence so the assertions below can // inspect them without relying on specific prompt markers. let request_log = mount_sse_sequence(&server, vec![sse1, sse2, sse3]).await; @@ -361,7 +361,8 @@ async fn manual_compact_uses_custom_prompt() { let requests = get_responses_requests(&server).await; let body = requests .iter() - .find_map(|req| req.body_json::().ok()) + .map(core_test_support::RequestBodyExt::json_body::) + .next() .expect("summary request body"); let input = body @@ -591,9 +592,7 @@ async fn multiple_auto_compact_per_task_runs_after_token_limit_hit() { // collect the requests payloads from the model let requests_payloads = get_responses_requests(&server).await; - let body = requests_payloads[0] - .body_json::() - .unwrap(); + let body = requests_payloads[0].json_body::(); let input = body.get("input").and_then(|v| v.as_array()).unwrap(); fn normalize_inputs(values: &[serde_json::Value]) -> Vec { @@ -634,9 +633,7 @@ async fn multiple_auto_compact_per_task_runs_after_token_limit_hit() { prefixed_third_summary.as_str(), ]; for (i, expected_summary) in compaction_indices.into_iter().zip(expected_summaries) { - let body = requests_payloads.clone()[i] - .body_json::() - .unwrap(); + let body = requests_payloads.clone()[i].json_body::(); let input = body.get("input").and_then(|v| v.as_array()).unwrap(); let input = normalize_inputs(input); assert_eq!(input.len(), 3); @@ -999,7 +996,7 @@ async fn multiple_auto_compact_per_task_runs_after_token_limit_hit() { ]); for (i, request) in requests_payloads.iter().enumerate() { - let body = request.body_json::().unwrap(); + let body = request.json_body::(); let input = body.get("input").and_then(|v| v.as_array()).unwrap(); let expected_input = expected_requests_inputs[i].as_array().unwrap(); assert_eq!(normalize_inputs(input), normalize_inputs(expected_input)); @@ -1038,30 +1035,30 @@ async fn auto_compact_runs_after_token_limit_hit() { let prefixed_auto_summary = AUTO_SUMMARY_TEXT; let first_matcher = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); + let body = req.text_body(); body.contains(FIRST_AUTO_MSG) && !body.contains(SECOND_AUTO_MSG) - && !body_contains_text(body, SUMMARIZATION_PROMPT) + && !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT) }; mount_sse_once_match(&server, first_matcher, sse1).await; let second_matcher = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); + let body = req.text_body(); body.contains(SECOND_AUTO_MSG) && body.contains(FIRST_AUTO_MSG) - && !body_contains_text(body, SUMMARIZATION_PROMPT) + && !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT) }; mount_sse_once_match(&server, second_matcher, sse2).await; let third_matcher = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); - body_contains_text(body, SUMMARIZATION_PROMPT) + let body = req.text_body(); + body_contains_text(body.as_str(), SUMMARIZATION_PROMPT) }; mount_sse_once_match(&server, third_matcher, sse3).await; - let fourth_matcher = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); - body.contains(POST_AUTO_USER_MSG) && !body_contains_text(body, SUMMARIZATION_PROMPT) + let body = req.text_body(); + body.contains(POST_AUTO_USER_MSG) + && !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT) }; mount_sse_once_match(&server, fourth_matcher, sse4).await; @@ -1126,10 +1123,7 @@ async fn auto_compact_runs_after_token_limit_hit() { requests.len() ); let is_auto_compact = |req: &wiremock::Request| { - body_contains_text( - std::str::from_utf8(&req.body).unwrap_or(""), - SUMMARIZATION_PROMPT, - ) + body_contains_text(req.text_body().as_str(), SUMMARIZATION_PROMPT) }; let auto_compact_count = requests.iter().filter(|req| is_auto_compact(req)).count(); assert_eq!( @@ -1151,20 +1145,16 @@ async fn auto_compact_runs_after_token_limit_hit() { .enumerate() .rev() .find_map(|(idx, req)| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); - (body.contains(POST_AUTO_USER_MSG) && !body_contains_text(body, SUMMARIZATION_PROMPT)) + let body = req.text_body(); + (body.contains(POST_AUTO_USER_MSG) && !body_contains_text(&body, SUMMARIZATION_PROMPT)) .then_some(idx) }) .expect("follow-up request missing"); assert_eq!(follow_up_index, 3, "follow-up request should be last"); - let body_first = requests[0].body_json::().unwrap(); - let body_auto = requests[auto_compact_index] - .body_json::() - .unwrap(); - let body_follow_up = requests[follow_up_index] - .body_json::() - .unwrap(); + let body_first = requests[0].json_body::(); + let body_auto = requests[auto_compact_index].json_body::(); + let body_follow_up = requests[follow_up_index].json_body::(); let instructions = body_auto .get("instructions") .and_then(|v| v.as_str()) @@ -1375,24 +1365,24 @@ async fn auto_compact_persists_rollout_entries() { ]); let first_matcher = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); + let body = req.text_body(); body.contains(FIRST_AUTO_MSG) && !body.contains(SECOND_AUTO_MSG) - && !body_contains_text(body, SUMMARIZATION_PROMPT) + && !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT) }; mount_sse_once_match(&server, first_matcher, sse1).await; let second_matcher = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); + let body = req.text_body(); body.contains(SECOND_AUTO_MSG) && body.contains(FIRST_AUTO_MSG) - && !body_contains_text(body, SUMMARIZATION_PROMPT) + && !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT) }; mount_sse_once_match(&server, second_matcher, sse2).await; let third_matcher = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); - body_contains_text(body, SUMMARIZATION_PROMPT) + let body = req.text_body(); + body_contains_text(body.as_str(), SUMMARIZATION_PROMPT) }; mount_sse_once_match(&server, third_matcher, sse3).await; diff --git a/codex-rs/core/tests/suite/compact_resume_fork.rs b/codex-rs/core/tests/suite/compact_resume_fork.rs index 3e38c89b338..4edb1c83dd9 100644 --- a/codex-rs/core/tests/suite/compact_resume_fork.rs +++ b/codex-rs/core/tests/suite/compact_resume_fork.rs @@ -23,6 +23,7 @@ use codex_core::protocol::Op; use codex_core::protocol::WarningEvent; use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; use codex_protocol::user_input::UserInput; +use core_test_support::RequestBodyExt; use core_test_support::load_default_config_for_test; use core_test_support::responses::ev_assistant_message; use core_test_support::responses::ev_completed; @@ -796,8 +797,9 @@ async fn mount_initial_flow(server: &MockServer) { let sse5 = sse(vec![ev_completed("r5")]); let match_first = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); + let body = req.text_body(); body.contains("\"text\":\"hello world\"") + && !body_contains_text(body.as_str(), SUMMARIZATION_PROMPT) && !body.contains(&format!("\"text\":\"{SUMMARY_TEXT}\"")) && !body.contains("\"text\":\"AFTER_COMPACT\"") && !body.contains("\"text\":\"AFTER_RESUME\"") @@ -806,13 +808,13 @@ async fn mount_initial_flow(server: &MockServer) { mount_sse_once_match(server, match_first, sse1).await; let match_compact = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); - body_contains_text(body, SUMMARIZATION_PROMPT) || body.contains(&json_fragment(FIRST_REPLY)) + let body = req.text_body(); + body_contains_text(body.as_str(), SUMMARIZATION_PROMPT) }; mount_sse_once_match(server, match_compact, sse2).await; let match_after_compact = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); + let body = req.text_body(); body.contains("\"text\":\"AFTER_COMPACT\"") && !body.contains("\"text\":\"AFTER_RESUME\"") && !body.contains("\"text\":\"AFTER_FORK\"") @@ -820,13 +822,13 @@ async fn mount_initial_flow(server: &MockServer) { mount_sse_once_match(server, match_after_compact, sse3).await; let match_after_resume = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); + let body = req.text_body(); body.contains("\"text\":\"AFTER_RESUME\"") }; mount_sse_once_match(server, match_after_resume, sse4).await; let match_after_fork = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); + let body = req.text_body(); body.contains("\"text\":\"AFTER_FORK\"") }; mount_sse_once_match(server, match_after_fork, sse5).await; @@ -840,13 +842,13 @@ async fn mount_second_compact_flow(server: &MockServer) { let sse7 = sse(vec![ev_completed("r7")]); let match_second_compact = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); - body.contains("AFTER_FORK") + let body = req.text_body(); + body_contains_text(body.as_str(), SUMMARIZATION_PROMPT) && body.contains("AFTER_FORK") }; mount_sse_once_match(server, match_second_compact, sse6).await; let match_after_second_resume = |req: &wiremock::Request| { - let body = std::str::from_utf8(&req.body).unwrap_or(""); + let body = req.text_body(); body.contains(&format!("\"text\":\"{AFTER_SECOND_RESUME}\"")) }; mount_sse_once_match(server, match_after_second_resume, sse7).await; diff --git a/codex-rs/core/tests/suite/json_result.rs b/codex-rs/core/tests/suite/json_result.rs index 1b9949102e6..7731b8b7e0f 100644 --- a/codex-rs/core/tests/suite/json_result.rs +++ b/codex-rs/core/tests/suite/json_result.rs @@ -6,6 +6,7 @@ use codex_core::protocol::Op; use codex_core::protocol::SandboxPolicy; use codex_protocol::config_types::ReasoningSummary; use codex_protocol::user_input::UserInput; +use core_test_support::RequestBodyExt; use core_test_support::responses; use core_test_support::skip_if_no_network; use core_test_support::test_codex::TestCodex; @@ -54,7 +55,7 @@ async fn codex_returns_json_result(model: String) -> anyhow::Result<()> { let expected_schema: serde_json::Value = serde_json::from_str(SCHEMA)?; let match_json_text_param = move |req: &wiremock::Request| { - let body: serde_json::Value = serde_json::from_slice(&req.body).unwrap_or_default(); + let body: serde_json::Value = req.json_body(); let Some(text) = body.get("text") else { return false; }; diff --git a/codex-rs/core/tests/suite/review.rs b/codex-rs/core/tests/suite/review.rs index b35213f7e64..d09caafe359 100644 --- a/codex-rs/core/tests/suite/review.rs +++ b/codex-rs/core/tests/suite/review.rs @@ -21,6 +21,7 @@ use codex_core::protocol::RolloutItem; use codex_core::protocol::RolloutLine; use codex_core::review_format::render_review_output_text; use codex_protocol::user_input::UserInput; +use core_test_support::RequestBodyExt; use core_test_support::load_default_config_for_test; use core_test_support::load_sse_fixture_with_id_from_str; use core_test_support::responses::get_responses_requests; @@ -430,7 +431,7 @@ async fn review_uses_custom_review_model_from_config() { let request = requests .first() .expect("expected POST request to /responses"); - let body = request.body_json::().unwrap(); + let body = request.json_body::(); assert_eq!(body["model"].as_str().unwrap(), "gpt-5.1"); server.verify().await; @@ -551,7 +552,7 @@ async fn review_input_isolated_from_parent_history() { let request = requests .first() .expect("expected POST request to /responses"); - let body = request.body_json::().unwrap(); + let body = request.json_body::(); let input = body["input"].as_array().expect("input array"); assert!( input.len() >= 2, @@ -676,7 +677,7 @@ async fn review_history_surfaces_in_parent_session() { // Critically, no messages from the review thread should appear. let requests = get_responses_requests(&server).await; assert_eq!(requests.len(), 2); - let body = requests[1].body_json::().unwrap(); + let body = requests[1].json_body::(); let input = body["input"].as_array().expect("input array"); // Must include the followup as the last item for this turn diff --git a/codex-rs/core/tests/suite/stream_error_allows_next_turn.rs b/codex-rs/core/tests/suite/stream_error_allows_next_turn.rs index e7a60912643..f249dbfc46e 100644 --- a/codex-rs/core/tests/suite/stream_error_allows_next_turn.rs +++ b/codex-rs/core/tests/suite/stream_error_allows_next_turn.rs @@ -3,6 +3,7 @@ use codex_core::WireApi; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; use codex_protocol::user_input::UserInput; +use core_test_support::body_contains; use core_test_support::load_sse_fixture_with_id; use core_test_support::skip_if_no_network; use core_test_support::test_codex::TestCodex; @@ -11,7 +12,6 @@ use core_test_support::wait_for_event; use wiremock::Mock; use wiremock::MockServer; use wiremock::ResponseTemplate; -use wiremock::matchers::body_string_contains; use wiremock::matchers::method; use wiremock::matchers::path; @@ -38,7 +38,7 @@ async fn continue_after_stream_error() { // so the failing request should only occur once. Mock::given(method("POST")) .and(path("/v1/responses")) - .and(body_string_contains("first message")) + .and(body_contains("first message")) .respond_with(fail) .up_to_n_times(2) .mount(&server) @@ -50,7 +50,7 @@ async fn continue_after_stream_error() { Mock::given(method("POST")) .and(path("/v1/responses")) - .and(body_string_contains("follow up")) + .and(body_contains("follow up")) .respond_with(ok) .expect(1) .mount(&server)