diff --git a/codex-rs/core/src/account_switching.rs b/codex-rs/core/src/account_switching.rs index c9bbb9b1478e..2621711b4b15 100644 --- a/codex-rs/core/src/account_switching.rs +++ b/codex-rs/core/src/account_switching.rs @@ -15,12 +15,21 @@ use codex_protocol::protocol::RateLimitReachedType; #[derive(Debug, Default)] pub struct RateLimitSwitchState { tried_accounts: HashSet, + limited_chatgpt_accounts: HashSet, blocked_until: HashMap>, } impl RateLimitSwitchState { - pub fn mark_limited(&mut self, account_id: &str, blocked_until: Option>) { + pub fn mark_limited( + &mut self, + account_id: &str, + mode: AuthMode, + blocked_until: Option>, + ) { self.tried_accounts.insert(account_id.to_string()); + if mode.has_chatgpt_account() { + self.limited_chatgpt_accounts.insert(account_id.to_string()); + } if let Some(until) = blocked_until { self.blocked_until .entry(account_id.to_string()) @@ -30,6 +39,8 @@ impl RateLimitSwitchState { } }) .or_insert(until); + } else { + self.blocked_until.remove(account_id); } } @@ -40,6 +51,10 @@ impl RateLimitSwitchState { fn blocked_until(&self, account_id: &str) -> Option> { self.blocked_until.get(account_id).copied() } + + fn is_chatgpt_limited(&self, account_id: &str) -> bool { + self.limited_chatgpt_accounts.contains(account_id) + } } #[derive(Debug, Clone, Copy, PartialEq)] @@ -201,6 +216,7 @@ fn has_unexpired_tried_marker( pub fn select_next_account_id( codex_home: &Path, + auth_home: &Path, state: &RateLimitSwitchState, allow_api_key_fallback: bool, now: DateTime, @@ -208,9 +224,9 @@ pub fn select_next_account_id( ) -> io::Result> { let current = match current_account_id { Some(id) => Some(id.to_string()), - None => codex_login::get_active_account_id(codex_home)?, + None => codex_login::get_active_account_id(auth_home)?, }; - let accounts = codex_login::list_accounts(codex_home)?; + let accounts = codex_login::list_accounts(auth_home)?; let snapshots = account_usage::list_rate_limit_snapshots(codex_home).unwrap_or_default(); let snapshot_map: HashMap = snapshots @@ -239,6 +255,9 @@ pub fn select_next_account_id( if current.is_some_and(|id| id == account.id) { continue; } + if state.is_chatgpt_limited(&account.id) { + continue; + } if has_unexpired_tried_marker(state, &account.id, now) { continue; } @@ -265,9 +284,11 @@ pub fn select_next_account_id( } let all_chatgpt_unavailable = chatgpt_accounts.iter().all(|account| { - current.is_some_and(|id| id == account.id) - || has_unexpired_tried_marker(state, &account.id, now) - || is_blocked(now, blocked_until_for(state, &snapshot_map, &account.id)) + let blocked_until = blocked_until_for(state, &snapshot_map, &account.id); + let blocked = is_blocked(now, blocked_until); + let exhausted = state.is_chatgpt_limited(&account.id); + let tried = state.has_tried(&account.id); + current.is_some_and(|id| id == account.id) || blocked || (tried && exhausted) }); if !all_chatgpt_unavailable { return Ok(None); @@ -347,23 +368,26 @@ pub fn switch_active_account_to_preferred_for_new_session( pub fn switch_active_account_on_rate_limit( codex_home: &Path, + auth_home: &Path, state: &mut RateLimitSwitchState, allow_api_key_fallback: bool, now: DateTime, current_account_id: &str, + current_mode: AuthMode, blocked_until: Option>, auth_credentials_store_mode: AuthCredentialsStoreMode, ) -> io::Result { - state.mark_limited(current_account_id, blocked_until); + state.mark_limited(current_account_id, current_mode, blocked_until); match select_next_account_id( codex_home, + auth_home, state, allow_api_key_fallback, now, Some(current_account_id), )? { Some(account_id) => { - codex_login::activate_account(codex_home, &account_id, auth_credentials_store_mode) + codex_login::activate_account(auth_home, &account_id, auth_credentials_store_mode) .map(AccountSwitchOutcome::Switched) } None => Ok(AccountSwitchOutcome::NoCandidate), @@ -486,6 +510,7 @@ mod tests { .expect("record faster"); let selected = select_next_account_id( + temp.path(), temp.path(), &RateLimitSwitchState::default(), false, @@ -506,6 +531,7 @@ mod tests { let _api_key = upsert_api_key(temp.path(), "sk-test"); let selected = select_next_account_id( + temp.path(), temp.path(), &RateLimitSwitchState::default(), true, @@ -535,6 +561,7 @@ mod tests { .expect("record usage hint"); let selected = select_next_account_id( + temp.path(), temp.path(), &RateLimitSwitchState::default(), true, @@ -554,10 +581,15 @@ mod tests { let first_api_key = upsert_api_key(temp.path(), "sk-first"); let second_api_key = upsert_api_key(temp.path(), "sk-second"); let mut state = RateLimitSwitchState::default(); - state.mark_limited(&first_api_key, Some(now + Duration::hours(1))); + state.mark_limited( + &first_api_key, + AuthMode::ApiKey, + Some(now + Duration::hours(1)), + ); let selected = - select_next_account_id(temp.path(), &state, true, now, Some(¤t)).expect("select"); + select_next_account_id(temp.path(), temp.path(), &state, true, now, Some(¤t)) + .expect("select"); assert_eq!(selected, Some(second_api_key)); } @@ -582,11 +614,13 @@ mod tests { let mut state = RateLimitSwitchState::default(); let outcome = switch_active_account_on_rate_limit( + temp.path(), temp.path(), &mut state, false, now, ¤t, + AuthMode::Chatgpt, Some(now + Duration::hours(1)), AuthCredentialsStoreMode::File, ) @@ -607,6 +641,98 @@ mod tests { assert!(state.has_tried(¤t)); } + #[test] + fn api_key_fallback_requires_all_chatgpt_accounts_marked_limited() { + let temp = tempfile::tempdir().expect("tempdir"); + let now = Utc::now(); + let current = upsert_chatgpt(temp.path(), "current"); + let candidate = upsert_chatgpt(temp.path(), "candidate"); + let api_key = upsert_api_key(temp.path(), "sk-test"); + let mut state = RateLimitSwitchState::default(); + state.mark_limited(¤t, AuthMode::Chatgpt, None); + + let selected = + select_next_account_id(temp.path(), temp.path(), &state, true, now, Some(¤t)) + .expect("select"); + assert_eq!(selected, Some(candidate.clone())); + + state.mark_limited(&candidate, AuthMode::Chatgpt, None); + let selected = select_next_account_id( + temp.path(), + temp.path(), + &state, + true, + now, + Some(&candidate), + ) + .expect("select"); + assert_eq!(selected, Some(api_key)); + } + + #[test] + fn limited_chatgpt_account_is_not_reselected_without_reset_hint() { + let temp = tempfile::tempdir().expect("tempdir"); + let now = Utc::now(); + let current = upsert_chatgpt(temp.path(), "current"); + let candidate = upsert_chatgpt(temp.path(), "candidate"); + let api_key = upsert_api_key(temp.path(), "sk-test"); + let mut state = RateLimitSwitchState::default(); + state.mark_limited(¤t, AuthMode::Chatgpt, None); + + let selected = + select_next_account_id(temp.path(), temp.path(), &state, true, now, Some(¤t)) + .expect("select"); + assert_eq!(selected, Some(candidate.clone())); + + state.mark_limited(&candidate, AuthMode::Chatgpt, None); + let selected = select_next_account_id( + temp.path(), + temp.path(), + &state, + true, + now, + Some(&candidate), + ) + .expect("select"); + assert_eq!(selected, Some(api_key)); + } + + #[test] + fn limited_chatgpt_account_is_not_reselected_after_expired_reset_hint() { + let temp = tempfile::tempdir().expect("tempdir"); + let now = Utc::now(); + let current = upsert_chatgpt(temp.path(), "current"); + let candidate = upsert_chatgpt(temp.path(), "candidate"); + let api_key = upsert_api_key(temp.path(), "sk-test"); + let mut state = RateLimitSwitchState::default(); + state.mark_limited( + ¤t, + AuthMode::Chatgpt, + Some(now - Duration::minutes(1)), + ); + + let selected = + select_next_account_id(temp.path(), temp.path(), &state, true, now, Some(¤t)) + .expect("select"); + assert_eq!(selected, Some(candidate.clone())); + + state.mark_limited( + &candidate, + AuthMode::Chatgpt, + Some(now - Duration::minutes(1)), + ); + let selected = select_next_account_id( + temp.path(), + temp.path(), + &state, + true, + now, + Some(&candidate), + ) + .expect("select"); + assert_eq!(selected, Some(api_key)); + } + #[test] fn new_session_switch_selects_best_chatgpt_candidate() { let temp = tempfile::tempdir().expect("tempdir"); diff --git a/codex-rs/core/src/session/mod.rs b/codex-rs/core/src/session/mod.rs index cb1cea0f7c78..14abffa62656 100644 --- a/codex-rs/core/src/session/mod.rs +++ b/codex-rs/core/src/session/mod.rs @@ -46,8 +46,10 @@ use chrono::Utc; use codex_analytics::AnalyticsEventsClient; use codex_analytics::SubAgentThreadStartedInput; use codex_analytics::TurnCodexErrorFact; +use codex_app_server_protocol::AuthMode; use codex_app_server_protocol::McpServerElicitationRequest; use codex_app_server_protocol::McpServerElicitationRequestParams; +use codex_config::types::AuthCredentialsStoreMode; use codex_config::types::OAuthCredentialsStoreMode; use codex_exec_server::Environment; use codex_exec_server::EnvironmentManager; @@ -1059,6 +1061,15 @@ impl Session { state.session_configuration.codex_home().clone() } + pub(crate) async fn auth_home(&self) -> AbsolutePathBuf { + let state = self.state.lock().await; + state + .session_configuration + .original_config_do_not_use + .auth_home + .clone() + } + pub(crate) fn subscribe_out_of_band_elicitation_pause_state(&self) -> watch::Receiver { self.out_of_band_elicitation_paused.subscribe() } @@ -1488,6 +1499,37 @@ impl Session { .clone() } + pub(crate) async fn auto_switch_accounts_on_rate_limit(&self) -> bool { + let state = self.state.lock().await; + state + .session_configuration + .original_config_do_not_use + .auto_switch_accounts_on_rate_limit + } + + pub(crate) async fn api_key_fallback_on_all_accounts_limited(&self) -> bool { + let state = self.state.lock().await; + state + .session_configuration + .original_config_do_not_use + .api_key_fallback_on_all_accounts_limited + } + + pub(crate) async fn cli_auth_credentials_store_mode(&self) -> AuthCredentialsStoreMode { + let state = self.state.lock().await; + state + .session_configuration + .original_config_do_not_use + .cli_auth_credentials_store_mode + } + + pub(crate) fn current_auth_mode(&self) -> Option { + self.services + .auth_manager + .auth_cached() + .map(|auth| auth.auth_mode()) + } + pub(crate) async fn provider(&self) -> ModelProviderInfo { let state = self.state.lock().await; state.session_configuration.provider.clone() diff --git a/codex-rs/core/src/session/turn.rs b/codex-rs/core/src/session/turn.rs index 2ca37dfd4f55..4d037e8c2bc0 100644 --- a/codex-rs/core/src/session/turn.rs +++ b/codex-rs/core/src/session/turn.rs @@ -5,6 +5,8 @@ use std::sync::Arc; use std::sync::atomic::Ordering; use crate::SkillInjections; +use crate::account_switching::AccountSwitchOutcome; +use crate::account_switching::RateLimitSwitchState; use crate::build_skill_injections; use crate::client::ModelClientSession; use crate::client_common::Prompt; @@ -79,6 +81,7 @@ use codex_protocol::config_types::ModeKind; use codex_protocol::config_types::ServiceTier; use codex_protocol::error::CodexErr; use codex_protocol::error::Result as CodexResult; +use codex_protocol::error::UsageLimitReachedError; use codex_protocol::items::PlanItem; use codex_protocol::items::TurnItem; use codex_protocol::items::build_hook_prompt_message; @@ -198,6 +201,7 @@ pub(crate) async fn run_turn( // However, we defer that drain until after sampling in two cases: // 1. At the start of a turn, so the fresh turn input in `input` gets sampled first. // 2. After auto-compact, when model/tool continuation needs to resume before any steer. + let mut rate_limit_switch_state = RateLimitSwitchState::default(); loop { // Note that pending_input would be something like a message the user @@ -239,6 +243,7 @@ pub(crate) async fn run_turn( turn_metadata_header.as_deref(), sampling_request_input.clone(), auto_review_awareness_input_item.clone(), + &mut rate_limit_switch_state, cancellation_token.child_token(), ) .await @@ -1009,6 +1014,7 @@ async fn run_sampling_request( turn_metadata_header: Option<&str>, input: Vec, request_only_input_item: Option, + rate_limit_switch_state: &mut RateLimitSwitchState, cancellation_token: CancellationToken, ) -> CodexResult { let router = built_tools(sess.as_ref(), turn_context.as_ref(), &cancellation_token).await?; @@ -1080,6 +1086,17 @@ async fn run_sampling_request( if let Some(rate_limits) = rate_limits { sess.update_rate_limits(&turn_context, *rate_limits).await; } + if maybe_switch_account_after_usage_limit( + &sess, + client_session, + rate_limit_switch_state, + &e, + ) + .await + { + turn_context.turn_timing_state.record_sampling_retry(); + continue; + } return Err(CodexErr::UsageLimitReached(e)); } Err(err) => err, @@ -1103,6 +1120,63 @@ async fn run_sampling_request( } } +async fn maybe_switch_account_after_usage_limit( + sess: &Arc, + client_session: &mut ModelClientSession, + rate_limit_switch_state: &mut RateLimitSwitchState, + limit_err: &UsageLimitReachedError, +) -> bool { + if !sess.auto_switch_accounts_on_rate_limit().await { + return false; + } + if codex_login::auth::read_codex_api_key_from_env().is_some() { + return false; + } + let Some((codex_home, current_account_id, _)) = sess.account_usage_context().await else { + return false; + }; + + let now = chrono::Utc::now(); + let auth_home = sess.auth_home().await; + let allow_api_key_fallback = sess.api_key_fallback_on_all_accounts_limited().await; + let current_mode = sess + .current_auth_mode() + .unwrap_or(codex_app_server_protocol::AuthMode::ApiKey); + let auth_credentials_store_mode = sess.cli_auth_credentials_store_mode().await; + match crate::account_switching::switch_active_account_on_rate_limit( + codex_home.as_path(), + auth_home.as_path(), + rate_limit_switch_state, + allow_api_key_fallback, + now, + current_account_id.as_str(), + current_mode, + limit_err.resets_at, + auth_credentials_store_mode, + ) { + Ok(AccountSwitchOutcome::Switched(account)) => { + info!( + from_account_id = %current_account_id, + to_account_id = %account.id, + reason = "usage_limit_reached", + "usage limit hit; auto-switching active account" + ); + sess.services.auth_manager.reload().await; + *client_session = sess.services.model_client.new_session(); + true + } + Ok(AccountSwitchOutcome::NoCandidate) => false, + Err(err) => { + warn!( + from_account_id = %current_account_id, + error = %err, + "failed to auto-switch account after usage limit" + ); + false + } + } +} + #[expect( clippy::await_holding_invalid_type, reason = "tool router construction reads through the session-owned manager guard"