Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 136 additions & 10 deletions codex-rs/core/src/account_switching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,21 @@ use codex_protocol::protocol::RateLimitReachedType;
#[derive(Debug, Default)]
pub struct RateLimitSwitchState {
tried_accounts: HashSet<String>,
limited_chatgpt_accounts: HashSet<String>,
blocked_until: HashMap<String, DateTime<Utc>>,
}

impl RateLimitSwitchState {
pub fn mark_limited(&mut self, account_id: &str, blocked_until: Option<DateTime<Utc>>) {
pub fn mark_limited(
&mut self,
account_id: &str,
mode: AuthMode,
blocked_until: Option<DateTime<Utc>>,
) {
self.tried_accounts.insert(account_id.to_string());
if mode.has_chatgpt_account() {
self.limited_chatgpt_accounts.insert(account_id.to_string());
}
if let Some(until) = blocked_until {
self.blocked_until
.entry(account_id.to_string())
Expand All @@ -30,6 +39,8 @@ impl RateLimitSwitchState {
}
})
.or_insert(until);
} else {
self.blocked_until.remove(account_id);
}
}

Expand All @@ -40,6 +51,10 @@ impl RateLimitSwitchState {
fn blocked_until(&self, account_id: &str) -> Option<DateTime<Utc>> {
self.blocked_until.get(account_id).copied()
}

fn is_chatgpt_limited(&self, account_id: &str) -> bool {
self.limited_chatgpt_accounts.contains(account_id)
}
}

#[derive(Debug, Clone, Copy, PartialEq)]
Expand Down Expand Up @@ -201,16 +216,17 @@ fn has_unexpired_tried_marker(

pub fn select_next_account_id(
codex_home: &Path,
auth_home: &Path,
state: &RateLimitSwitchState,
allow_api_key_fallback: bool,
now: DateTime<Utc>,
current_account_id: Option<&str>,
) -> io::Result<Option<String>> {
let current = match current_account_id {
Some(id) => Some(id.to_string()),
None => codex_login::get_active_account_id(codex_home)?,
None => codex_login::get_active_account_id(auth_home)?,
};
let accounts = codex_login::list_accounts(codex_home)?;
let accounts = codex_login::list_accounts(auth_home)?;

let snapshots = account_usage::list_rate_limit_snapshots(codex_home).unwrap_or_default();
let snapshot_map: HashMap<String, account_usage::StoredRateLimitSnapshot> = snapshots
Expand Down Expand Up @@ -239,6 +255,9 @@ pub fn select_next_account_id(
if current.is_some_and(|id| id == account.id) {
continue;
}
if state.is_chatgpt_limited(&account.id) {
continue;
}
if has_unexpired_tried_marker(state, &account.id, now) {
continue;
}
Expand All @@ -265,9 +284,11 @@ pub fn select_next_account_id(
}

let all_chatgpt_unavailable = chatgpt_accounts.iter().all(|account| {
current.is_some_and(|id| id == account.id)
|| has_unexpired_tried_marker(state, &account.id, now)
|| is_blocked(now, blocked_until_for(state, &snapshot_map, &account.id))
let blocked_until = blocked_until_for(state, &snapshot_map, &account.id);
let blocked = is_blocked(now, blocked_until);
let exhausted = state.is_chatgpt_limited(&account.id);
let tried = state.has_tried(&account.id);
current.is_some_and(|id| id == account.id) || blocked || (tried && exhausted)
});
if !all_chatgpt_unavailable {
return Ok(None);
Expand Down Expand Up @@ -347,23 +368,26 @@ pub fn switch_active_account_to_preferred_for_new_session(

pub fn switch_active_account_on_rate_limit(
codex_home: &Path,
auth_home: &Path,
state: &mut RateLimitSwitchState,
allow_api_key_fallback: bool,
now: DateTime<Utc>,
current_account_id: &str,
current_mode: AuthMode,
blocked_until: Option<DateTime<Utc>>,
auth_credentials_store_mode: AuthCredentialsStoreMode,
) -> io::Result<AccountSwitchOutcome> {
state.mark_limited(current_account_id, blocked_until);
state.mark_limited(current_account_id, current_mode, blocked_until);
match select_next_account_id(
codex_home,
auth_home,
state,
allow_api_key_fallback,
now,
Some(current_account_id),
)? {
Some(account_id) => {
codex_login::activate_account(codex_home, &account_id, auth_credentials_store_mode)
codex_login::activate_account(auth_home, &account_id, auth_credentials_store_mode)
.map(AccountSwitchOutcome::Switched)
}
None => Ok(AccountSwitchOutcome::NoCandidate),
Expand Down Expand Up @@ -486,6 +510,7 @@ mod tests {
.expect("record faster");

let selected = select_next_account_id(
temp.path(),
temp.path(),
&RateLimitSwitchState::default(),
false,
Expand All @@ -506,6 +531,7 @@ mod tests {
let _api_key = upsert_api_key(temp.path(), "sk-test");

let selected = select_next_account_id(
temp.path(),
temp.path(),
&RateLimitSwitchState::default(),
true,
Expand Down Expand Up @@ -535,6 +561,7 @@ mod tests {
.expect("record usage hint");

let selected = select_next_account_id(
temp.path(),
temp.path(),
&RateLimitSwitchState::default(),
true,
Expand All @@ -554,10 +581,15 @@ mod tests {
let first_api_key = upsert_api_key(temp.path(), "sk-first");
let second_api_key = upsert_api_key(temp.path(), "sk-second");
let mut state = RateLimitSwitchState::default();
state.mark_limited(&first_api_key, Some(now + Duration::hours(1)));
state.mark_limited(
&first_api_key,
AuthMode::ApiKey,
Some(now + Duration::hours(1)),
);

let selected =
select_next_account_id(temp.path(), &state, true, now, Some(&current)).expect("select");
select_next_account_id(temp.path(), temp.path(), &state, true, now, Some(&current))
.expect("select");

assert_eq!(selected, Some(second_api_key));
}
Expand All @@ -582,11 +614,13 @@ mod tests {

let mut state = RateLimitSwitchState::default();
let outcome = switch_active_account_on_rate_limit(
temp.path(),
temp.path(),
&mut state,
false,
now,
&current,
AuthMode::Chatgpt,
Some(now + Duration::hours(1)),
AuthCredentialsStoreMode::File,
)
Expand All @@ -607,6 +641,98 @@ mod tests {
assert!(state.has_tried(&current));
}

#[test]
fn api_key_fallback_requires_all_chatgpt_accounts_marked_limited() {
let temp = tempfile::tempdir().expect("tempdir");
let now = Utc::now();
let current = upsert_chatgpt(temp.path(), "current");
let candidate = upsert_chatgpt(temp.path(), "candidate");
let api_key = upsert_api_key(temp.path(), "sk-test");
let mut state = RateLimitSwitchState::default();
state.mark_limited(&current, AuthMode::Chatgpt, None);

let selected =
select_next_account_id(temp.path(), temp.path(), &state, true, now, Some(&current))
.expect("select");
assert_eq!(selected, Some(candidate.clone()));

state.mark_limited(&candidate, AuthMode::Chatgpt, None);
let selected = select_next_account_id(
temp.path(),
temp.path(),
&state,
true,
now,
Some(&candidate),
)
.expect("select");
assert_eq!(selected, Some(api_key));
}

#[test]
fn limited_chatgpt_account_is_not_reselected_without_reset_hint() {
let temp = tempfile::tempdir().expect("tempdir");
let now = Utc::now();
let current = upsert_chatgpt(temp.path(), "current");
let candidate = upsert_chatgpt(temp.path(), "candidate");
let api_key = upsert_api_key(temp.path(), "sk-test");
let mut state = RateLimitSwitchState::default();
state.mark_limited(&current, AuthMode::Chatgpt, None);

let selected =
select_next_account_id(temp.path(), temp.path(), &state, true, now, Some(&current))
.expect("select");
assert_eq!(selected, Some(candidate.clone()));

state.mark_limited(&candidate, AuthMode::Chatgpt, None);
let selected = select_next_account_id(
temp.path(),
temp.path(),
&state,
true,
now,
Some(&candidate),
)
.expect("select");
assert_eq!(selected, Some(api_key));
}

#[test]
fn limited_chatgpt_account_is_not_reselected_after_expired_reset_hint() {
let temp = tempfile::tempdir().expect("tempdir");
let now = Utc::now();
let current = upsert_chatgpt(temp.path(), "current");
let candidate = upsert_chatgpt(temp.path(), "candidate");
let api_key = upsert_api_key(temp.path(), "sk-test");
let mut state = RateLimitSwitchState::default();
state.mark_limited(
&current,
AuthMode::Chatgpt,
Some(now - Duration::minutes(1)),
);

let selected =
select_next_account_id(temp.path(), temp.path(), &state, true, now, Some(&current))
.expect("select");
assert_eq!(selected, Some(candidate.clone()));

state.mark_limited(
&candidate,
AuthMode::Chatgpt,
Some(now - Duration::minutes(1)),
);
let selected = select_next_account_id(
temp.path(),
temp.path(),
&state,
true,
now,
Some(&candidate),
)
.expect("select");
assert_eq!(selected, Some(api_key));
}

#[test]
fn new_session_switch_selects_best_chatgpt_candidate() {
let temp = tempfile::tempdir().expect("tempdir");
Expand Down
42 changes: 42 additions & 0 deletions codex-rs/core/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ use chrono::Utc;
use codex_analytics::AnalyticsEventsClient;
use codex_analytics::SubAgentThreadStartedInput;
use codex_analytics::TurnCodexErrorFact;
use codex_app_server_protocol::AuthMode;
use codex_app_server_protocol::McpServerElicitationRequest;
use codex_app_server_protocol::McpServerElicitationRequestParams;
use codex_config::types::AuthCredentialsStoreMode;
use codex_config::types::OAuthCredentialsStoreMode;
use codex_exec_server::Environment;
use codex_exec_server::EnvironmentManager;
Expand Down Expand Up @@ -1059,6 +1061,15 @@ impl Session {
state.session_configuration.codex_home().clone()
}

pub(crate) async fn auth_home(&self) -> AbsolutePathBuf {
let state = self.state.lock().await;
state
.session_configuration
.original_config_do_not_use
.auth_home
.clone()
}

pub(crate) fn subscribe_out_of_band_elicitation_pause_state(&self) -> watch::Receiver<bool> {
self.out_of_band_elicitation_paused.subscribe()
}
Expand Down Expand Up @@ -1488,6 +1499,37 @@ impl Session {
.clone()
}

pub(crate) async fn auto_switch_accounts_on_rate_limit(&self) -> bool {
let state = self.state.lock().await;
state
.session_configuration
.original_config_do_not_use
.auto_switch_accounts_on_rate_limit
}

pub(crate) async fn api_key_fallback_on_all_accounts_limited(&self) -> bool {
let state = self.state.lock().await;
state
.session_configuration
.original_config_do_not_use
.api_key_fallback_on_all_accounts_limited
}

pub(crate) async fn cli_auth_credentials_store_mode(&self) -> AuthCredentialsStoreMode {
let state = self.state.lock().await;
state
.session_configuration
.original_config_do_not_use
.cli_auth_credentials_store_mode
}

pub(crate) fn current_auth_mode(&self) -> Option<AuthMode> {
self.services
.auth_manager
.auth_cached()
.map(|auth| auth.auth_mode())
}

pub(crate) async fn provider(&self) -> ModelProviderInfo {
let state = self.state.lock().await;
state.session_configuration.provider.clone()
Expand Down
Loading
Loading