Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/cmd/server/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions backend/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ type DefaultConfig struct {

type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401临时不可调度冷却(分钟)
}

// APIKeyAuthCacheConfig API Key 认证缓存配置
Expand Down Expand Up @@ -1190,6 +1191,7 @@ func setDefaults() {

// RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10)

// Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移)
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")
Expand Down
28 changes: 22 additions & 6 deletions backend/internal/service/ratelimit_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,29 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} else {
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
}
// 3. 临时不可调度,替代 SetError(保持 status=active 让刷新服务能拾取)
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "OAuth 401: " + upstreamMsg
}
cooldownMinutes := s.cfg.RateLimit.OAuth401CooldownMinutes
if cooldownMinutes <= 0 {
cooldownMinutes = 10
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil {
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
}
shouldDisable = true
} else {
// 非 OAuth 账号(APIKey):保持原有 SetError 行为
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
shouldDisable = true
}
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
shouldDisable = true
case 402:
// 支付要求:余额不足或计费问题,停止调度
msg := "Payment required (402): insufficient balance or billing issue"
Expand Down
10 changes: 5 additions & 5 deletions backend/internal/service/ratelimit_service_401_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
return r.err
}

func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
tests := []struct {
name string
platform string
Expand Down Expand Up @@ -76,9 +76,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))

require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)")
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1)
})
}
Expand All @@ -98,7 +97,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))

require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1)
}

Expand Down
53 changes: 33 additions & 20 deletions backend/internal/service/token_refresh_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ type TokenRefreshService struct {
refreshers []TokenRefresher
cfg *config.TokenRefreshConfig
cacheInvalidator TokenCacheInvalidator
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存

stopCh chan struct{}
wg sync.WaitGroup
Expand All @@ -34,12 +35,14 @@ func NewTokenRefreshService(
cacheInvalidator TokenCacheInvalidator,
schedulerCache SchedulerCache,
cfg *config.Config,
tempUnschedCache TempUnschedCache,
) *TokenRefreshService {
s := &TokenRefreshService{
accountRepo: accountRepo,
cfg: &cfg.TokenRefresh,
cacheInvalidator: cacheInvalidator,
schedulerCache: schedulerCache,
tempUnschedCache: tempUnschedCache,
stopCh: make(chan struct{}),
}

Expand Down Expand Up @@ -231,6 +234,26 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
}
}
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
if clearErr := s.accountRepo.ClearTempUnschedulable(ctx, account.ID); clearErr != nil {
slog.Warn("token_refresh.clear_temp_unschedulable_failed",
"account_id", account.ID,
"error", clearErr,
)
} else {
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
}
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
if s.tempUnschedCache != nil {
if clearErr := s.tempUnschedCache.DeleteTempUnsched(ctx, account.ID); clearErr != nil {
slog.Warn("token_refresh.clear_temp_unsched_cache_failed",
"account_id", account.ID,
"error", clearErr,
)
}
}
}
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
Expand All @@ -257,8 +280,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
return nil
}

// Antigravity 账户:不可重试错误直接标记 error 状态并返回
if account.Platform == PlatformAntigravity && isNonRetryableRefreshError(err) {
// 不可重试错误(invalid_grant/invalid_client 等)直接标记 error 状态并返回
if isNonRetryableRefreshError(err) {
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
slog.Error("token_refresh.set_error_status_failed",
Expand All @@ -285,23 +308,13 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
}
}

// Antigravity 账户:其他错误仅记录日志,不标记 error(可能是临时网络问题)
// 其他平台账户:重试失败后标记 error
if account.Platform == PlatformAntigravity {
slog.Warn("token_refresh.retry_exhausted_antigravity",
"account_id", account.ID,
"max_retries", s.cfg.MaxRetries,
"error", lastErr,
)
} else {
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
slog.Error("token_refresh.set_error_status_failed",
"account_id", account.ID,
"error", err,
)
}
}
// 可重试错误耗尽:仅记录日志,不标记 error(可能是临时网络问题,下个周期继续重试)
slog.Warn("token_refresh.retry_exhausted",
"account_id", account.ID,
"platform", account.Platform,
"max_retries", s.cfg.MaxRetries,
"error", lastErr,
)

return lastErr
}
Expand Down
Loading
Loading