From 8dce47c20134c4fcddabe4b22568e61838d3dfc3 Mon Sep 17 00:00:00 2001 From: Timon Vonk Date: Thu, 14 Aug 2025 11:50:32 +0200 Subject: [PATCH] feat: Respect backoff policy for streaming --- async-openai/src/client.rs | 13 ++++-- async-openai/src/lib.rs | 1 + async-openai/src/streaming_backoff.rs | 64 +++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 3 deletions(-) create mode 100644 async-openai/src/streaming_backoff.rs diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index fe2ed232..e9df8cdc 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use bytes::Bytes; use futures::{stream::StreamExt, Stream}; use reqwest::multipart::Form; -use reqwest_eventsource::{Event, EventSource, RequestBuilderExt}; +use reqwest_eventsource::{retry::ExponentialBackoff, Event, EventSource, RequestBuilderExt}; use serde::{de::DeserializeOwned, Serialize}; use crate::{ @@ -12,6 +12,7 @@ use crate::{ file::Files, image::Images, moderation::Moderations, + streaming_backoff::StreamingBackoff, traits::AsyncTryFrom, Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites, Models, Projects, Responses, Threads, Uploads, Users, VectorStores, @@ -414,7 +415,7 @@ impl Client { I: Serialize, O: DeserializeOwned + std::marker::Send + 'static, { - let event_source = self + let mut event_source = self .http_client .post(self.config.url(path)) .query(&self.config.query()) @@ -423,6 +424,9 @@ impl Client { .eventsource() .unwrap(); + let retry_policy: StreamingBackoff = self.backoff.clone().into(); + event_source.set_retry_policy(Box::new(retry_policy)); + stream(event_source).await } @@ -436,7 +440,7 @@ impl Client { I: Serialize, O: DeserializeOwned + std::marker::Send + 'static, { - let event_source = self + let mut event_source = self .http_client .post(self.config.url(path)) .query(&self.config.query()) @@ -445,6 +449,9 @@ impl Client { .eventsource() .unwrap(); + let retry_policy: StreamingBackoff = self.backoff.clone().into(); + event_source.set_retry_policy(Box::new(retry_policy)); + stream_mapped_raw_events(event_source, event_mapper).await } diff --git a/async-openai/src/lib.rs b/async-openai/src/lib.rs index c94bc495..ce7dd70a 100644 --- a/async-openai/src/lib.rs +++ b/async-openai/src/lib.rs @@ -165,6 +165,7 @@ mod projects; mod responses; mod runs; mod steps; +mod streaming_backoff; mod threads; pub mod traits; pub mod types; diff --git a/async-openai/src/streaming_backoff.rs b/async-openai/src/streaming_backoff.rs new file mode 100644 index 00000000..456e1c30 --- /dev/null +++ b/async-openai/src/streaming_backoff.rs @@ -0,0 +1,64 @@ +use std::time::Duration; + +use reqwest::StatusCode; +use reqwest_eventsource::retry::RetryPolicy; + +/// Wraps `backoff::ExponentialBackoff` to provide a custom backoff suitable for +/// reqwest_eventsource +pub struct StreamingBackoff(backoff::ExponentialBackoff); + +impl StreamingBackoff { + fn should_retry(&self, error: &reqwest_eventsource::Error) -> bool { + // Errors at the connection level only + if let reqwest_eventsource::Error::Transport(error) = error { + // TODO: We can't inspect the response body as reading consumes it. + // This is problematic because quota exceeded errors are also 429. + return error + .status() + .as_ref() + .is_some_and(StatusCode::is_server_error) + || error.status() == Some(reqwest::StatusCode::TOO_MANY_REQUESTS); + } + + true + } +} + +impl From for StreamingBackoff { + fn from(backoff: backoff::ExponentialBackoff) -> Self { + Self(backoff) + } +} + +impl RetryPolicy for StreamingBackoff { + fn retry( + &self, + error: &reqwest_eventsource::Error, + last_retry: Option<(usize, Duration)>, + ) -> Option { + if !self.should_retry(error) { + return None; + }; + + // Ignoring backoff randomization factor for simplicity + // Basically reimplements the retry policy from eventsource + if let Some((_retry_num, last_duration)) = last_retry { + let duration = last_duration.mul_f64(self.0.multiplier); + + if let Some(max_duration) = self.0.max_elapsed_time { + Some(duration.min(max_duration)) + } else { + Some(duration) + } + } else { + Some(self.0.initial_interval) + } + } + + fn set_reconnection_time(&mut self, duration: Duration) { + self.0.initial_interval = duration; + if let Some(max_elapsed_time) = self.0.max_elapsed_time { + self.0.max_elapsed_time = Some(max_elapsed_time.max(duration)) + } + } +}