Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions internal/api/handlers/management/auth_files.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ func (h *Handler) ListAuthFiles(c *gin.Context) {
h.listAuthFilesFromDisk(c)
return
}
// Sync from master if configured (triggered by frontend refresh)
if h.authManager.GetCredentialMaster() != "" && h.cfg != nil && h.cfg.AuthDir != "" {
if err := h.authManager.SyncAuthsFromMaster(c.Request.Context(), h.cfg.AuthDir); err != nil {
log.Debugf("ListAuthFiles: failed to sync from master: %v", err)
}
}
auths := h.authManager.List()
files := make([]gin.H, 0, len(auths))
for _, auth := range auths {
Expand Down
94 changes: 94 additions & 0 deletions internal/api/handlers/management/credential_sync.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package management

import (
"crypto/subtle"
"net/http"
"strings"
"time"

"github.com/gin-gonic/gin"
)

// PeerAuthMiddleware returns a middleware for peer-to-peer authentication.
// Both master and follower share the same secret-key value (typically a bcrypt hash),
// and this middleware does constant-time string comparison (not bcrypt verification).
// This differs from Middleware() which does bcrypt verification for human users.
func (h *Handler) PeerAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if h == nil || h.cfg == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "peer authentication not configured"})
return
}
expected := h.cfg.RemoteManagement.SecretKey
if expected == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "peer authentication not configured"})
return
}

// Accept Authorization: Bearer <secret> or X-Peer-Secret header
var provided string
if auth := c.GetHeader("Authorization"); auth != "" {
parts := strings.SplitN(auth, " ", 2)
if len(parts) == 2 && strings.EqualFold(parts[0], "bearer") {
provided = parts[1]
}
}
if provided == "" {
provided = c.GetHeader("X-Peer-Secret")
}

if provided == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing peer secret"})
return
}
if subtle.ConstantTimeCompare([]byte(provided), []byte(expected)) != 1 {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid peer secret"})
return
}
c.Next()
}
}

// HandleCredentialQuery returns the current access_token for a given auth ID.
// This endpoint is used by follower nodes to fetch credentials from master.
func (h *Handler) HandleCredentialQuery(c *gin.Context) {
if h == nil || h.authManager == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "server not initialized"})
return
}

id := c.Query("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id parameter is required"})
return
}

h.authManager.RefreshIfNeeded(c.Request.Context(), id)

accessToken := h.authManager.GetAccessToken(id)
if accessToken == "" {
c.JSON(http.StatusNotFound, gin.H{"error": "credential not found or no access_token"})
return
}

response := gin.H{
"id": id,
"access_token": accessToken,
}
if expiredAt, ok := h.authManager.GetExpirationTime(id); ok && !expiredAt.IsZero() {
response["expired"] = expiredAt.Format(time.RFC3339)
}
c.JSON(http.StatusOK, response)
}

// HandleAuthList returns all auth entries (without refresh_token).
// This endpoint is used by follower nodes for startup sync.
func (h *Handler) HandleAuthList(c *gin.Context) {
if h == nil || h.authManager == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "server not initialized"})
return
}

auths := h.authManager.GetAllAuthsForSync()
c.JSON(http.StatusOK, gin.H{"auths": auths})
}
9 changes: 9 additions & 0 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,13 @@ func (s *Server) setupRoutes() {
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})

// Internal credential query endpoint for master-follower mode.
// Uses PeerAuthMiddleware (constant-time hash comparison) instead of Middleware (bcrypt),
// because both master and follower share the same secret-key hash value.
internal := s.engine.Group("/v0/internal", s.mgmt.PeerAuthMiddleware())
internal.GET("/credential", s.mgmt.HandleCredentialQuery)
internal.GET("/auth-list", s.mgmt.HandleAuthList)

// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
}

Expand Down Expand Up @@ -1044,3 +1051,5 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
}
}
}


5 changes: 5 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ type Config struct {
// DisableCooling disables quota cooldown scheduling when true.
DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"`

// CredentialMaster specifies the master node URL for credential synchronization.
// When set, this node acts as a follower and fetches access_token from master on 401 errors.
// Example: "http://192.168.1.100:8317"
CredentialMaster string `yaml:"credential-master" json:"credential-master"`

// RequestRetry defines the retry times when the request failed.
RequestRetry int `yaml:"request-retry" json:"request-retry"`
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
Expand Down
27 changes: 27 additions & 0 deletions internal/registry/model_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,33 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
log.Debugf("Resumed client %s for model %s", clientID, modelID)
}

// ResumeClientAllModels clears all suspensions for a client across all models.
// This is useful when credentials are refreshed or updated from master.
func (r *ModelRegistry) ResumeClientAllModels(clientID string) {
if clientID == "" {
return
}
r.mutex.Lock()
defer r.mutex.Unlock()

now := time.Now()
resumed := 0
for modelID, registration := range r.models {
if registration == nil || registration.SuspendedClients == nil {
continue
}
if _, ok := registration.SuspendedClients[clientID]; ok {
delete(registration.SuspendedClients, clientID)
registration.LastUpdated = now
resumed++
log.Debugf("Resumed client %s for model %s (bulk)", clientID, modelID)
}
}
if resumed > 0 {
log.Debugf("Resumed client %s for %d models", clientID, resumed)
}
}

// ClientSupportsModel reports whether the client registered support for modelID.
func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
clientID = strings.TrimSpace(clientID)
Expand Down
22 changes: 17 additions & 5 deletions internal/runtime/executor/claude_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,17 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
// Decode response body (handle gzip if needed)
errorBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if decErr != nil {
log.Warnf("failed to decode error response body: %v", decErr)
errorBody = httpResp.Body
}
b, _ := io.ReadAll(errorBody)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
if errClose := httpResp.Body.Close(); errClose != nil {
if errClose := errorBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
return resp, err
Expand Down Expand Up @@ -309,10 +315,16 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
// Decode response body (handle gzip if needed)
errorBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if decErr != nil {
log.Warnf("failed to decode error response body: %v", decErr)
errorBody = httpResp.Body
}
b, _ := io.ReadAll(errorBody)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
if errClose := errorBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
Expand Down Expand Up @@ -489,7 +501,6 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
}

func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
log.Debugf("claude executor: refresh called")
if auth == nil {
return nil, fmt.Errorf("claude executor: auth is nil")
}
Expand Down Expand Up @@ -955,6 +966,7 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
} else {
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
}

return payload
}

Expand Down
40 changes: 36 additions & 4 deletions sdk/cliproxy/auth/conductor.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ func (m *Manager) SetConfig(cfg *internalconfig.Config) {
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
}



func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) string {
if m == nil {
return ""
Expand Down Expand Up @@ -569,6 +571,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
fetchedFromMaster := make(map[string]struct{}) // Track auths that already fetched from master
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
Expand All @@ -592,6 +595,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)

resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
Expand All @@ -606,6 +610,16 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}

statusCode := 0
if result.Error != nil {
statusCode = result.Error.HTTPStatus
}
if m.tryFetchFromMasterOnUnauthorized(ctx, statusCode, auth.ID, provider, fetchedFromMaster) {
delete(tried, auth.ID)
continue
}

m.MarkResult(execCtx, result)
lastErr = errExec
continue
Expand Down Expand Up @@ -675,6 +689,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
fetchedFromMaster := make(map[string]struct{}) // Track auths that already fetched from master
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
Expand Down Expand Up @@ -708,6 +723,12 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
if errors.As(errStream, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}

if m.tryFetchFromMasterOnUnauthorized(ctx, rerr.HTTPStatus, auth.ID, provider, fetchedFromMaster) {
delete(tried, auth.ID)
continue
}

result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(execCtx, result)
Expand Down Expand Up @@ -1179,11 +1200,10 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
statusCode := statusCodeFromResult(result.Error)
switch statusCode {
case 401:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
state.NextRetryAfter = now.Add(30 * time.Minute)
suspendReason = "unauthorized"
shouldSuspendModel = true
case 402, 403:
case 402:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "payment_required"
Expand Down Expand Up @@ -1983,7 +2003,19 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
exec = m.executors[auth.Provider]
}
m.mu.RUnlock()
if auth == nil || exec == nil {
if auth == nil {
return
}

// Follower nodes (with credential-master configured) fetch from master instead of local refresh
if m.GetCredentialMaster() != "" {
// Errors are logged inside fetchCredentialFromMaster, will retry on next refresh cycle or on 401
m.fetchCredentialFromMaster(ctx, id, auth.Provider)
return
}

// Master node: perform local refresh using executor
if exec == nil {
return
}
cloned := auth.Clone()
Expand Down
Loading