diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 996ea1a77..7de7fec64 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -255,6 +255,12 @@ func (h *Handler) ListAuthFiles(c *gin.Context) { h.listAuthFilesFromDisk(c) return } + // Sync from master if configured (triggered by frontend refresh) + if h.authManager.GetCredentialMaster() != "" && h.cfg != nil && h.cfg.AuthDir != "" { + if err := h.authManager.SyncAuthsFromMaster(c.Request.Context(), h.cfg.AuthDir); err != nil { + log.Debugf("ListAuthFiles: failed to sync from master: %v", err) + } + } auths := h.authManager.List() files := make([]gin.H, 0, len(auths)) for _, auth := range auths { diff --git a/internal/api/handlers/management/credential_sync.go b/internal/api/handlers/management/credential_sync.go new file mode 100644 index 000000000..57abc16c0 --- /dev/null +++ b/internal/api/handlers/management/credential_sync.go @@ -0,0 +1,94 @@ +package management + +import ( + "crypto/subtle" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +// PeerAuthMiddleware returns a middleware for peer-to-peer authentication. +// Both master and follower share the same secret-key value (typically a bcrypt hash), +// and this middleware does constant-time string comparison (not bcrypt verification). +// This differs from Middleware() which does bcrypt verification for human users. +func (h *Handler) PeerAuthMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if h == nil || h.cfg == nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "peer authentication not configured"}) + return + } + expected := h.cfg.RemoteManagement.SecretKey + if expected == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "peer authentication not configured"}) + return + } + + // Accept Authorization: Bearer or X-Peer-Secret header + var provided string + if auth := c.GetHeader("Authorization"); auth != "" { + parts := strings.SplitN(auth, " ", 2) + if len(parts) == 2 && strings.EqualFold(parts[0], "bearer") { + provided = parts[1] + } + } + if provided == "" { + provided = c.GetHeader("X-Peer-Secret") + } + + if provided == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing peer secret"}) + return + } + if subtle.ConstantTimeCompare([]byte(provided), []byte(expected)) != 1 { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid peer secret"}) + return + } + c.Next() + } +} + +// HandleCredentialQuery returns the current access_token for a given auth ID. +// This endpoint is used by follower nodes to fetch credentials from master. +func (h *Handler) HandleCredentialQuery(c *gin.Context) { + if h == nil || h.authManager == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "server not initialized"}) + return + } + + id := c.Query("id") + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "id parameter is required"}) + return + } + + h.authManager.RefreshIfNeeded(c.Request.Context(), id) + + accessToken := h.authManager.GetAccessToken(id) + if accessToken == "" { + c.JSON(http.StatusNotFound, gin.H{"error": "credential not found or no access_token"}) + return + } + + response := gin.H{ + "id": id, + "access_token": accessToken, + } + if expiredAt, ok := h.authManager.GetExpirationTime(id); ok && !expiredAt.IsZero() { + response["expired"] = expiredAt.Format(time.RFC3339) + } + c.JSON(http.StatusOK, response) +} + +// HandleAuthList returns all auth entries (without refresh_token). +// This endpoint is used by follower nodes for startup sync. +func (h *Handler) HandleAuthList(c *gin.Context) { + if h == nil || h.authManager == nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "server not initialized"}) + return + } + + auths := h.authManager.GetAllAuthsForSync() + c.JSON(http.StatusOK, gin.H{"auths": auths}) +} diff --git a/internal/api/server.go b/internal/api/server.go index f9a2abdd8..9330978b9 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -422,6 +422,13 @@ func (s *Server) setupRoutes() { c.String(http.StatusOK, oauthCallbackSuccessHTML) }) + // Internal credential query endpoint for master-follower mode. + // Uses PeerAuthMiddleware (constant-time hash comparison) instead of Middleware (bcrypt), + // because both master and follower share the same secret-key hash value. + internal := s.engine.Group("/v0/internal", s.mgmt.PeerAuthMiddleware()) + internal.GET("/credential", s.mgmt.HandleCredentialQuery) + internal.GET("/auth-list", s.mgmt.HandleAuthList) + // Management routes are registered lazily by registerManagementRoutes when a secret is configured. } @@ -1044,3 +1051,5 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc { } } } + + diff --git a/internal/config/config.go b/internal/config/config.go index dcf6b1f76..05a8959ba 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -67,6 +67,11 @@ type Config struct { // DisableCooling disables quota cooldown scheduling when true. DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"` + // CredentialMaster specifies the master node URL for credential synchronization. + // When set, this node acts as a follower and fetches access_token from master on 401 errors. + // Example: "http://192.168.1.100:8317" + CredentialMaster string `yaml:"credential-master" json:"credential-master"` + // RequestRetry defines the retry times when the request failed. RequestRetry int `yaml:"request-retry" json:"request-retry"` // MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential. diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index edb1f124d..162023af8 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -670,6 +670,33 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) { log.Debugf("Resumed client %s for model %s", clientID, modelID) } +// ResumeClientAllModels clears all suspensions for a client across all models. +// This is useful when credentials are refreshed or updated from master. +func (r *ModelRegistry) ResumeClientAllModels(clientID string) { + if clientID == "" { + return + } + r.mutex.Lock() + defer r.mutex.Unlock() + + now := time.Now() + resumed := 0 + for modelID, registration := range r.models { + if registration == nil || registration.SuspendedClients == nil { + continue + } + if _, ok := registration.SuspendedClients[clientID]; ok { + delete(registration.SuspendedClients, clientID) + registration.LastUpdated = now + resumed++ + log.Debugf("Resumed client %s for model %s (bulk)", clientID, modelID) + } + } + if resumed > 0 { + log.Debugf("Resumed client %s for %d models", clientID, resumed) + } +} + // ClientSupportsModel reports whether the client registered support for modelID. func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool { clientID = strings.TrimSpace(clientID) diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 5b76d02ae..ec4f0e7bd 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -169,11 +169,17 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) + // Decode response body (handle gzip if needed) + errorBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding")) + if decErr != nil { + log.Warnf("failed to decode error response body: %v", decErr) + errorBody = httpResp.Body + } + b, _ := io.ReadAll(errorBody) appendAPIResponseChunk(ctx, e.cfg, b) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) err = statusErr{code: httpResp.StatusCode, msg: string(b)} - if errClose := httpResp.Body.Close(); errClose != nil { + if errClose := errorBody.Close(); errClose != nil { log.Errorf("response body close error: %v", errClose) } return resp, err @@ -309,10 +315,16 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) + // Decode response body (handle gzip if needed) + errorBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding")) + if decErr != nil { + log.Warnf("failed to decode error response body: %v", decErr) + errorBody = httpResp.Body + } + b, _ := io.ReadAll(errorBody) appendAPIResponseChunk(ctx, e.cfg, b) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - if errClose := httpResp.Body.Close(); errClose != nil { + if errClose := errorBody.Close(); errClose != nil { log.Errorf("response body close error: %v", errClose) } err = statusErr{code: httpResp.StatusCode, msg: string(b)} @@ -489,7 +501,6 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut } func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - log.Debugf("claude executor: refresh called") if auth == nil { return nil, fmt.Errorf("claude executor: auth is nil") } @@ -955,6 +966,7 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte { } else { payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions)) } + return payload } diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 3a64c8c34..6286dc417 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -207,6 +207,8 @@ func (m *Manager) SetConfig(cfg *internalconfig.Config) { m.rebuildAPIKeyModelAliasFromRuntimeConfig() } + + func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) string { if m == nil { return "" @@ -569,6 +571,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req routeModel := req.Model opts = ensureRequestedModelMetadata(opts, routeModel) tried := make(map[string]struct{}) + fetchedFromMaster := make(map[string]struct{}) // Track auths that already fetched from master var lastErr error for { auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried) @@ -592,6 +595,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req execReq.Model = rewriteModelForAuth(routeModel, auth) execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) + resp, errExec := executor.Execute(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -606,6 +610,16 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req if ra := retryAfterFromError(errExec); ra != nil { result.RetryAfter = ra } + + statusCode := 0 + if result.Error != nil { + statusCode = result.Error.HTTPStatus + } + if m.tryFetchFromMasterOnUnauthorized(ctx, statusCode, auth.ID, provider, fetchedFromMaster) { + delete(tried, auth.ID) + continue + } + m.MarkResult(execCtx, result) lastErr = errExec continue @@ -675,6 +689,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string routeModel := req.Model opts = ensureRequestedModelMetadata(opts, routeModel) tried := make(map[string]struct{}) + fetchedFromMaster := make(map[string]struct{}) // Track auths that already fetched from master var lastErr error for { auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried) @@ -708,6 +723,12 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string if errors.As(errStream, &se) && se != nil { rerr.HTTPStatus = se.StatusCode() } + + if m.tryFetchFromMasterOnUnauthorized(ctx, rerr.HTTPStatus, auth.ID, provider, fetchedFromMaster) { + delete(tried, auth.ID) + continue + } + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} result.RetryAfter = retryAfterFromError(errStream) m.MarkResult(execCtx, result) @@ -1179,11 +1200,10 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { statusCode := statusCodeFromResult(result.Error) switch statusCode { case 401: - next := now.Add(30 * time.Minute) - state.NextRetryAfter = next + state.NextRetryAfter = now.Add(30 * time.Minute) suspendReason = "unauthorized" shouldSuspendModel = true - case 402, 403: + case 402: next := now.Add(30 * time.Minute) state.NextRetryAfter = next suspendReason = "payment_required" @@ -1983,7 +2003,19 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) { exec = m.executors[auth.Provider] } m.mu.RUnlock() - if auth == nil || exec == nil { + if auth == nil { + return + } + + // Follower nodes (with credential-master configured) fetch from master instead of local refresh + if m.GetCredentialMaster() != "" { + // Errors are logged inside fetchCredentialFromMaster, will retry on next refresh cycle or on 401 + m.fetchCredentialFromMaster(ctx, id, auth.Provider) + return + } + + // Master node: perform local refresh using executor + if exec == nil { return } cloned := auth.Clone() diff --git a/sdk/cliproxy/auth/credential_master.go b/sdk/cliproxy/auth/credential_master.go new file mode 100644 index 000000000..6ddfed797 --- /dev/null +++ b/sdk/cliproxy/auth/credential_master.go @@ -0,0 +1,332 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + log "github.com/sirupsen/logrus" +) + +// AuthSyncData represents auth data for sync (without refresh_token). +type AuthSyncData struct { + ID string `json:"id"` + Provider string `json:"provider"` + Metadata map[string]any `json:"metadata"` +} + +// tryFetchFromMasterOnUnauthorized attempts to fetch credentials from master on 401 errors. +// Returns true if retry should happen (fetch succeeded and auth not yet retried). +// The fetched map tracks which auth IDs have already been fetched to prevent infinite loops. +func (m *Manager) tryFetchFromMasterOnUnauthorized(ctx context.Context, statusCode int, authID, provider string, fetched map[string]struct{}) bool { + if statusCode != 401 || m.GetCredentialMaster() == "" { + return false + } + if _, alreadyFetched := fetched[authID]; alreadyFetched { + log.Warnf("got %d again after fetching from master, not retrying", statusCode) + return false + } + log.Infof("got %d, fetching credential from master and retrying...", statusCode) + fetched[authID] = struct{}{} + if err := m.fetchCredentialFromMaster(ctx, authID, provider); err != nil { + log.Warnf("failed to fetch credential from master: %v", err) + return false + } + return true +} + +// GetCredentialMaster returns the configured master node URL from runtime config. +func (m *Manager) GetCredentialMaster() string { + if m == nil { + return "" + } + cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config) + if cfg == nil { + return "" + } + return strings.TrimSpace(cfg.CredentialMaster) +} + +// getPeerSecret returns the peer secret from runtime config. +func (m *Manager) getPeerSecret() string { + if m == nil { + return "" + } + cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config) + if cfg == nil { + return "" + } + return cfg.RemoteManagement.SecretKey +} + +// GetAccessToken returns the access_token for a given auth ID. +// Used by master node to serve credential queries from followers. +func (m *Manager) GetAccessToken(id string) string { + if m == nil || id == "" { + return "" + } + m.mu.RLock() + auth, ok := m.auths[id] + m.mu.RUnlock() + if !ok || auth == nil || auth.Metadata == nil { + return "" + } + if at, ok := auth.Metadata["access_token"].(string); ok { + return at + } + return "" +} + +// RefreshIfNeeded checks if the token for the given auth ID needs refresh, +// and refreshes it if necessary. This is called by master node when serving +// credential queries from followers, ensuring tokens are refreshed on-demand +// even when master itself is not making API requests. +func (m *Manager) RefreshIfNeeded(ctx context.Context, id string) { + if m == nil || id == "" { + return + } + m.mu.RLock() + auth := m.auths[id] + m.mu.RUnlock() + if auth == nil { + return + } + now := time.Now() + if m.shouldRefresh(auth, now) { + log.Debugf("RefreshIfNeeded: token needs refresh for %s, triggering refresh", id) + m.refreshAuth(ctx, id) + } +} + +// GetExpirationTime returns the expiration time for a given auth ID. +// Used by master node to include expiration info in credential responses. +func (m *Manager) GetExpirationTime(id string) (time.Time, bool) { + if m == nil || id == "" { + return time.Time{}, false + } + m.mu.RLock() + auth := m.auths[id] + m.mu.RUnlock() + if auth == nil { + return time.Time{}, false + } + return auth.ExpirationTime() +} + +// GetAllAuthsForSync returns all auth entries for sync (without refresh_token). +func (m *Manager) GetAllAuthsForSync() []AuthSyncData { + if m == nil { + return nil + } + m.mu.RLock() + defer m.mu.RUnlock() + + result := make([]AuthSyncData, 0, len(m.auths)) + for _, auth := range m.auths { + if auth == nil || auth.Disabled { + continue + } + syncData := AuthSyncData{ + ID: auth.ID, + Provider: auth.Provider, + Metadata: sanitizeMetadataForSync(auth.Metadata), + } + result = append(result, syncData) + } + return result +} + +// sanitizeMetadataForSync removes sensitive fields like refresh_token. +func sanitizeMetadataForSync(meta map[string]any) map[string]any { + if meta == nil { + return nil + } + result := make(map[string]any, len(meta)) + for k, v := range meta { + if k == "refresh_token" || k == "refreshToken" { + continue + } + result[k] = v + } + return result +} + +// fetchCredentialFromMaster fetches the latest access_token from master node. +func (m *Manager) fetchCredentialFromMaster(ctx context.Context, id, provider string) error { + master := m.GetCredentialMaster() + if master == "" { + return errors.New("credential-master not configured") + } + secret := m.getPeerSecret() + if secret == "" { + return errors.New("peer secret not configured") + } + + url := strings.TrimRight(master, "/") + "/v0/internal/credential?id=" + id + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+secret) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return errors.New("master returned " + resp.Status + ": " + string(body)) + } + + var result struct { + ID string `json:"id"` + AccessToken string `json:"access_token"` + Expired string `json:"expired"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return err + } + if result.AccessToken == "" { + return errors.New("master returned empty access_token") + } + + m.mu.Lock() + auth, ok := m.auths[id] + if !ok || auth == nil { + m.mu.Unlock() + return errors.New("auth not found locally") + } + if auth.Metadata == nil { + auth.Metadata = make(map[string]any) + } + auth.Metadata["access_token"] = result.AccessToken + auth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339) + if result.Expired != "" { + auth.Metadata["expired"] = result.Expired + } + auth.UpdatedAt = time.Now() + auth.LastRefreshedAt = time.Now() + auth.LastError = nil + auth.Status = StatusActive + auth.Unavailable = false + auth.NextRetryAfter = time.Time{} + auth.ModelStates = nil + m.mu.Unlock() + + _ = m.persist(ctx, auth) + registry.GetGlobalRegistry().ResumeClientAllModels(id) + + log.Infof("fetched access_token from master: provider=%s, id=%s", provider, id) + m.hook.OnAuthUpdated(ctx, auth.Clone()) + return nil +} + +// SyncAuthsFromMaster syncs all auth entries from master node at startup. +// It writes auth files to the local auth directory for file watcher to pick up. +func (m *Manager) SyncAuthsFromMaster(ctx context.Context, authDir string) error { + master := m.GetCredentialMaster() + if master == "" { + return nil + } + secret := m.getPeerSecret() + if secret == "" { + log.Warnf("SyncAuthsFromMaster: peer secret not configured") + return errors.New("peer secret not configured") + } + log.Infof("SyncAuthsFromMaster: syncing from %s with authDir=%s", master, authDir) + + url := strings.TrimRight(master, "/") + "/v0/internal/auth-list" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+secret) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return errors.New("master returned " + resp.Status + ": " + string(body)) + } + + var result struct { + Auths []AuthSyncData `json:"auths"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return err + } + + for _, syncData := range result.Auths { + auth := syncDataToAuth(syncData, authDir) + m.mu.Lock() + m.auths[auth.ID] = auth + m.mu.Unlock() + + if err := writeAuthToFile(authDir, syncData); err != nil { + log.Warnf("failed to write auth file %s: %v", syncData.ID, err) + } + + registry.GetGlobalRegistry().ResumeClientAllModels(auth.ID) + m.hook.OnAuthUpdated(ctx, auth.Clone()) + } + + log.Infof("synced %d auths from master", len(result.Auths)) + return nil +} + +// syncDataToAuth converts AuthSyncData to Auth for memory storage. +func syncDataToAuth(data AuthSyncData, authDir string) *Auth { + now := time.Now() + filename := data.ID + if !strings.HasSuffix(filename, ".json") { + filename += ".json" + } + return &Auth{ + ID: data.ID, + Provider: data.Provider, + FileName: filename, + Metadata: data.Metadata, + Status: StatusActive, + CreatedAt: now, + UpdatedAt: now, + Attributes: map[string]string{ + "path": filepath.Join(authDir, filename), + }, + } +} + +// writeAuthToFile writes an auth entry to local file for persistence. +func writeAuthToFile(authDir string, syncData AuthSyncData) error { + if authDir == "" || syncData.ID == "" { + return nil + } + + filename := syncData.ID + if !strings.HasSuffix(filename, ".json") { + filename += ".json" + } + filePath := filepath.Join(authDir, filename) + + data, err := json.MarshalIndent(syncData.Metadata, "", " ") + if err != nil { + return err + } + + return os.WriteFile(filePath, data, 0600) +} diff --git a/sdk/cliproxy/auth/credential_master_test.go b/sdk/cliproxy/auth/credential_master_test.go new file mode 100644 index 000000000..11acb8275 --- /dev/null +++ b/sdk/cliproxy/auth/credential_master_test.go @@ -0,0 +1,264 @@ +package auth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func TestSanitizeMetadataForSync(t *testing.T) { + meta := map[string]any{ + "access_token": "tok123", + "refresh_token": "ref456", + "refreshToken": "ref789", + "expired": "2026-01-01T00:00:00Z", + } + result := sanitizeMetadataForSync(meta) + if _, ok := result["refresh_token"]; ok { + t.Error("refresh_token should be stripped") + } + if _, ok := result["refreshToken"]; ok { + t.Error("refreshToken should be stripped") + } + if result["access_token"] != "tok123" { + t.Error("access_token should be preserved") + } + if result["expired"] != "2026-01-01T00:00:00Z" { + t.Error("expired should be preserved") + } +} + +func TestSanitizeMetadataForSync_Nil(t *testing.T) { + if sanitizeMetadataForSync(nil) != nil { + t.Error("nil input should return nil") + } +} + +func TestSyncDataToAuth(t *testing.T) { + data := AuthSyncData{ + ID: "test-auth", + Provider: "claude", + Metadata: map[string]any{"access_token": "abc"}, + } + auth := syncDataToAuth(data, "/tmp/auths") + if auth.ID != "test-auth" { + t.Errorf("expected ID test-auth, got %s", auth.ID) + } + if auth.Provider != "claude" { + t.Errorf("expected Provider claude, got %s", auth.Provider) + } + if auth.FileName != "test-auth.json" { + t.Errorf("expected FileName test-auth.json, got %s", auth.FileName) + } + if auth.Status != StatusActive { + t.Errorf("expected StatusActive, got %v", auth.Status) + } + if auth.Attributes["path"] != filepath.Join("/tmp/auths", "test-auth.json") { + t.Errorf("unexpected path attribute: %s", auth.Attributes["path"]) + } +} + +func TestSyncDataToAuth_AlreadyHasJsonSuffix(t *testing.T) { + data := AuthSyncData{ID: "my.json", Provider: "claude"} + auth := syncDataToAuth(data, "/tmp") + if auth.FileName != "my.json" { + t.Errorf("should not double-append .json, got %s", auth.FileName) + } +} + +func TestWriteAuthToFile(t *testing.T) { + dir := t.TempDir() + data := AuthSyncData{ + ID: "write-test", + Provider: "claude", + Metadata: map[string]any{"access_token": "tok"}, + } + if err := writeAuthToFile(dir, data); err != nil { + t.Fatalf("writeAuthToFile failed: %v", err) + } + content, err := os.ReadFile(filepath.Join(dir, "write-test.json")) + if err != nil { + t.Fatalf("failed to read written file: %v", err) + } + var meta map[string]any + if err := json.Unmarshal(content, &meta); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if meta["access_token"] != "tok" { + t.Errorf("expected access_token=tok, got %v", meta["access_token"]) + } +} + +func TestWriteAuthToFile_EmptyDir(t *testing.T) { + if err := writeAuthToFile("", AuthSyncData{ID: "x"}); err != nil { + t.Error("empty authDir should return nil") + } +} + +func TestWriteAuthToFile_EmptyID(t *testing.T) { + if err := writeAuthToFile("/tmp", AuthSyncData{}); err != nil { + t.Error("empty ID should return nil") + } +} + +func TestFetchCredentialFromMaster(t *testing.T) { + master := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v0/internal/credential" { + http.NotFound(w, r) + return + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-secret" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + id := r.URL.Query().Get("id") + json.NewEncoder(w).Encode(map[string]string{ + "id": id, + "access_token": "new-token-123", + "expired": "2026-12-31T23:59:59Z", + }) + })) + defer master.Close() + + mgr := NewManager(nil, nil, nil) + mgr.SetConfig(&config.Config{ + CredentialMaster: master.URL, + RemoteManagement: config.RemoteManagement{SecretKey: "test-secret"}, + }) + + // Register a local auth entry + mgr.mu.Lock() + mgr.auths["auth-1"] = &Auth{ + ID: "auth-1", + Provider: "claude", + Metadata: map[string]any{"access_token": "old-token"}, + Status: StatusActive, + } + mgr.mu.Unlock() + + err := mgr.fetchCredentialFromMaster(context.Background(), "auth-1", "claude") + if err != nil { + t.Fatalf("fetchCredentialFromMaster failed: %v", err) + } + + mgr.mu.RLock() + auth := mgr.auths["auth-1"] + mgr.mu.RUnlock() + + if at, ok := auth.Metadata["access_token"].(string); !ok || at != "new-token-123" { + t.Errorf("expected new-token-123, got %v", auth.Metadata["access_token"]) + } + if exp, ok := auth.Metadata["expired"].(string); !ok || exp != "2026-12-31T23:59:59Z" { + t.Errorf("expected expired field to be updated, got %v", auth.Metadata["expired"]) + } +} + +func TestFetchCredentialFromMaster_NoMaster(t *testing.T) { + mgr := NewManager(nil, nil, nil) + err := mgr.fetchCredentialFromMaster(context.Background(), "x", "claude") + if err == nil || err.Error() != "credential-master not configured" { + t.Errorf("expected 'credential-master not configured', got %v", err) + } +} + +func TestFetchCredentialFromMaster_NoSecret(t *testing.T) { + mgr := NewManager(nil, nil, nil) + mgr.SetConfig(&config.Config{ + CredentialMaster: "http://localhost:9999", + // No SecretKey configured + }) + err := mgr.fetchCredentialFromMaster(context.Background(), "x", "claude") + if err == nil || err.Error() != "peer secret not configured" { + t.Errorf("expected 'peer secret not configured', got %v", err) + } +} + +func TestGetAllAuthsForSync(t *testing.T) { + mgr := NewManager(nil, nil, nil) + mgr.mu.Lock() + mgr.auths["a1"] = &Auth{ + ID: "a1", Provider: "claude", + Metadata: map[string]any{"access_token": "t1", "refresh_token": "rt1"}, + } + mgr.auths["a2"] = &Auth{ + ID: "a2", Provider: "claude", Disabled: true, + Metadata: map[string]any{"access_token": "t2"}, + } + mgr.mu.Unlock() + + result := mgr.GetAllAuthsForSync() + if len(result) != 1 { + t.Fatalf("expected 1 auth (disabled excluded), got %d", len(result)) + } + if result[0].ID != "a1" { + t.Errorf("expected a1, got %s", result[0].ID) + } + if _, ok := result[0].Metadata["refresh_token"]; ok { + t.Error("refresh_token should be stripped from sync data") + } +} + +func TestSyncAuthsFromMaster(t *testing.T) { + master := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v0/internal/auth-list" { + http.NotFound(w, r) + return + } + auth := r.Header.Get("Authorization") + if auth != "Bearer sync-secret" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + json.NewEncoder(w).Encode(map[string]any{ + "auths": []map[string]any{ + {"id": "sync-1", "provider": "claude", "metadata": map[string]any{"access_token": "st1"}}, + {"id": "sync-2", "provider": "claude", "metadata": map[string]any{"access_token": "st2"}}, + }, + }) + })) + defer master.Close() + + dir := t.TempDir() + mgr := NewManager(nil, nil, nil) + mgr.SetConfig(&config.Config{ + CredentialMaster: master.URL, + RemoteManagement: config.RemoteManagement{SecretKey: "sync-secret"}, + }) + + err := mgr.SyncAuthsFromMaster(context.Background(), dir) + if err != nil { + t.Fatalf("SyncAuthsFromMaster failed: %v", err) + } + + mgr.mu.RLock() + defer mgr.mu.RUnlock() + if len(mgr.auths) != 2 { + t.Errorf("expected 2 auths, got %d", len(mgr.auths)) + } + if mgr.auths["sync-1"] == nil || mgr.auths["sync-2"] == nil { + t.Error("expected both sync-1 and sync-2 to be registered") + } + + // Check files were written + for _, id := range []string{"sync-1", "sync-2"} { + path := filepath.Join(dir, id+".json") + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("expected file %s to exist", path) + } + } +} + +func TestSyncAuthsFromMaster_NoMaster(t *testing.T) { + mgr := NewManager(nil, nil, nil) + err := mgr.SyncAuthsFromMaster(context.Background(), "/tmp") + if err != nil { + t.Errorf("expected nil when no master configured, got %v", err) + } +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 4223b5b28..eb6316fa4 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -277,6 +277,7 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A return } auth = auth.Clone() + log.Debugf("applyCoreAuthAddOrUpdate: id=%s", auth.ID) s.ensureExecutorsForAuth(auth) // IMPORTANT: Update coreManager FIRST, before model registration. @@ -593,6 +594,14 @@ func (s *Service) Run(ctx context.Context) error { watcherCtx, watcherCancel := context.WithCancel(context.Background()) s.watcherCancel = watcherCancel + + // Sync auth files from master node if configured (before starting file watcher) + if s.coreManager != nil { + if err := s.coreManager.SyncAuthsFromMaster(ctx, s.cfg.AuthDir); err != nil { + log.Warnf("failed to sync auths from master: %v", err) + } + } + if err = watcherWrapper.Start(watcherCtx); err != nil { return fmt.Errorf("cliproxy: failed to start watcher: %w", err) }