diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index b5626ce9c0..7c243ebb91 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -125,6 +125,8 @@ func (m *AmpModule) Register(ctx modules.Context) error { m.registerOnce.Do(func() { // Initialize model mapper from config (for routing unavailable models to alternatives) m.modelMapper = NewModelMapper(settings.ModelMappings) + // Load oauth-model-alias for provider lookup via aliases + m.modelMapper.UpdateOAuthModelAlias(ctx.Config.OAuthModelAlias) // Store initial config for partial reload comparison settingsCopy := settings @@ -212,6 +214,11 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { } } + // Always update oauth-model-alias for model mapper (used for provider lookup) + if m.modelMapper != nil { + m.modelMapper.UpdateOAuthModelAlias(cfg.OAuthModelAlias) + } + if m.enabled { // Check upstream URL change - now supports hot-reload if newUpstreamURL == "" && oldUpstreamURL != "" { diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index 7d7f7f5f28..f46af1c0f4 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -2,12 +2,15 @@ package amp import ( "bytes" + "errors" "io" + "net/http" "net/http/httputil" "strings" "time" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" @@ -30,7 +33,13 @@ const ( ) // MappedModelContextKey is the Gin context key for passing mapped model names. -const MappedModelContextKey = "mapped_model" +// Deprecated: Use ctxkeys.MappedModel instead. +const MappedModelContextKey = string(ctxkeys.MappedModel) + +// FallbackModelsContextKey is the Gin context key for passing fallback model names. +// When the primary mapped model fails (e.g., quota exceeded), these models can be tried. +// Deprecated: Use ctxkeys.FallbackModels instead. +const FallbackModelsContextKey = string(ctxkeys.FallbackModels) // logAmpRouting logs the routing decision for an Amp request with structured fields func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) { @@ -77,6 +86,10 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid // FallbackHandler wraps a standard handler with fallback logic to ampcode.com // when the model's provider is not available in CLIProxyAPI +// +// Deprecated: FallbackHandler is deprecated in favor of routing.ModelRoutingWrapper. +// Use routing.NewModelRoutingWrapper() instead for unified routing logic. +// This type is kept for backward compatibility and test purposes. type FallbackHandler struct { getProxy func() *httputil.ReverseProxy modelMapper ModelMapper @@ -85,6 +98,8 @@ type FallbackHandler struct { // NewFallbackHandler creates a new fallback handler wrapper // The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes) +// +// Deprecated: Use routing.NewModelRoutingWrapper() instead. func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler { return &FallbackHandler{ getProxy: getProxy, @@ -93,6 +108,8 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler } // NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support +// +// Deprecated: Use routing.NewModelRoutingWrapper() instead. func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler { if forceModelMappings == nil { forceModelMappings = func() bool { return false } @@ -113,6 +130,20 @@ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) { // If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc { return func(c *gin.Context) { + // Swallow ErrAbortHandler panics from ReverseProxy to avoid noisy stack traces. + // ReverseProxy raises this panic when the client connection is closed prematurely + // (e.g., user cancels request, network disconnect) or when ServeHTTP is called + // with a ResponseWriter that doesn't implement http.CloseNotifier. + // This is an expected error condition, not a bug, so we handle it gracefully. + defer func() { + if rec := recover(); rec != nil { + if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) { + return + } + panic(rec) + } + }() + requestPath := c.Request.URL.Path // Read the request body to extract the model name @@ -142,36 +173,57 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc thinkingSuffix = "(" + suffixResult.RawSuffix + ")" } - resolveMappedModel := func() (string, []string) { + // resolveMappedModels returns all mapped models (primary + fallbacks) and providers for the first one. + resolveMappedModels := func() ([]string, []string) { if fh.modelMapper == nil { - return "", nil + return nil, nil } - mappedModel := fh.modelMapper.MapModel(modelName) - if mappedModel == "" { - mappedModel = fh.modelMapper.MapModel(normalizedModel) + mapper, ok := fh.modelMapper.(*DefaultModelMapper) + if !ok { + // Fallback to single model for non-DefaultModelMapper + mappedModel := fh.modelMapper.MapModel(modelName) + if mappedModel == "" { + mappedModel = fh.modelMapper.MapModel(normalizedModel) + } + if mappedModel == "" { + return nil, nil + } + mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName + mappedProviders := util.GetProviderName(mappedBaseModel) + if len(mappedProviders) == 0 { + return nil, nil + } + return []string{mappedModel}, mappedProviders } - mappedModel = strings.TrimSpace(mappedModel) - if mappedModel == "" { - return "", nil + + // Use MapModelWithFallbacks for DefaultModelMapper + mappedModels := mapper.MapModelWithFallbacks(modelName) + if len(mappedModels) == 0 { + mappedModels = mapper.MapModelWithFallbacks(normalizedModel) + } + if len(mappedModels) == 0 { + return nil, nil } - // Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target - // already specifies its own thinking suffix. - if thinkingSuffix != "" { - mappedSuffixResult := thinking.ParseSuffix(mappedModel) - if !mappedSuffixResult.HasSuffix { - mappedModel += thinkingSuffix + // Apply thinking suffix if needed + for i, model := range mappedModels { + if thinkingSuffix != "" { + suffixResult := thinking.ParseSuffix(model) + if !suffixResult.HasSuffix { + mappedModels[i] = model + thinkingSuffix + } } } - mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName - mappedProviders := util.GetProviderName(mappedBaseModel) - if len(mappedProviders) == 0 { - return "", nil + // Get providers for the first model + firstBaseModel := thinking.ParseSuffix(mappedModels[0]).ModelName + providers := util.GetProviderName(firstBaseModel) + if len(providers) == 0 { + return nil, nil } - return mappedModel, mappedProviders + return mappedModels, providers } // Track resolved model for logging (may change if mapping is applied) @@ -179,21 +231,27 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc usedMapping := false var providers []string + // Helper to apply model mapping and update state + applyMapping := func(mappedModels []string, mappedProviders []string) { + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0]) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + c.Set(string(ctxkeys.MappedModel), mappedModels[0]) + if len(mappedModels) > 1 { + c.Set(string(ctxkeys.FallbackModels), mappedModels[1:]) + } + resolvedModel = mappedModels[0] + usedMapping = true + providers = mappedProviders + } + // Check if model mappings should be forced ahead of local API keys forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings() if forceMappings { // FORCE MODE: Check model mappings FIRST (takes precedence over local API keys) // This allows users to route Amp requests to their preferred OAuth providers - if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { - // Mapping found and provider available - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - providers = mappedProviders + if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 { + applyMapping(mappedModels, mappedProviders) } // If no mapping applied, check for local providers @@ -206,15 +264,8 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc if len(providers) == 0 { // No providers configured - check if we have a model mapping - if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { - // Mapping found and provider available - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - providers = mappedProviders + if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 { + applyMapping(mappedModels, mappedProviders) } } } diff --git a/internal/api/modules/amp/fallback_handlers_characterization_test.go b/internal/api/modules/amp/fallback_handlers_characterization_test.go new file mode 100644 index 0000000000..e52bc5cef2 --- /dev/null +++ b/internal/api/modules/amp/fallback_handlers_characterization_test.go @@ -0,0 +1,326 @@ +package amp + +import ( + "bytes" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/routing/testutil" + "github.com/stretchr/testify/assert" +) + +// Characterization tests for fallback_handlers.go using testutil recorders +// These tests capture existing behavior before refactoring to routing layer + +func TestCharacterization_LocalProvider(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Register a mock provider for the test model + reg := registry.GetGlobalRegistry() + reg.RegisterClient("char-test-local", "anthropic", []*registry.ModelInfo{ + {ID: "test-model-local"}, + }) + defer reg.UnregisterClient("char-test-local") + + // Setup recorders + proxyRecorder := testutil.NewFakeProxyRecorder() + handlerRecorder := testutil.NewFakeHandlerRecorder() + + // Create gin context + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body := `{"model": "test-model-local", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Create fallback handler with proxy recorder + // Create a test server to act as the proxy target + proxyServer := httptest.NewServer(proxyRecorder.ToHandler()) + defer proxyServer.Close() + + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + // Create a reverse proxy that forwards to our test server + targetURL, _ := url.Parse(proxyServer.URL) + return httputil.NewSingleHostReverseProxy(targetURL) + }) + + // Execute + wrapped := fh.WrapHandler(handlerRecorder.GinHandler()) + wrapped(c) + + // Assert: proxy NOT called + assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider") + + // Assert: local handler called once + assert.True(t, handlerRecorder.WasCalled(), "local handler should be called") + assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once") + + // Assert: request body model unchanged + assert.Contains(t, string(handlerRecorder.RequestBody), "test-model-local", "request body model should be unchanged") +} + +func TestCharacterization_ModelMapping(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Register a mock provider for the TARGET model (the mapped-to model) + reg := registry.GetGlobalRegistry() + reg.RegisterClient("char-test-mapped", "openai", []*registry.ModelInfo{ + {ID: "gpt-4-local"}, + }) + defer reg.UnregisterClient("char-test-mapped") + + // Setup recorders + proxyRecorder := testutil.NewFakeProxyRecorder() + handlerRecorder := testutil.NewFakeHandlerRecorder() + + // Create model mapper with a mapping + mapper := NewModelMapper([]config.AmpModelMapping{ + {From: "gpt-4-turbo", To: "gpt-4-local"}, + }) + + // Create gin context + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + // Request with original model that gets mapped + body := `{"model": "gpt-4-turbo", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Create fallback handler with mapper + proxyServer := httptest.NewServer(proxyRecorder.ToHandler()) + defer proxyServer.Close() + + fh := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { + targetURL, _ := url.Parse(proxyServer.URL) + return httputil.NewSingleHostReverseProxy(targetURL) + }, mapper, func() bool { return false }) + + // Execute - use handler that returns model in response for rewriter to work + wrapped := fh.WrapHandler(handlerRecorder.GinHandlerWithModel()) + wrapped(c) + + // Assert: proxy NOT called + assert.False(t, proxyRecorder.Called, "proxy should NOT be called for model mapping") + + // Assert: local handler called once + assert.True(t, handlerRecorder.WasCalled(), "local handler should be called") + assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once") + + // Assert: request body model was rewritten to mapped model + assert.Contains(t, string(handlerRecorder.RequestBody), "gpt-4-local", "request body model should be rewritten to mapped model") + assert.NotContains(t, string(handlerRecorder.RequestBody), "gpt-4-turbo", "request body should NOT contain original model") + + // Assert: context has mapped_model key set + mappedModel, exists := handlerRecorder.GetContextKey("mapped_model") + assert.True(t, exists, "context should have mapped_model key") + assert.Equal(t, "gpt-4-local", mappedModel, "mapped_model should be the target model") + + // Assert: response body model rewritten back to original + // The response writer should rewrite model names in the response + responseBody := w.Body.String() + assert.Contains(t, responseBody, "gpt-4-turbo", "response should have original model name") +} + +func TestCharacterization_AmpCreditsProxy(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Setup recorders - NO local provider registered, NO mapping configured + proxyRecorder := testutil.NewFakeProxyRecorder() + handlerRecorder := testutil.NewFakeHandlerRecorder() + + // Create gin context with CloseNotifier support (required for ReverseProxy) + w := testutil.NewCloseNotifierRecorder() + c, _ := gin.CreateTestContext(w) + + // Request with a model that has no local provider and no mapping + body := `{"model": "unknown-model-no-provider", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Create fallback handler + proxyServer := httptest.NewServer(proxyRecorder.ToHandler()) + defer proxyServer.Close() + + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + targetURL, _ := url.Parse(proxyServer.URL) + return httputil.NewSingleHostReverseProxy(targetURL) + }) + + // Execute + wrapped := fh.WrapHandler(handlerRecorder.GinHandler()) + wrapped(c) + + // Assert: proxy called once + assert.True(t, proxyRecorder.Called, "proxy should be called when no local provider and no mapping") + assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once") + + // Assert: local handler NOT called + assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called when falling back to proxy") + + // Assert: body forwarded to proxy is original (no rewrite) + assert.Contains(t, string(proxyRecorder.RequestBody), "unknown-model-no-provider", "request body model should be unchanged when proxying") +} + +func TestCharacterization_BodyRestore(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Register a mock provider for the test model + reg := registry.GetGlobalRegistry() + reg.RegisterClient("char-test-body", "anthropic", []*registry.ModelInfo{ + {ID: "test-model-body"}, + }) + defer reg.UnregisterClient("char-test-body") + + // Setup recorders + proxyRecorder := testutil.NewFakeProxyRecorder() + handlerRecorder := testutil.NewFakeHandlerRecorder() + + // Create gin context + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + // Create a complex request body that will be read by the wrapper for model extraction + originalBody := `{"model": "test-model-body", "messages": [{"role": "user", "content": "hello"}], "temperature": 0.7, "stream": true}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(originalBody))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Create fallback handler with proxy recorder + proxyServer := httptest.NewServer(proxyRecorder.ToHandler()) + defer proxyServer.Close() + + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + targetURL, _ := url.Parse(proxyServer.URL) + return httputil.NewSingleHostReverseProxy(targetURL) + }) + + // Execute + wrapped := fh.WrapHandler(handlerRecorder.GinHandler()) + wrapped(c) + + // Assert: local handler called (not proxy, since we have a local provider) + assert.True(t, handlerRecorder.WasCalled(), "local handler should be called") + assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider") + + // Assert: handler receives complete original body + // This verifies that the body was properly restored after the wrapper read it for model extraction + assert.Equal(t, originalBody, string(handlerRecorder.RequestBody), "handler should receive complete original body after wrapper reads it for model extraction") +} + +// TestCharacterization_GeminiV1Beta1_PostModels tests that POST requests with /models/ path use Gemini bridge handler +// This is a characterization test for the route gating logic in routes.go +func TestCharacterization_GeminiV1Beta1_PostModels(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Register a mock provider for the test model (Gemini format uses path-based model extraction) + reg := registry.GetGlobalRegistry() + reg.RegisterClient("char-test-gemini", "google", []*registry.ModelInfo{ + {ID: "gemini-pro"}, + }) + defer reg.UnregisterClient("char-test-gemini") + + // Setup recorders + proxyRecorder := testutil.NewFakeProxyRecorder() + handlerRecorder := testutil.NewFakeHandlerRecorder() + + // Create a test server for the proxy + proxyServer := httptest.NewServer(proxyRecorder.ToHandler()) + defer proxyServer.Close() + + // Create fallback handler + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + targetURL, _ := url.Parse(proxyServer.URL) + return httputil.NewSingleHostReverseProxy(targetURL) + }) + + // Create the Gemini bridge handler (simulating what routes.go does) + geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler()) + geminiV1Beta1Handler := fh.WrapHandler(geminiBridge) + + // Create router with the same gating logic as routes.go + r := gin.New() + r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) { + if c.Request.Method == "POST" { + if path := c.Param("path"); strings.Contains(path, "/models/") { + // POST with /models/ path -> use Gemini bridge with fallback handler + geminiV1Beta1Handler(c) + return + } + } + // Non-POST or no /models/ in path -> proxy upstream + proxyRecorder.ServeHTTP(c.Writer, c.Request) + }) + + // Execute: POST request with /models/ in path + body := `{"contents": [{"role": "user", "parts": [{"text": "hello"}]}]}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro:generateContent", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Assert: local Gemini handler called + assert.True(t, handlerRecorder.WasCalled(), "local Gemini handler should be called for POST /models/") + + // Assert: proxy NOT called + assert.False(t, proxyRecorder.Called, "proxy should NOT be called for POST /models/ path") +} + +// TestCharacterization_GeminiV1Beta1_GetProxies tests that GET requests to Gemini v1beta1 always use proxy +// This is a characterization test for the route gating logic in routes.go +func TestCharacterization_GeminiV1Beta1_GetProxies(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Setup recorders + proxyRecorder := testutil.NewFakeProxyRecorder() + handlerRecorder := testutil.NewFakeHandlerRecorder() + + // Create a test server for the proxy + proxyServer := httptest.NewServer(proxyRecorder.ToHandler()) + defer proxyServer.Close() + + // Create fallback handler + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + targetURL, _ := url.Parse(proxyServer.URL) + return httputil.NewSingleHostReverseProxy(targetURL) + }) + + // Create the Gemini bridge handler + geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler()) + geminiV1Beta1Handler := fh.WrapHandler(geminiBridge) + + // Create router with the same gating logic as routes.go + r := gin.New() + r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) { + if c.Request.Method == "POST" { + if path := c.Param("path"); strings.Contains(path, "/models/") { + geminiV1Beta1Handler(c) + return + } + } + proxyRecorder.ServeHTTP(c.Writer, c.Request) + }) + + // Execute: GET request (even with /models/ in path) + req := httptest.NewRequest(http.MethodGet, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Assert: proxy called + assert.True(t, proxyRecorder.Called, "proxy should be called for GET requests") + assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once") + + // Assert: local handler NOT called + assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called for GET requests") +} diff --git a/internal/api/modules/amp/fallback_handlers_test.go b/internal/api/modules/amp/fallback_handlers_test.go index a687fd116b..eef73cbd9b 100644 --- a/internal/api/modules/amp/fallback_handlers_test.go +++ b/internal/api/modules/amp/fallback_handlers_test.go @@ -2,7 +2,7 @@ package amp import ( "bytes" - "encoding/json" + "io" "net/http" "net/http/httptest" "net/http/httputil" @@ -11,63 +11,138 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/stretchr/testify/assert" ) -func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) { +// Characterization tests for fallback_handlers.go +// These tests capture existing behavior before refactoring to routing layer + +func TestFallbackHandler_WrapHandler_LocalProvider_NoMapping(t *testing.T) { gin.SetMode(gin.TestMode) - reg := registry.GetGlobalRegistry() - reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{ - {ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"}, + // Setup: model that has local providers (gemini-2.5-pro is registered) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body := `{"model": "gemini-2.5-pro", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Handler that should be called (not proxy) + handlerCalled := false + handler := func(c *gin.Context) { + handlerCalled = true + c.JSON(200, gin.H{"status": "ok"}) + } + + // Create fallback handler + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + return nil // no proxy }) - defer reg.UnregisterClient("test-client-amp-fallback") - mapper := NewModelMapper([]config.AmpModelMapping{ - {From: "gpt-5.2", To: "test/gpt-5.2"}, + // Execute + wrapped := fh.WrapHandler(handler) + wrapped(c) + + // Assert: handler should be called directly (no mapping needed) + assert.True(t, handlerCalled, "handler should be called for local provider") + assert.Equal(t, 200, w.Code) +} + +func TestFallbackHandler_WrapHandler_MappingApplied(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Register a mock provider for the target model + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client", "anthropic", []*registry.ModelInfo{ + {ID: "claude-opus-4-5-thinking"}, }) - fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil) + // Setup: model that needs mapping + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + body := `{"model": "claude-opus-4-5-20251101", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Handler to capture rewritten body + var capturedBody []byte handler := func(c *gin.Context) { - var req struct { - Model string `json:"model"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "model": req.Model, - "seen_model": req.Model, - }) + capturedBody, _ = io.ReadAll(c.Request.Body) + c.JSON(200, gin.H{"status": "ok"}) } - r := gin.New() - r.POST("/chat/completions", fallback.WrapHandler(handler)) + // Create fallback handler with mapper + mapper := NewModelMapper([]config.AmpModelMapping{ + {From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"}, + }) + + fh := NewFallbackHandlerWithMapper( + func() *httputil.ReverseProxy { return nil }, + mapper, + func() bool { return false }, + ) + + // Execute + wrapped := fh.WrapHandler(handler) + wrapped(c) + + // Assert: body should be rewritten + assert.Contains(t, string(capturedBody), "claude-opus-4-5-thinking") + + // Assert: context should have mapped model + mappedModel, exists := c.Get(MappedModelContextKey) + assert.True(t, exists, "MappedModelContextKey should be set") + assert.NotEmpty(t, mappedModel) +} + +func TestFallbackHandler_WrapHandler_ThinkingSuffixPreserved(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Register a mock provider for the target model + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-2", "anthropic", []*registry.ModelInfo{ + {ID: "claude-opus-4-5-thinking"}, + }) - reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`) - req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody)) - req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - r.ServeHTTP(w, req) + c, _ := gin.CreateTestContext(w) - if w.Code != http.StatusOK { - t.Fatalf("Expected status 200, got %d", w.Code) - } + // Model with thinking suffix + body := `{"model": "claude-opus-4-5-20251101(xhigh)", "messages": []}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req - var resp struct { - Model string `json:"model"` - SeenModel string `json:"seen_model"` - } - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("Failed to parse response JSON: %v", err) + var capturedBody []byte + handler := func(c *gin.Context) { + capturedBody, _ = io.ReadAll(c.Request.Body) + c.JSON(200, gin.H{"status": "ok"}) } - if resp.Model != "gpt-5.2(xhigh)" { - t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model) - } - if resp.SeenModel != "test/gpt-5.2(xhigh)" { - t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel) - } + mapper := NewModelMapper([]config.AmpModelMapping{ + {From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"}, + }) + + fh := NewFallbackHandlerWithMapper( + func() *httputil.ReverseProxy { return nil }, + mapper, + func() bool { return false }, + ) + + wrapped := fh.WrapHandler(handler) + wrapped(c) + + // Assert: thinking suffix should be preserved + assert.Contains(t, string(capturedBody), "(xhigh)") +} + +func TestFallbackHandler_WrapHandler_NoProvider_NoMapping_ProxyEnabled(t *testing.T) { + // Skip: httptest.ResponseRecorder doesn't implement http.CloseNotifier + // which is required by httputil.ReverseProxy. This test requires a real + // HTTP server and client to properly test proxy behavior. + t.Skip("requires real HTTP server for proxy testing") } diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go index 4159a2b576..b8d4743296 100644 --- a/internal/api/modules/amp/model_mapping.go +++ b/internal/api/modules/amp/model_mapping.go @@ -30,18 +30,98 @@ type DefaultModelMapper struct { mu sync.RWMutex mappings map[string]string // exact: from -> to (normalized lowercase keys) regexps []regexMapping // regex rules evaluated in order + + // oauthAliasForward maps channel -> name (lower) -> []alias for oauth-model-alias lookup. + // This allows model-mappings targets to find providers via their aliases. + oauthAliasForward map[string]map[string][]string } // NewModelMapper creates a new model mapper with the given initial mappings. func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper { m := &DefaultModelMapper{ - mappings: make(map[string]string), - regexps: nil, + mappings: make(map[string]string), + regexps: nil, + oauthAliasForward: nil, } m.UpdateMappings(mappings) return m } +// UpdateOAuthModelAlias updates the oauth-model-alias lookup table. +// This is called during initialization and on config hot-reload. +func (m *DefaultModelMapper) UpdateOAuthModelAlias(aliases map[string][]config.OAuthModelAlias) { + m.mu.Lock() + defer m.mu.Unlock() + + if len(aliases) == 0 { + m.oauthAliasForward = nil + return + } + + forward := make(map[string]map[string][]string, len(aliases)) + for rawChannel, entries := range aliases { + channel := strings.ToLower(strings.TrimSpace(rawChannel)) + if channel == "" || len(entries) == 0 { + continue + } + channelMap := make(map[string][]string) + for _, entry := range entries { + name := strings.TrimSpace(entry.Name) + alias := strings.TrimSpace(entry.Alias) + if name == "" || alias == "" { + continue + } + if strings.EqualFold(name, alias) { + continue + } + nameKey := strings.ToLower(name) + channelMap[nameKey] = append(channelMap[nameKey], alias) + } + if len(channelMap) > 0 { + forward[channel] = channelMap + } + } + if len(forward) == 0 { + m.oauthAliasForward = nil + return + } + m.oauthAliasForward = forward + log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward)) +} + +// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel +// that have available providers. Useful for fallback when one alias is quota-exceeded. +func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string { + if m.oauthAliasForward == nil { + return nil + } + + targetKey := strings.ToLower(strings.TrimSpace(targetModel)) + if targetKey == "" { + return nil + } + + var result []string + seen := make(map[string]struct{}) + + // Check all channels for this model name + for _, channelMap := range m.oauthAliasForward { + aliases := channelMap[targetKey] + for _, alias := range aliases { + aliasLower := strings.ToLower(alias) + if _, exists := seen[aliasLower]; exists { + continue + } + providers := util.GetProviderName(alias) + if len(providers) > 0 { + result = append(result, alias) + seen[aliasLower] = struct{}{} + } + } + } + return result +} + // MapModel checks if a mapping exists for the requested model and if the // target model has available local providers. Returns the mapped model name // or empty string if no valid mapping exists. @@ -51,9 +131,20 @@ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper { // However, if the mapping target already contains a suffix, the config suffix // takes priority over the user's suffix. func (m *DefaultModelMapper) MapModel(requestedModel string) string { - if requestedModel == "" { + models := m.MapModelWithFallbacks(requestedModel) + if len(models) == 0 { return "" } + return models[0] +} + +// MapModelWithFallbacks returns all possible target models for the requested model, +// including fallback aliases from oauth-model-alias. The first model is the primary target, +// and subsequent models are fallbacks to try if the primary is unavailable (e.g., quota exceeded). +func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []string { + if requestedModel == "" { + return nil + } m.mu.RLock() defer m.mu.RUnlock() @@ -78,34 +169,54 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string { } } if !exists { - return "" + return nil } } // Check if target model already has a thinking suffix (config priority) targetResult := thinking.ParseSuffix(targetModel) + targetBase := targetResult.ModelName + + // Helper to apply suffix to a model + applySuffix := func(model string) string { + modelResult := thinking.ParseSuffix(model) + if modelResult.HasSuffix { + return model + } + if requestResult.HasSuffix && requestResult.RawSuffix != "" { + return model + "(" + requestResult.RawSuffix + ")" + } + return model + } // Verify target model has available providers (use base model for lookup) - providers := util.GetProviderName(targetResult.ModelName) - if len(providers) == 0 { - log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel) - return "" + providers := util.GetProviderName(targetBase) + + // If direct provider available, return it as primary + if len(providers) > 0 { + return []string{applySuffix(targetModel)} } - // Suffix handling: config suffix takes priority, otherwise preserve user suffix - if targetResult.HasSuffix { - // Config's "to" already contains a suffix - use it as-is (config priority) - return targetModel + // No direct providers - check oauth-model-alias for all aliases that have providers + allAliases := m.findAllAliasesWithProviders(targetBase) + if len(allAliases) == 0 { + log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel) + return nil } - // Preserve user's thinking suffix on the mapped model - // (skip empty suffixes to avoid returning "model()") - if requestResult.HasSuffix && requestResult.RawSuffix != "" { - return targetModel + "(" + requestResult.RawSuffix + ")" + // Log resolution + if len(allAliases) == 1 { + log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0]) + } else { + log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases)-1) } - // Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go - return targetModel + // Apply suffix to all aliases + result := make([]string, len(allAliases)) + for i, alias := range allAliases { + result[i] = applySuffix(alias) + } + return result } // UpdateMappings refreshes the mapping configuration from config. @@ -165,6 +276,22 @@ func (m *DefaultModelMapper) GetMappings() map[string]string { return result } +// GetMappingsAsConfig returns the current model mappings as config.AmpModelMapping slice. +// Safe for concurrent use. +func (m *DefaultModelMapper) GetMappingsAsConfig() []config.AmpModelMapping { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make([]config.AmpModelMapping, 0, len(m.mappings)) + for from, to := range m.mappings { + result = append(result, config.AmpModelMapping{ + From: from, + To: to, + }) + } + return result +} + type regexMapping struct { re *regexp.Regexp to string diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 456a50ac12..790a3cce3f 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -5,11 +5,12 @@ import ( "errors" "net" "net/http" - "net/http/httputil" "strings" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" + "github.com/router-for-me/CLIProxyAPI/v6/internal/routing" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini" @@ -234,19 +235,20 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha // If no local OAuth is available, falls back to ampcode.com proxy. geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler) geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler) - geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { - return m.getProxy() - }, m.modelMapper, m.forceModelMappings) - geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) - // Route POST model calls through Gemini bridge with FallbackHandler. - // FallbackHandler checks provider -> mapping -> proxy fallback automatically. + // T-025: Migrated Gemini v1beta1 bridge to use ModelRoutingWrapper + // Create a dedicated routing wrapper for the Gemini bridge + geminiBridgeWrapper := m.createModelRoutingWrapper() + geminiV1Beta1Handler := geminiBridgeWrapper.Wrap(geminiBridge) + + // Route POST model calls through Gemini bridge with ModelRoutingWrapper. + // ModelRoutingWrapper checks provider -> mapping -> proxy fallback automatically. // All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior. ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) { if c.Request.Method == "POST" { if path := c.Param("path"); strings.Contains(path, "/models/") { - // POST with /models/ path -> use Gemini bridge with fallback handler - // FallbackHandler will check provider/mapping and proxy if needed + // POST with /models/ path -> use Gemini bridge with unified routing wrapper + // ModelRoutingWrapper will check provider/mapping and proxy if needed geminiV1Beta1Handler(c) return } @@ -256,6 +258,41 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha }) } +// createModelRoutingWrapper creates a new ModelRoutingWrapper for unified routing. +// This is used for testing the new routing implementation (T-021 onwards). +func (m *AmpModule) createModelRoutingWrapper() *routing.ModelRoutingWrapper { + // Create a registry - in production this would be populated with actual providers + registry := routing.NewRegistry() + + // Create a minimal config with just AmpCode settings + // The Router only needs AmpCode.ModelMappings and OAuthModelAlias + cfg := &config.Config{ + AmpCode: func() config.AmpCode { + if m.modelMapper != nil { + return config.AmpCode{ + ModelMappings: m.modelMapper.GetMappingsAsConfig(), + } + } + return config.AmpCode{} + }(), + } + + // Create router with registry and config + router := routing.NewRouter(registry, cfg) + + // Create wrapper with proxy function + proxyFunc := func(c *gin.Context) { + proxy := m.getProxy() + if proxy != nil { + proxy.ServeHTTP(c.Writer, c.Request) + } else { + c.JSON(503, gin.H{"error": "amp upstream proxy not available"}) + } + } + + return routing.NewModelRoutingWrapper(router, nil, nil, proxyFunc) +} + // registerProviderAliases registers /api/provider/{provider}/... routes // These allow Amp CLI to route requests like: // @@ -269,12 +306,9 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler) openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler) - // Create fallback handler wrapper that forwards to ampcode.com when provider not found - // Uses m.getProxy() for hot-reload support (proxy can be updated at runtime) - // Also includes model mapping support for routing unavailable models to alternatives - fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { - return m.getProxy() - }, m.modelMapper, m.forceModelMappings) + // Create unified routing wrapper (T-021 onwards) + // Replaces FallbackHandler with Router-based unified routing + routingWrapper := m.createModelRoutingWrapper() // Provider-specific routes under /api/provider/:provider ampProviders := engine.Group("/api/provider") @@ -302,33 +336,36 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han } // Root-level routes (for providers that omit /v1, like groq/cerebras) - // Wrap handlers with fallback logic to forward to ampcode.com when provider not found + // T-022: Migrated all OpenAI routes to use ModelRoutingWrapper for unified routing provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check) - provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions)) - provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions)) - provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses)) + provider.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions)) + provider.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions)) + provider.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses)) // /v1 routes (OpenAI/Claude-compatible endpoints) v1Amp := provider.Group("/v1") { v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback - // OpenAI-compatible endpoints with fallback - v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions)) - v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions)) - v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses)) + // OpenAI-compatible endpoints with ModelRoutingWrapper + // T-021, T-022: Migrated to unified routing wrapper + v1Amp.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions)) + v1Amp.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions)) + v1Amp.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses)) - // Claude/Anthropic-compatible endpoints with fallback - v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages)) - v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens)) + // Claude/Anthropic-compatible endpoints with ModelRoutingWrapper + // T-023: Migrated Claude routes to unified routing wrapper + v1Amp.POST("/messages", routingWrapper.Wrap(claudeCodeHandlers.ClaudeMessages)) + v1Amp.POST("/messages/count_tokens", routingWrapper.Wrap(claudeCodeHandlers.ClaudeCountTokens)) } // /v1beta routes (Gemini native API) // Note: Gemini handler extracts model from URL path, so fallback logic needs special handling + // T-024: Migrated Gemini v1beta routes to unified routing wrapper v1betaAmp := provider.Group("/v1beta") { v1betaAmp.GET("/models", geminiHandlers.GeminiModels) - v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler)) + v1betaAmp.POST("/models/*action", routingWrapper.Wrap(geminiHandlers.GeminiHandler)) v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler) } } diff --git a/internal/api/server.go b/internal/api/server.go index f9a2abdd89..bcb855d5a2 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -960,8 +960,8 @@ func (s *Server) UpdateClients(cfg *config.Config) { s.mgmt.SetAuthManager(s.handlers.AuthManager) } - // Notify Amp module only when Amp config has changed. - ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode) + // Notify Amp module when Amp config or OAuth model aliases have changed. + ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode) || !reflect.DeepEqual(oldCfg.OAuthModelAlias, cfg.OAuthModelAlias) if ampConfigChanged { if s.ampModule != nil { log.Debugf("triggering amp module config update") diff --git a/internal/cache/signature_cache.go b/internal/cache/signature_cache.go index af5371bfbc..e15b0802ae 100644 --- a/internal/cache/signature_cache.go +++ b/internal/cache/signature_cache.go @@ -6,6 +6,8 @@ import ( "strings" "sync" "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" ) // SignatureEntry holds a cached thinking signature with timestamp @@ -184,6 +186,7 @@ func HasValidSignature(modelName, signature string) bool { } func GetModelGroup(modelName string) string { + // Fast path: check model name patterns first if strings.Contains(modelName, "gpt") { return "gpt" } else if strings.Contains(modelName, "claude") { @@ -191,5 +194,21 @@ func GetModelGroup(modelName string) string { } else if strings.Contains(modelName, "gemini") { return "gemini" } + + // Slow path: check registry for provider-based grouping + // This handles models registered via claude-api-key, gemini-api-key, etc. + // that don't have provider name in their model name (e.g., kimi-k2.5 via claude-api-key) + if providers := registry.GetGlobalRegistry().GetModelProviders(modelName); len(providers) > 0 { + provider := strings.ToLower(providers[0]) + switch provider { + case "claude": + return "claude" + case "gemini", "gemini-cli", "aistudio", "vertex", "antigravity": + return "gemini" + case "codex": + return "gpt" + } + } + return modelName } diff --git a/internal/cache/signature_cache_test.go b/internal/cache/signature_cache_test.go index 8340815934..af4361f9aa 100644 --- a/internal/cache/signature_cache_test.go +++ b/internal/cache/signature_cache_test.go @@ -208,3 +208,84 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) { // but the logic is verified by the implementation _ = time.Now() // Acknowledge we're not testing time passage } + +// === GetModelGroup Tests === +// These tests verify that GetModelGroup correctly identifies model groups +// both by name pattern (fast path) and by registry provider lookup (slow path). + +func TestGetModelGroup_ByNamePattern(t *testing.T) { + tests := []struct { + modelName string + expectedGroup string + }{ + {"gpt-4o", "gpt"}, + {"gpt-4-turbo", "gpt"}, + {"claude-sonnet-4-20250514", "claude"}, + {"claude-opus-4-5-thinking", "claude"}, + {"gemini-2.5-pro", "gemini"}, + {"gemini-3-pro-preview", "gemini"}, + } + + for _, tt := range tests { + t.Run(tt.modelName, func(t *testing.T) { + result := GetModelGroup(tt.modelName) + if result != tt.expectedGroup { + t.Errorf("GetModelGroup(%q) = %q, expected %q", tt.modelName, result, tt.expectedGroup) + } + }) + } +} + +func TestGetModelGroup_UnknownModel(t *testing.T) { + // For unknown models with no registry entry, should return the model name itself + result := GetModelGroup("unknown-model-xyz") + if result != "unknown-model-xyz" { + t.Errorf("GetModelGroup for unknown model should return model name, got %q", result) + } +} + +// TestGetModelGroup_RegistryFallback tests that models registered via +// provider-specific API keys (e.g., kimi-k2.5 via claude-api-key) are +// correctly grouped by their provider. +// This test requires a populated global registry. +func TestGetModelGroup_RegistryFallback(t *testing.T) { + // This test only makes sense when the global registry is populated + // In unit test context, skip if registry is empty + + // Example: kimi-k2.5 registered via claude-api-key should group as "claude" + // The model name doesn't contain "claude", so name pattern matching fails. + // The registry should be checked to find the provider. + + // Skip for now - this requires integration test setup + t.Skip("Requires populated global registry - run as integration test") +} + +// === Cross-Model Signature Validation Tests === +// These tests verify that signatures cached under one model name can be +// validated under mapped model names (same provider group). + +func TestCacheSignature_CrossModelValidation(t *testing.T) { + ClearSignatureCache("") + + // Original request uses "claude-opus-4-5-20251101" + originalModel := "claude-opus-4-5-20251101" + // Mapped model is "claude-opus-4-5-thinking" + mappedModel := "claude-opus-4-5-thinking" + + text := "Some thinking block content" + sig := "validSignature123456789012345678901234567890123456789012" + + // Cache signature under the original model + CacheSignature(originalModel, text, sig) + + // Both should return the same signature because they're in the same group + retrieved1 := GetCachedSignature(originalModel, text) + retrieved2 := GetCachedSignature(mappedModel, text) + + if retrieved1 != sig { + t.Errorf("Original model signature mismatch: got %q", retrieved1) + } + if retrieved2 != sig { + t.Errorf("Mapped model signature mismatch: got %q", retrieved2) + } +} diff --git a/internal/routing/adapter.go b/internal/routing/adapter.go new file mode 100644 index 0000000000..1d90b0fed5 --- /dev/null +++ b/internal/routing/adapter.go @@ -0,0 +1,39 @@ +// Package routing provides adapter to integrate with existing codebase. +package routing + +import ( + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +// Adapter bridges the new routing layer with existing auth manager. +type Adapter struct { + router *Router + exec *Executor +} + +// NewAdapter creates a new adapter with the given configuration and auth manager. +func NewAdapter(cfg *config.Config, authManager *coreauth.Manager) *Adapter { + registry := NewRegistry() + + // TODO: Register OAuth providers from authManager + // TODO: Register API key providers from cfg + + router := NewRouter(registry, cfg) + exec := NewExecutor(router) + + return &Adapter{ + router: router, + exec: exec, + } +} + +// Router returns the underlying router. +func (a *Adapter) Router() *Router { + return a.router +} + +// Executor returns the underlying executor. +func (a *Adapter) Executor() *Executor { + return a.exec +} diff --git a/internal/routing/ctxkeys/keys.go b/internal/routing/ctxkeys/keys.go new file mode 100644 index 0000000000..5838d54d2f --- /dev/null +++ b/internal/routing/ctxkeys/keys.go @@ -0,0 +1,11 @@ +package ctxkeys + +type key string + +const ( + MappedModel key = "mapped_model" + FallbackModels key = "fallback_models" + RouteCandidates key = "route_candidates" + RoutingDecision key = "routing_decision" + MappingApplied key = "mapping_applied" +) diff --git a/internal/routing/executor.go b/internal/routing/executor.go new file mode 100644 index 0000000000..30b5750b0b --- /dev/null +++ b/internal/routing/executor.go @@ -0,0 +1,111 @@ +package routing + +import ( + "context" + "errors" + + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + log "github.com/sirupsen/logrus" +) + +// Executor handles request execution with fallback support. +type Executor struct { + router *Router +} + +// NewExecutor creates a new executor with the given router. +func NewExecutor(router *Router) *Executor { + return &Executor{router: router} +} + +// Execute sends the request through the routing decision. +func (e *Executor) Execute(ctx context.Context, req executor.Request) (executor.Response, error) { + decision := e.router.Resolve(req.Model) + + log.Debugf("routing: %s -> %s (%d candidates)", + decision.RequestedModel, + decision.ResolvedModel, + len(decision.Candidates)) + + var lastErr error + tried := make(map[string]struct{}) + + for i, candidate := range decision.Candidates { + key := candidate.Provider.Name() + "/" + candidate.Model + if _, ok := tried[key]; ok { + continue + } + tried[key] = struct{}{} + + log.Debugf("routing: trying candidate %d/%d: %s with model %s", + i+1, len(decision.Candidates), candidate.Provider.Name(), candidate.Model) + + req.Model = candidate.Model + resp, err := candidate.Provider.Execute(ctx, candidate.Model, req) + if err == nil { + return resp, nil + } + + lastErr = err + log.Debugf("routing: candidate failed: %v", err) + + // Check if it's a fatal error (not retryable) + if isFatalError(err) { + break + } + } + + if lastErr != nil { + return executor.Response{}, lastErr + } + return executor.Response{}, errors.New("no available providers") +} + +// ExecuteStream sends a streaming request through the routing decision. +func (e *Executor) ExecuteStream(ctx context.Context, req executor.Request) (<-chan executor.StreamChunk, error) { + decision := e.router.Resolve(req.Model) + + log.Debugf("routing stream: %s -> %s (%d candidates)", + decision.RequestedModel, + decision.ResolvedModel, + len(decision.Candidates)) + + var lastErr error + tried := make(map[string]struct{}) + + for i, candidate := range decision.Candidates { + key := candidate.Provider.Name() + "/" + candidate.Model + if _, ok := tried[key]; ok { + continue + } + tried[key] = struct{}{} + + log.Debugf("routing stream: trying candidate %d/%d: %s with model %s", + i+1, len(decision.Candidates), candidate.Provider.Name(), candidate.Model) + + req.Model = candidate.Model + chunks, err := candidate.Provider.ExecuteStream(ctx, candidate.Model, req) + if err == nil { + return chunks, nil + } + + lastErr = err + log.Debugf("routing stream: candidate failed: %v", err) + + if isFatalError(err) { + break + } + } + + if lastErr != nil { + return nil, lastErr + } + return nil, errors.New("no available providers") +} + +// isFatalError returns true if the error is not retryable. +func isFatalError(err error) bool { + // TODO: implement based on error type + // For now, all errors are retryable + return false +} diff --git a/internal/routing/extractor.go b/internal/routing/extractor.go new file mode 100644 index 0000000000..94fe969ac9 --- /dev/null +++ b/internal/routing/extractor.go @@ -0,0 +1,59 @@ +package routing + +import ( + "strings" + + "github.com/tidwall/gjson" +) + +// ModelExtractor extracts model names from request data. +type ModelExtractor interface { + // Extract returns the model name from the request body and gin parameters. + // The ginParams map contains route parameters like "action" and "path". + Extract(body []byte, ginParams map[string]string) (string, error) +} + +// DefaultModelExtractor is the standard implementation of ModelExtractor. +type DefaultModelExtractor struct{} + +// NewModelExtractor creates a new DefaultModelExtractor. +func NewModelExtractor() *DefaultModelExtractor { + return &DefaultModelExtractor{} +} + +// Extract extracts the model name from the request. +// It checks in order: +// 1. JSON body "model" field (OpenAI, Claude format) +// 2. "action" parameter for Gemini standard format (e.g., "gemini-pro:generateContent") +// 3. "path" parameter for AMP CLI Gemini format (e.g., "/publishers/google/models/gemini-3-pro:streamGenerateContent") +func (e *DefaultModelExtractor) Extract(body []byte, ginParams map[string]string) (string, error) { + // First try to parse from JSON body (OpenAI, Claude, etc.) + if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String { + return result.String(), nil + } + + // For Gemini requests, model is in the URL path + // Standard format: /models/{model}:generateContent -> :action parameter + if action, ok := ginParams["action"]; ok && action != "" { + // Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro") + parts := strings.Split(action, ":") + if len(parts) > 0 && parts[0] != "" { + return parts[0], nil + } + } + + // AMP CLI format: /publishers/google/models/{model}:method -> *path parameter + // Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent + if path, ok := ginParams["path"]; ok && path != "" { + // Look for /models/{model}:method pattern + if idx := strings.Index(path, "/models/"); idx >= 0 { + modelPart := path[idx+8:] // Skip "/models/" + // Split by colon to get model name + if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 { + return modelPart[:colonIdx], nil + } + } + } + + return "", nil +} diff --git a/internal/routing/extractor_test.go b/internal/routing/extractor_test.go new file mode 100644 index 0000000000..485b4831b1 --- /dev/null +++ b/internal/routing/extractor_test.go @@ -0,0 +1,214 @@ +package routing + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestModelExtractor_ExtractFromJSONBody(t *testing.T) { + extractor := NewModelExtractor() + + tests := []struct { + name string + body []byte + want string + wantErr bool + }{ + { + name: "extract from JSON body with model field", + body: []byte(`{"model":"gpt-4.1"}`), + want: "gpt-4.1", + }, + { + name: "extract claude model from JSON body", + body: []byte(`{"model":"claude-3-5-sonnet-20241022"}`), + want: "claude-3-5-sonnet-20241022", + }, + { + name: "extract with additional fields", + body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`), + want: "gpt-4", + }, + { + name: "empty body returns empty", + body: []byte{}, + want: "", + }, + { + name: "no model field returns empty", + body: []byte(`{"messages":[]}`), + want: "", + }, + { + name: "model is not string returns empty", + body: []byte(`{"model":123}`), + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractor.Extract(tt.body, nil) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestModelExtractor_ExtractFromGeminiActionParam(t *testing.T) { + extractor := NewModelExtractor() + + tests := []struct { + name string + body []byte + ginParams map[string]string + want string + }{ + { + name: "extract from action parameter - gemini-pro", + body: []byte(`{}`), + ginParams: map[string]string{"action": "gemini-pro:generateContent"}, + want: "gemini-pro", + }, + { + name: "extract from action parameter - gemini-ultra", + body: []byte(`{}`), + ginParams: map[string]string{"action": "gemini-ultra:chat"}, + want: "gemini-ultra", + }, + { + name: "empty action returns empty", + body: []byte(`{}`), + ginParams: map[string]string{"action": ""}, + want: "", + }, + { + name: "action without colon returns full value", + body: []byte(`{}`), + ginParams: map[string]string{"action": "gemini-model"}, + want: "gemini-model", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractor.Extract(tt.body, tt.ginParams) + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestModelExtractor_ExtractFromGeminiV1Beta1Path(t *testing.T) { + extractor := NewModelExtractor() + + tests := []struct { + name string + body []byte + ginParams map[string]string + want string + }{ + { + name: "extract from v1beta1 path - gemini-3-pro", + body: []byte(`{}`), + ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro:streamGenerateContent"}, + want: "gemini-3-pro", + }, + { + name: "extract from v1beta1 path with preview", + body: []byte(`{}`), + ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro-preview:generateContent"}, + want: "gemini-3-pro-preview", + }, + { + name: "path without models segment returns empty", + body: []byte(`{}`), + ginParams: map[string]string{"path": "/publishers/google/gemini-3-pro:streamGenerateContent"}, + want: "", + }, + { + name: "empty path returns empty", + body: []byte(`{}`), + ginParams: map[string]string{"path": ""}, + want: "", + }, + { + name: "path with /models/ but no colon returns empty", + body: []byte(`{}`), + ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro"}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractor.Extract(tt.body, tt.ginParams) + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestModelExtractor_ExtractPriority(t *testing.T) { + extractor := NewModelExtractor() + + // JSON body takes priority over gin params + t.Run("JSON body takes priority over action param", func(t *testing.T) { + body := []byte(`{"model":"gpt-4"}`) + params := map[string]string{"action": "gemini-pro:generateContent"} + got, err := extractor.Extract(body, params) + assert.NoError(t, err) + assert.Equal(t, "gpt-4", got) + }) + + // Action param takes priority over path param + t.Run("action param takes priority over path param", func(t *testing.T) { + body := []byte(`{}`) + params := map[string]string{ + "action": "gemini-action:generate", + "path": "/publishers/google/models/gemini-path:streamGenerateContent", + } + got, err := extractor.Extract(body, params) + assert.NoError(t, err) + assert.Equal(t, "gemini-action", got) + }) +} + +func TestModelExtractor_NoModelFound(t *testing.T) { + extractor := NewModelExtractor() + + tests := []struct { + name string + body []byte + ginParams map[string]string + }{ + { + name: "empty body and no params", + body: []byte{}, + ginParams: nil, + }, + { + name: "body without model and no params", + body: []byte(`{"messages":[]}`), + ginParams: map[string]string{}, + }, + { + name: "irrelevant params only", + body: []byte(`{}`), + ginParams: map[string]string{"other": "value"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractor.Extract(tt.body, tt.ginParams) + assert.NoError(t, err) + assert.Empty(t, got) + }) + } +} diff --git a/internal/routing/provider.go b/internal/routing/provider.go new file mode 100644 index 0000000000..8e1606c850 --- /dev/null +++ b/internal/routing/provider.go @@ -0,0 +1,80 @@ +// Package routing provides unified model routing for all provider types. +package routing + +import ( + "context" + + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +// ProviderType indicates the type of provider. +type ProviderType string + +const ( + ProviderTypeOAuth ProviderType = "oauth" + ProviderTypeAPIKey ProviderType = "api_key" + ProviderTypeVertex ProviderType = "vertex" +) + +// Provider is the unified interface for all provider types (OAuth, API key, etc.). +type Provider interface { + // Name returns the unique provider identifier. + Name() string + + // Type returns the provider type. + Type() ProviderType + + // SupportsModel returns true if this provider can handle the given model. + SupportsModel(model string) bool + + // Available returns true if the provider is available for the model (not quota exceeded). + Available(model string) bool + + // Priority returns the priority for this provider (lower = tried first). + Priority() int + + // Execute sends the request to the provider. + Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) + + // ExecuteStream sends a streaming request to the provider. + ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) +} + +// ProviderCandidate represents a provider + model combination to try. +type ProviderCandidate struct { + Provider Provider + Model string // The actual model name to use (may be different from requested due to aliasing) +} + +// Registry manages all available providers. +type Registry struct { + providers []Provider +} + +// NewRegistry creates a new provider registry. +func NewRegistry() *Registry { + return &Registry{ + providers: make([]Provider, 0), + } +} + +// Register adds a provider to the registry. +func (r *Registry) Register(p Provider) { + r.providers = append(r.providers, p) +} + +// FindProviders returns all providers that support the given model and are available. +func (r *Registry) FindProviders(model string) []Provider { + var result []Provider + for _, p := range r.providers { + if p.SupportsModel(model) && p.Available(model) { + result = append(result, p) + } + } + return result +} + +// All returns all registered providers. +func (r *Registry) All() []Provider { + return r.providers +} diff --git a/internal/routing/providers/apikey.go b/internal/routing/providers/apikey.go new file mode 100644 index 0000000000..4603702dc6 --- /dev/null +++ b/internal/routing/providers/apikey.go @@ -0,0 +1,156 @@ +package providers + +import ( + "context" + "errors" + "net/http" + "strings" + "sync" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/routing" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +// APIKeyProvider wraps API key configs as routing.Provider. +type APIKeyProvider struct { + name string + provider string // claude, gemini, codex, vertex + keys []APIKeyEntry + mu sync.RWMutex + client HTTPClient +} + +// APIKeyEntry represents a single API key configuration. +type APIKeyEntry struct { + APIKey string + BaseURL string + Models []config.ClaudeModel // Using ClaudeModel as generic model alias +} + +// HTTPClient interface for making HTTP requests. +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +// NewAPIKeyProvider creates a new API key provider. +func NewAPIKeyProvider(name, provider string, client HTTPClient) *APIKeyProvider { + return &APIKeyProvider{ + name: name, + provider: provider, + keys: make([]APIKeyEntry, 0), + client: client, + } +} + +// Name returns the provider name. +func (p *APIKeyProvider) Name() string { + return p.name +} + +// Type returns ProviderTypeAPIKey. +func (p *APIKeyProvider) Type() routing.ProviderType { + return routing.ProviderTypeAPIKey +} + +// SupportsModel checks if the model is supported by this provider. +func (p *APIKeyProvider) SupportsModel(model string) bool { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, key := range p.keys { + for _, m := range key.Models { + if strings.EqualFold(m.Alias, model) || strings.EqualFold(m.Name, model) { + return true + } + } + } + return false +} + +// Available always returns true for API keys (unless explicitly disabled). +func (p *APIKeyProvider) Available(model string) bool { + return p.SupportsModel(model) +} + +// Priority returns the priority (API key is lower priority than OAuth). +func (p *APIKeyProvider) Priority() int { + return 20 +} + +// Execute sends the request using the API key. +func (p *APIKeyProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) { + key := p.selectKey(model) + if key == nil { + return executor.Response{}, ErrNoMatchingAPIKey + } + + // Resolve the actual model name from alias + actualModel := p.resolveModel(key, model) + + // Execute via HTTP client + return p.executeHTTP(ctx, key, actualModel, req) +} + +// ExecuteStream sends a streaming request. +func (p *APIKeyProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) ( + <-chan executor.StreamChunk, error) { + key := p.selectKey(model) + if key == nil { + return nil, ErrNoMatchingAPIKey + } + + actualModel := p.resolveModel(key, model) + return p.executeHTTPStream(ctx, key, actualModel, req) +} + +// AddKey adds an API key entry. +func (p *APIKeyProvider) AddKey(entry APIKeyEntry) { + p.mu.Lock() + defer p.mu.Unlock() + p.keys = append(p.keys, entry) +} + +// selectKey selects a key that supports the model. +func (p *APIKeyProvider) selectKey(model string) *APIKeyEntry { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, key := range p.keys { + for _, m := range key.Models { + if strings.EqualFold(m.Alias, model) || strings.EqualFold(m.Name, model) { + return &key + } + } + } + return nil +} + +// resolveModel resolves alias to actual model name. +func (p *APIKeyProvider) resolveModel(key *APIKeyEntry, requested string) string { + for _, m := range key.Models { + if strings.EqualFold(m.Alias, requested) { + return m.Name + } + } + return requested +} + +// executeHTTP makes the HTTP request. +func (p *APIKeyProvider) executeHTTP(ctx context.Context, key *APIKeyEntry, model string, req executor.Request) (executor.Response, error) { + // TODO: implement actual HTTP execution + // This is a placeholder - actual implementation would build HTTP request + return executor.Response{}, errors.New("not yet implemented") +} + +// executeHTTPStream makes a streaming HTTP request. +func (p *APIKeyProvider) executeHTTPStream(ctx context.Context, key *APIKeyEntry, model string, req executor.Request) ( + <-chan executor.StreamChunk, error) { + // TODO: implement actual HTTP streaming + return nil, errors.New("not yet implemented") +} + +// Errors +var ( + ErrNoMatchingAPIKey = errors.New("no API key supports the requested model") +) diff --git a/internal/routing/providers/oauth.go b/internal/routing/providers/oauth.go new file mode 100644 index 0000000000..ae0c09e28d --- /dev/null +++ b/internal/routing/providers/oauth.go @@ -0,0 +1,132 @@ +package providers + +import ( + "context" + "errors" + "sync" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/routing" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +// OAuthProvider wraps OAuth-based auths as routing.Provider. +type OAuthProvider struct { + name string + auths []*coreauth.Auth + mu sync.RWMutex + executor coreauth.ProviderExecutor +} + +// NewOAuthProvider creates a new OAuth provider. +func NewOAuthProvider(name string, exec coreauth.ProviderExecutor) *OAuthProvider { + return &OAuthProvider{ + name: name, + auths: make([]*coreauth.Auth, 0), + executor: exec, + } +} + +// Name returns the provider name. +func (p *OAuthProvider) Name() string { + return p.name +} + +// Type returns ProviderTypeOAuth. +func (p *OAuthProvider) Type() routing.ProviderType { + return routing.ProviderTypeOAuth +} + +// SupportsModel checks if any auth supports the model. +func (p *OAuthProvider) SupportsModel(model string) bool { + p.mu.RLock() + defer p.mu.RUnlock() + + // OAuth providers typically support models via oauth-model-alias + // The actual model support is determined at execution time + return true +} + +// Available checks if there's an available auth for the model. +func (p *OAuthProvider) Available(model string) bool { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, auth := range p.auths { + if p.isAuthAvailable(auth, model) { + return true + } + } + return false +} + +// Priority returns the priority (OAuth is preferred over API key). +func (p *OAuthProvider) Priority() int { + return 10 +} + +// Execute sends the request using an available OAuth auth. +func (p *OAuthProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) { + auth := p.selectAuth(model) + if auth == nil { + return executor.Response{}, ErrNoAvailableAuth + } + + return p.executor.Execute(ctx, auth, req, executor.Options{}) +} + +// ExecuteStream sends a streaming request. +func (p *OAuthProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) { + auth := p.selectAuth(model) + if auth == nil { + return nil, ErrNoAvailableAuth + } + + return p.executor.ExecuteStream(ctx, auth, req, executor.Options{}) +} + +// AddAuth adds an auth to this provider. +func (p *OAuthProvider) AddAuth(auth *coreauth.Auth) { + p.mu.Lock() + defer p.mu.Unlock() + p.auths = append(p.auths, auth) +} + +// RemoveAuth removes an auth from this provider. +func (p *OAuthProvider) RemoveAuth(authID string) { + p.mu.Lock() + defer p.mu.Unlock() + + filtered := make([]*coreauth.Auth, 0, len(p.auths)) + for _, auth := range p.auths { + if auth.ID != authID { + filtered = append(filtered, auth) + } + } + p.auths = filtered +} + +// isAuthAvailable checks if an auth is available for the model. +func (p *OAuthProvider) isAuthAvailable(auth *coreauth.Auth, model string) bool { + // TODO: integrate with model_registry for quota checking + // For now, just check if auth exists + return auth != nil +} + +// selectAuth selects an available auth for the model. +func (p *OAuthProvider) selectAuth(model string) *coreauth.Auth { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, auth := range p.auths { + if p.isAuthAvailable(auth, model) { + return auth + } + } + return nil +} + +// Errors +var ( + ErrNoAvailableAuth = errors.New("no available OAuth auth for model") +) diff --git a/internal/routing/rewriter.go b/internal/routing/rewriter.go new file mode 100644 index 0000000000..d0c027716a --- /dev/null +++ b/internal/routing/rewriter.go @@ -0,0 +1,159 @@ +package routing + +import ( + "bytes" + "net/http" + "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + log "github.com/sirupsen/logrus" +) + +// ModelRewriter handles model name rewriting in requests and responses. +type ModelRewriter interface { + // RewriteRequestBody rewrites the model field in a JSON request body. + // Returns the modified body or the original if no rewrite was needed. + RewriteRequestBody(body []byte, newModel string) ([]byte, error) + + // WrapResponseWriter wraps an http.ResponseWriter to rewrite model names in the response. + // Returns the wrapped writer and a cleanup function that must be called after the response is complete. + WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) +} + +// DefaultModelRewriter is the standard implementation of ModelRewriter. +type DefaultModelRewriter struct{} + +// NewModelRewriter creates a new DefaultModelRewriter. +func NewModelRewriter() *DefaultModelRewriter { + return &DefaultModelRewriter{} +} + +// RewriteRequestBody replaces the model name in a JSON request body. +func (r *DefaultModelRewriter) RewriteRequestBody(body []byte, newModel string) ([]byte, error) { + if !gjson.GetBytes(body, "model").Exists() { + return body, nil + } + result, err := sjson.SetBytes(body, "model", newModel) + if err != nil { + return body, err + } + return result, nil +} + +// WrapResponseWriter wraps a response writer to rewrite model names. +// The cleanup function must be called after the handler completes to flush any buffered data. +func (r *DefaultModelRewriter) WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) { + rw := &responseRewriter{ + ResponseWriter: w, + body: &bytes.Buffer{}, + requestedModel: requestedModel, + resolvedModel: resolvedModel, + } + return rw, func() { rw.flush() } +} + +// responseRewriter wraps http.ResponseWriter to intercept and modify the response body. +type responseRewriter struct { + http.ResponseWriter + body *bytes.Buffer + requestedModel string + resolvedModel string + isStreaming bool + wroteHeader bool + flushed bool +} + +// Write intercepts response writes and buffers them for model name replacement. +func (rw *responseRewriter) Write(data []byte) (int, error) { + // Ensure header is written + if !rw.wroteHeader { + rw.WriteHeader(http.StatusOK) + } + + // Detect streaming on first write + if rw.body.Len() == 0 && !rw.isStreaming { + contentType := rw.Header().Get("Content-Type") + rw.isStreaming = strings.Contains(contentType, "text/event-stream") || + strings.Contains(contentType, "stream") + } + + if rw.isStreaming { + n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data)) + if err == nil { + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } + } + return n, err + } + return rw.body.Write(data) +} + +// WriteHeader captures the status code and delegates to the underlying writer. +func (rw *responseRewriter) WriteHeader(code int) { + if !rw.wroteHeader { + rw.wroteHeader = true + rw.ResponseWriter.WriteHeader(code) + } +} + +// flush writes the buffered response with model names rewritten. +func (rw *responseRewriter) flush() { + if rw.flushed { + return + } + rw.flushed = true + + if rw.isStreaming { + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } + return + } + if rw.body.Len() > 0 { + data := rw.rewriteModelInResponse(rw.body.Bytes()) + if _, err := rw.ResponseWriter.Write(data); err != nil { + log.Warnf("response rewriter: failed to write rewritten response: %v", err) + } + } +} + +// modelFieldPaths lists all JSON paths where model name may appear. +var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"} + +// rewriteModelInResponse replaces all occurrences of the resolved model with the requested model. +func (rw *responseRewriter) rewriteModelInResponse(data []byte) []byte { + if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel { + return data + } + + for _, path := range modelFieldPaths { + if gjson.GetBytes(data, path).Exists() { + data, _ = sjson.SetBytes(data, path, rw.requestedModel) + } + } + return data +} + +// rewriteStreamChunk rewrites model names in SSE stream chunks. +func (rw *responseRewriter) rewriteStreamChunk(chunk []byte) []byte { + if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel { + return chunk + } + + // SSE format: "data: {json}\n\n" + lines := bytes.Split(chunk, []byte("\n")) + for i, line := range lines { + if bytes.HasPrefix(line, []byte("data: ")) { + jsonData := bytes.TrimPrefix(line, []byte("data: ")) + if len(jsonData) > 0 && jsonData[0] == '{' { + // Rewrite JSON in the data line + rewritten := rw.rewriteModelInResponse(jsonData) + lines[i] = append([]byte("data: "), rewritten...) + } + } + } + + return bytes.Join(lines, []byte("\n")) +} diff --git a/internal/routing/rewriter_test.go b/internal/routing/rewriter_test.go new file mode 100644 index 0000000000..d628f71076 --- /dev/null +++ b/internal/routing/rewriter_test.go @@ -0,0 +1,342 @@ +package routing + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestModelRewriter_RewriteRequestBody(t *testing.T) { + rewriter := NewModelRewriter() + + tests := []struct { + name string + body []byte + newModel string + wantModel string + wantChange bool + }{ + { + name: "rewrites model field in JSON body", + body: []byte(`{"model":"gpt-4.1","messages":[]}`), + newModel: "claude-local", + wantModel: "claude-local", + wantChange: true, + }, + { + name: "rewrites with empty body returns empty", + body: []byte{}, + newModel: "gpt-4", + wantModel: "", + wantChange: false, + }, + { + name: "handles missing model field gracefully", + body: []byte(`{"messages":[{"role":"user"}]}`), + newModel: "gpt-4", + wantModel: "", + wantChange: false, + }, + { + name: "preserves other fields when rewriting", + body: []byte(`{"model":"old-model","temperature":0.7,"max_tokens":100}`), + newModel: "new-model", + wantModel: "new-model", + wantChange: true, + }, + { + name: "handles nested JSON structure", + body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}],"stream":true}`), + newModel: "claude-3-opus", + wantModel: "claude-3-opus", + wantChange: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := rewriter.RewriteRequestBody(tt.body, tt.newModel) + require.NoError(t, err) + + if tt.wantChange { + assert.NotEqual(t, string(tt.body), string(result), "body should have been modified") + } + + if tt.wantModel != "" { + // Parse result and check model field + model, _ := NewModelExtractor().Extract(result, nil) + assert.Equal(t, tt.wantModel, model) + } + }) + } +} + +func TestModelRewriter_WrapResponseWriter(t *testing.T) { + rewriter := NewModelRewriter() + + t.Run("response writer wraps without error", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + require.NotNil(t, wrapped) + require.NotNil(t, cleanup) + defer cleanup() + }) + + t.Run("rewrites model in non-streaming response", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + // Write a response with the resolved model + response := []byte(`{"model":"claude-local","content":"hello"}`) + wrapped.Header().Set("Content-Type", "application/json") + _, err := wrapped.Write(response) + require.NoError(t, err) + + // Cleanup triggers the rewrite + cleanup() + + // Check the response was rewritten to the requested model + body := recorder.Body.Bytes() + assert.Contains(t, string(body), `"model":"gpt-4"`) + assert.NotContains(t, string(body), `"model":"claude-local"`) + }) + + t.Run("no-op when requested equals resolved", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "gpt-4") + + response := []byte(`{"model":"gpt-4","content":"hello"}`) + wrapped.Header().Set("Content-Type", "application/json") + _, err := wrapped.Write(response) + require.NoError(t, err) + + cleanup() + + body := recorder.Body.Bytes() + assert.Contains(t, string(body), `"model":"gpt-4"`) + }) + + t.Run("rewrites modelVersion field", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + response := []byte(`{"modelVersion":"claude-local","content":"hello"}`) + wrapped.Header().Set("Content-Type", "application/json") + _, err := wrapped.Write(response) + require.NoError(t, err) + + cleanup() + + body := recorder.Body.Bytes() + assert.Contains(t, string(body), `"modelVersion":"gpt-4"`) + }) + + t.Run("handles streaming responses", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + // Set streaming content type + wrapped.Header().Set("Content-Type", "text/event-stream") + + // Write SSE chunks with resolved model + chunk1 := []byte("data: {\"model\":\"claude-local\",\"delta\":\"hello\"}\n\n") + _, err := wrapped.Write(chunk1) + require.NoError(t, err) + + chunk2 := []byte("data: {\"model\":\"claude-local\",\"delta\":\" world\"}\n\n") + _, err = wrapped.Write(chunk2) + require.NoError(t, err) + + cleanup() + + // For streaming, data is written immediately with rewrites + body := recorder.Body.Bytes() + assert.Contains(t, string(body), `"model":"gpt-4"`) + assert.NotContains(t, string(body), `"model":"claude-local"`) + }) + + t.Run("empty body handled gracefully", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + wrapped.Header().Set("Content-Type", "application/json") + // Don't write anything + + cleanup() + + body := recorder.Body.Bytes() + assert.Empty(t, body) + }) + + t.Run("preserves other JSON fields", func(t *testing.T) { + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + response := []byte(`{"model":"claude-local","temperature":0.7,"usage":{"prompt_tokens":10}}`) + wrapped.Header().Set("Content-Type", "application/json") + _, err := wrapped.Write(response) + require.NoError(t, err) + + cleanup() + + body := recorder.Body.Bytes() + assert.Contains(t, string(body), `"temperature":0.7`) + assert.Contains(t, string(body), `"prompt_tokens":10`) + }) +} + +func TestResponseRewriter_ImplementsInterfaces(t *testing.T) { + rewriter := NewModelRewriter() + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + defer cleanup() + + // Should implement http.ResponseWriter + assert.Implements(t, (*http.ResponseWriter)(nil), wrapped) + + // Should preserve header access + wrapped.Header().Set("X-Custom", "value") + assert.Equal(t, "value", recorder.Header().Get("X-Custom")) + + // Should write status + wrapped.WriteHeader(http.StatusCreated) + assert.Equal(t, http.StatusCreated, recorder.Code) +} + +func TestResponseRewriter_Flush(t *testing.T) { + t.Run("flush writes buffered content", func(t *testing.T) { + rewriter := NewModelRewriter() + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + response := []byte(`{"model":"claude-local","content":"test"}`) + wrapped.Header().Set("Content-Type", "application/json") + wrapped.Write(response) + + // Before cleanup, response should be empty (buffered) + assert.Empty(t, recorder.Body.Bytes()) + + // After cleanup, response should be written + cleanup() + assert.NotEmpty(t, recorder.Body.Bytes()) + }) + + t.Run("multiple flush calls are safe", func(t *testing.T) { + rewriter := NewModelRewriter() + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + response := []byte(`{"model":"claude-local"}`) + wrapped.Header().Set("Content-Type", "application/json") + wrapped.Write(response) + + // First cleanup + cleanup() + firstBody := recorder.Body.Bytes() + + // Second cleanup should not write again + cleanup() + secondBody := recorder.Body.Bytes() + + assert.Equal(t, firstBody, secondBody) + }) +} + +func TestResponseRewriter_StreamingWithDataLines(t *testing.T) { + rewriter := NewModelRewriter() + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + wrapped.Header().Set("Content-Type", "text/event-stream") + + // SSE format with multiple data lines + chunk := []byte("data: {\"model\":\"claude-local\"}\n\ndata: {\"model\":\"claude-local\",\"done\":true}\n\n") + wrapped.Write(chunk) + + cleanup() + + body := recorder.Body.Bytes() + // Both data lines should have model rewritten + assert.Contains(t, string(body), `"model":"gpt-4"`) + assert.NotContains(t, string(body), `"model":"claude-local"`) +} + +func TestModelRewriter_RoundTrip(t *testing.T) { + // Simulate a full request -> response cycle with model rewriting + rewriter := NewModelRewriter() + + // Step 1: Rewrite request body + originalRequest := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`) + rewrittenRequest, err := rewriter.RewriteRequestBody(originalRequest, "claude-local") + require.NoError(t, err) + + // Verify request was rewritten + extractor := NewModelExtractor() + requestModel, _ := extractor.Extract(rewrittenRequest, nil) + assert.Equal(t, "claude-local", requestModel) + + // Step 2: Simulate response with resolved model + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + response := []byte(`{"model":"claude-local","content":"Hello! How can I help?"}`) + wrapped.Header().Set("Content-Type", "application/json") + wrapped.Write(response) + cleanup() + + // Verify response was rewritten back + body, _ := io.ReadAll(recorder.Result().Body) + responseModel, _ := extractor.Extract(body, nil) + assert.Equal(t, "gpt-4", responseModel) +} + +func TestModelRewriter_NonJSONBody(t *testing.T) { + rewriter := NewModelRewriter() + + // Binary/non-JSON body should be returned unchanged + body := []byte{0x00, 0x01, 0x02, 0x03} + result, err := rewriter.RewriteRequestBody(body, "gpt-4") + require.NoError(t, err) + assert.Equal(t, body, result) +} + +func TestModelRewriter_InvalidJSON(t *testing.T) { + rewriter := NewModelRewriter() + + // Invalid JSON without model field should be returned unchanged + body := []byte(`not valid json`) + result, err := rewriter.RewriteRequestBody(body, "gpt-4") + require.NoError(t, err) + assert.Equal(t, body, result) +} + +func TestResponseRewriter_StatusCodePreserved(t *testing.T) { + rewriter := NewModelRewriter() + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + wrapped.WriteHeader(http.StatusAccepted) + wrapped.Write([]byte(`{"model":"claude-local"}`)) + cleanup() + + assert.Equal(t, http.StatusAccepted, recorder.Code) +} + +func TestResponseRewriter_HeaderFlushed(t *testing.T) { + rewriter := NewModelRewriter() + recorder := httptest.NewRecorder() + wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local") + + wrapped.Header().Set("Content-Type", "application/json") + wrapped.Header().Set("X-Request-ID", "abc123") + wrapped.Write([]byte(`{"model":"claude-local"}`)) + cleanup() + + result := recorder.Result() + assert.Equal(t, "application/json", result.Header.Get("Content-Type")) + assert.Equal(t, "abc123", result.Header.Get("X-Request-ID")) +} diff --git a/internal/routing/router.go b/internal/routing/router.go new file mode 100644 index 0000000000..543c7ecf7d --- /dev/null +++ b/internal/routing/router.go @@ -0,0 +1,317 @@ +package routing + +import ( + "context" + "sort" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +// Router resolves models to provider candidates. +type Router struct { + registry *Registry + modelMappings map[string]string // normalized from -> to + oauthAliases map[string][]string // normalized model -> []alias +} + +// NewRouter creates a new router with the given configuration. +func NewRouter(registry *Registry, cfg *config.Config) *Router { + r := &Router{ + registry: registry, + modelMappings: make(map[string]string), + oauthAliases: make(map[string][]string), + } + + if cfg != nil { + r.loadModelMappings(cfg.AmpCode.ModelMappings) + r.loadOAuthAliases(cfg.OAuthModelAlias) + } + + return r +} + +// LegacyRoutingDecision contains the resolved routing information. +// Deprecated: Will be replaced by RoutingDecision from types.go in T-013. +type LegacyRoutingDecision struct { + RequestedModel string // Original model from request + ResolvedModel string // After model-mappings + Candidates []ProviderCandidate // Ordered list of providers to try +} + +// Resolve determines the routing decision for the requested model. +// Deprecated: Will be updated to use RoutingRequest and return *RoutingDecision in T-013. +func (r *Router) Resolve(requestedModel string) *LegacyRoutingDecision { + // 1. Extract thinking suffix + suffixResult := thinking.ParseSuffix(requestedModel) + baseModel := suffixResult.ModelName + + // 2. Apply model-mappings + targetModel := r.applyMappings(baseModel) + + // 3. Find primary providers + candidates := r.findCandidates(targetModel, suffixResult) + + // 4. Add fallback aliases + for _, alias := range r.oauthAliases[strings.ToLower(targetModel)] { + candidates = append(candidates, r.findCandidates(alias, suffixResult)...) + } + + // 5. Sort by priority + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].Provider.Priority() < candidates[j].Provider.Priority() + }) + + return &LegacyRoutingDecision{ + RequestedModel: requestedModel, + ResolvedModel: targetModel, + Candidates: candidates, + } +} + +// ResolveV2 determines the routing decision for a routing request. +// It uses the new RoutingRequest and RoutingDecision types. +func (r *Router) ResolveV2(req RoutingRequest) *RoutingDecision { + // 1. Extract thinking suffix + suffixResult := thinking.ParseSuffix(req.RequestedModel) + baseModel := suffixResult.ModelName + thinkingSuffix := "" + if suffixResult.HasSuffix { + thinkingSuffix = "(" + suffixResult.RawSuffix + ")" + } + + // 2. Check for local providers + localCandidates := r.findLocalCandidates(baseModel, suffixResult) + + // 3. Apply model-mappings if needed + mappedModel := r.applyMappings(baseModel) + mappingCandidates := r.findLocalCandidates(mappedModel, suffixResult) + + // 4. Determine route type based on preferences and availability + var decision *RoutingDecision + + if req.ForceModelMapping && mappedModel != baseModel && len(mappingCandidates) > 0 { + // FORCE MODE: Use mapping even if local provider exists + decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:]) + } else if req.PreferLocalProvider && len(localCandidates) > 0 { + // DEFAULT MODE with local preference: Use local provider first + decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix) + } else if len(localCandidates) > 0 { + // DEFAULT MODE: Local provider available + decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix) + } else if mappedModel != baseModel && len(mappingCandidates) > 0 { + // DEFAULT MODE: No local provider, but mapping available + decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:]) + } else { + // No local provider, no mapping - use amp credits proxy + decision = &RoutingDecision{ + RouteType: RouteTypeAmpCredits, + ResolvedModel: req.RequestedModel, + ShouldProxy: true, + } + } + + return decision +} + +// findLocalCandidates finds local provider candidates for a model. +// If the internal registry is empty, it falls back to the global model registry. +func (r *Router) findLocalCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate { + var candidates []ProviderCandidate + + // Check internal registry first + registryProviders := r.registry.All() + if len(registryProviders) > 0 { + for _, p := range registryProviders { + if !p.SupportsModel(model) { + continue + } + + // Apply thinking suffix if needed + actualModel := model + if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix { + actualModel = model + "(" + suffixResult.RawSuffix + ")" + } + + if p.Available(actualModel) { + candidates = append(candidates, ProviderCandidate{ + Provider: p, + Model: actualModel, + }) + } + } + } else { + // Fallback to global model registry (same logic as FallbackHandler) + // This ensures compatibility when the wrapper is initialized with an empty registry + providers := registry.GetGlobalRegistry().GetModelProviders(model) + if len(providers) > 0 { + actualModel := model + if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix { + actualModel = model + "(" + suffixResult.RawSuffix + ")" + } + // Create a synthetic provider candidate for each provider + for _, providerName := range providers { + candidates = append(candidates, ProviderCandidate{ + Provider: &globalRegistryProvider{name: providerName, model: actualModel}, + Model: actualModel, + }) + } + } + } + + // Sort by priority + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].Provider.Priority() < candidates[j].Provider.Priority() + }) + + return candidates +} + +// globalRegistryProvider is a synthetic Provider implementation that wraps +// a provider name from the global model registry. It is used only for routing +// decisions when the internal registry is empty - actual execution goes through +// the normal handler path, not through this provider's Execute methods. +type globalRegistryProvider struct { + name string + model string +} + +func (p *globalRegistryProvider) Name() string { return p.name } +func (p *globalRegistryProvider) Type() ProviderType { return ProviderTypeOAuth } +func (p *globalRegistryProvider) Priority() int { return 0 } +func (p *globalRegistryProvider) SupportsModel(string) bool { return true } +func (p *globalRegistryProvider) Available(string) bool { return true } + +// Execute is not used for globalRegistryProvider - routing wrapper calls the handler directly. +func (p *globalRegistryProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) { + return executor.Response{}, nil +} + +// ExecuteStream is not used for globalRegistryProvider - routing wrapper calls the handler directly. +func (p *globalRegistryProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) { + return nil, nil +} + +// buildLocalProviderDecision creates a decision for local provider routing. +func (r *Router) buildLocalProviderDecision(requestedModel string, candidates []ProviderCandidate, thinkingSuffix string) *RoutingDecision { + resolvedModel := requestedModel + if thinkingSuffix != "" { + // Ensure thinking suffix is preserved + sr := thinking.ParseSuffix(requestedModel) + if !sr.HasSuffix { + resolvedModel = requestedModel + thinkingSuffix + } + } + + var fallbackModels []string + if len(candidates) > 1 { + for _, c := range candidates[1:] { + fallbackModels = append(fallbackModels, c.Model) + } + } + + return &RoutingDecision{ + RouteType: RouteTypeLocalProvider, + ResolvedModel: resolvedModel, + ProviderName: candidates[0].Provider.Name(), + FallbackModels: fallbackModels, + ShouldProxy: false, + } +} + +// buildMappingDecision creates a decision for model mapping routing. +func (r *Router) buildMappingDecision(requestedModel, mappedModel string, candidates []ProviderCandidate, thinkingSuffix string, fallbackCandidates []ProviderCandidate) *RoutingDecision { + // Apply thinking suffix to resolved model if needed + resolvedModel := mappedModel + if thinkingSuffix != "" { + sr := thinking.ParseSuffix(mappedModel) + if !sr.HasSuffix { + resolvedModel = mappedModel + thinkingSuffix + } + } + + var fallbackModels []string + for _, c := range fallbackCandidates { + fallbackModels = append(fallbackModels, c.Model) + } + + // Also add oauth aliases as fallbacks + baseMapped := thinking.ParseSuffix(mappedModel).ModelName + for _, alias := range r.oauthAliases[strings.ToLower(baseMapped)] { + // Check if this alias has providers + aliasCandidates := r.findLocalCandidates(alias, thinking.SuffixResult{ModelName: alias}) + for _, c := range aliasCandidates { + fallbackModels = append(fallbackModels, c.Model) + } + } + + return &RoutingDecision{ + RouteType: RouteTypeModelMapping, + ResolvedModel: resolvedModel, + ProviderName: candidates[0].Provider.Name(), + FallbackModels: fallbackModels, + ShouldProxy: false, + } +} + +// applyMappings applies model-mappings configuration. +func (r *Router) applyMappings(model string) string { + key := strings.ToLower(strings.TrimSpace(model)) + if mapped, ok := r.modelMappings[key]; ok { + return mapped + } + return model +} + +// findCandidates finds all provider candidates for a model. +func (r *Router) findCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate { + var candidates []ProviderCandidate + + for _, p := range r.registry.All() { + if !p.SupportsModel(model) { + continue + } + + // Apply thinking suffix if needed + actualModel := model + if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix { + actualModel = model + "(" + suffixResult.RawSuffix + ")" + } + + if p.Available(actualModel) { + candidates = append(candidates, ProviderCandidate{ + Provider: p, + Model: actualModel, + }) + } + } + + return candidates +} + +// loadModelMappings loads model-mappings from config. +func (r *Router) loadModelMappings(mappings []config.AmpModelMapping) { + for _, m := range mappings { + from := strings.ToLower(strings.TrimSpace(m.From)) + to := strings.TrimSpace(m.To) + if from != "" && to != "" { + r.modelMappings[from] = to + } + } +} + +// loadOAuthAliases loads oauth-model-alias from config. +func (r *Router) loadOAuthAliases(aliases map[string][]config.OAuthModelAlias) { + for _, entries := range aliases { + for _, entry := range entries { + name := strings.ToLower(strings.TrimSpace(entry.Name)) + alias := strings.TrimSpace(entry.Alias) + if name != "" && alias != "" && name != alias { + r.oauthAliases[name] = append(r.oauthAliases[name], alias) + } + } + } +} diff --git a/internal/routing/router_test.go b/internal/routing/router_test.go new file mode 100644 index 0000000000..c3674d01be --- /dev/null +++ b/internal/routing/router_test.go @@ -0,0 +1,202 @@ +package routing + +import ( + "context" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + globalRegistry "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + "github.com/stretchr/testify/assert" +) + +// mockProvider is a test double for Provider. +type mockProvider struct { + name string + providerType ProviderType + supportsModels map[string]bool + available bool + priority int +} + +func (m *mockProvider) Name() string { return m.name } +func (m *mockProvider) Type() ProviderType { return m.providerType } +func (m *mockProvider) SupportsModel(model string) bool { return m.supportsModels[model] } +func (m *mockProvider) Available(model string) bool { return m.available } +func (m *mockProvider) Priority() int { return m.priority } +func (m *mockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) { + return executor.Response{}, nil +} +func (m *mockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) { + return nil, nil +} + +func TestRouter_Resolve_ModelMappings(t *testing.T) { + registry := NewRegistry() + + // Add a provider + p := &mockProvider{ + name: "test-provider", + providerType: ProviderTypeOAuth, + supportsModels: map[string]bool{"target-model": true}, + available: true, + priority: 1, + } + registry.Register(p) + + // Create router with model mapping + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{ + {From: "user-model", To: "target-model"}, + }, + }, + } + router := NewRouter(registry, cfg) + + // Resolve + decision := router.Resolve("user-model") + + assert.Equal(t, "user-model", decision.RequestedModel) + assert.Equal(t, "target-model", decision.ResolvedModel) + assert.Len(t, decision.Candidates, 1) + assert.Equal(t, "target-model", decision.Candidates[0].Model) +} + +func TestRouter_Resolve_OAuthAliases(t *testing.T) { + registry := NewRegistry() + + // Add providers + p1 := &mockProvider{ + name: "oauth-1", + providerType: ProviderTypeOAuth, + supportsModels: map[string]bool{"primary-model": true}, + available: true, + priority: 1, + } + p2 := &mockProvider{ + name: "oauth-2", + providerType: ProviderTypeOAuth, + supportsModels: map[string]bool{"fallback-model": true}, + available: true, + priority: 2, + } + registry.Register(p1) + registry.Register(p2) + + // Create router with oauth aliases + cfg := &config.Config{ + OAuthModelAlias: map[string][]config.OAuthModelAlias{ + "test-channel": { + {Name: "primary-model", Alias: "fallback-model"}, + }, + }, + } + router := NewRouter(registry, cfg) + + // Resolve + decision := router.Resolve("primary-model") + + assert.Equal(t, "primary-model", decision.ResolvedModel) + assert.Len(t, decision.Candidates, 2) + // Primary should come first (lower priority value) + assert.Equal(t, "primary-model", decision.Candidates[0].Model) + assert.Equal(t, "fallback-model", decision.Candidates[1].Model) +} + +func TestRouter_Resolve_NoProviders(t *testing.T) { + registry := NewRegistry() + cfg := &config.Config{} + router := NewRouter(registry, cfg) + + decision := router.Resolve("unknown-model") + + assert.Equal(t, "unknown-model", decision.ResolvedModel) + assert.Empty(t, decision.Candidates) +} + +// === Global Registry Fallback Tests (T-027) === +// These tests verify that when the internal registry is empty, +// the router falls back to the global model registry. +// This is the core fix for the thinking signature 400 error. + +func TestRouter_GlobalRegistryFallback_LocalProvider(t *testing.T) { + // This test requires registering a model in the global registry. + // We use a model that's already registered via api-key config in production. + // For isolated testing, we can skip if global registry is not populated. + + globalReg := globalRegistry.GetGlobalRegistry() + modelCount := globalReg.GetModelCount("claude-sonnet-4-20250514") + + if modelCount == 0 { + t.Skip("Global registry not populated - run with server context") + } + + // Empty internal registry + emptyRegistry := NewRegistry() + cfg := &config.Config{} + router := NewRouter(emptyRegistry, cfg) + + req := RoutingRequest{ + RequestedModel: "claude-sonnet-4-20250514", + PreferLocalProvider: true, + } + decision := router.ResolveV2(req) + + // Should find provider from global registry + assert.Equal(t, RouteTypeLocalProvider, decision.RouteType) + assert.Equal(t, "claude-sonnet-4-20250514", decision.ResolvedModel) + assert.False(t, decision.ShouldProxy) +} + +func TestRouter_GlobalRegistryFallback_ModelMapping(t *testing.T) { + // This test verifies that model mapping works with global registry fallback. + + globalReg := globalRegistry.GetGlobalRegistry() + modelCount := globalReg.GetModelCount("claude-opus-4-5-thinking") + + if modelCount == 0 { + t.Skip("Global registry not populated - run with server context") + } + + // Empty internal registry + emptyRegistry := NewRegistry() + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{ + {From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"}, + }, + }, + } + router := NewRouter(emptyRegistry, cfg) + + req := RoutingRequest{ + RequestedModel: "claude-opus-4-5-20251101", + PreferLocalProvider: true, + } + decision := router.ResolveV2(req) + + // Should find mapped model from global registry + assert.Equal(t, RouteTypeModelMapping, decision.RouteType) + assert.Equal(t, "claude-opus-4-5-thinking", decision.ResolvedModel) + assert.False(t, decision.ShouldProxy) +} + +func TestRouter_GlobalRegistryFallback_AmpCreditsWhenNotFound(t *testing.T) { + // Empty internal registry + emptyRegistry := NewRegistry() + cfg := &config.Config{} + router := NewRouter(emptyRegistry, cfg) + + // Use a model that definitely doesn't exist anywhere + req := RoutingRequest{ + RequestedModel: "nonexistent-model-12345", + PreferLocalProvider: true, + } + decision := router.ResolveV2(req) + + // Should fall back to AMP credits proxy + assert.Equal(t, RouteTypeAmpCredits, decision.RouteType) + assert.Equal(t, "nonexistent-model-12345", decision.ResolvedModel) + assert.True(t, decision.ShouldProxy) +} diff --git a/internal/routing/router_v2_test.go b/internal/routing/router_v2_test.go new file mode 100644 index 0000000000..903b7aa855 --- /dev/null +++ b/internal/routing/router_v2_test.go @@ -0,0 +1,245 @@ +package routing + +import ( + "context" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + "github.com/stretchr/testify/assert" +) + +func TestRouter_DefaultMode_PrefersLocal(t *testing.T) { + // Setup: Create a router with a mock provider that supports "gpt-4" + registry := NewRegistry() + mockProvider := &MockProvider{ + name: "openai", + supportedModels: []string{"gpt-4"}, + available: true, + priority: 1, + } + registry.Register(mockProvider) + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{ + {From: "gpt-4", To: "claude-local"}, + }, + }, + } + + router := NewRouter(registry, cfg) + + // Test: Request gpt-4 when local provider exists + req := RoutingRequest{ + RequestedModel: "gpt-4", + PreferLocalProvider: true, + ForceModelMapping: false, + } + + decision := router.ResolveV2(req) + + // Assert: Should return LOCAL_PROVIDER, not MODEL_MAPPING + assert.Equal(t, RouteTypeLocalProvider, decision.RouteType) + assert.Equal(t, "gpt-4", decision.ResolvedModel) + assert.Equal(t, "openai", decision.ProviderName) + assert.False(t, decision.ShouldProxy) +} + +func TestRouter_DefaultMode_MapsWhenNoLocal(t *testing.T) { + // Setup: Create a router with NO provider for "gpt-4" but a mapping to "claude-local" + // which has a provider + registry := NewRegistry() + mockProvider := &MockProvider{ + name: "anthropic", + supportedModels: []string{"claude-local"}, + available: true, + priority: 1, + } + registry.Register(mockProvider) + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{ + {From: "gpt-4", To: "claude-local"}, + }, + }, + } + + router := NewRouter(registry, cfg) + + // Test: Request gpt-4 when no local provider exists, but mapping exists + req := RoutingRequest{ + RequestedModel: "gpt-4", + PreferLocalProvider: true, + ForceModelMapping: false, + } + + decision := router.ResolveV2(req) + + // Assert: Should return MODEL_MAPPING + assert.Equal(t, RouteTypeModelMapping, decision.RouteType) + assert.Equal(t, "claude-local", decision.ResolvedModel) + assert.Equal(t, "anthropic", decision.ProviderName) + assert.False(t, decision.ShouldProxy) +} + +func TestRouter_DefaultMode_AmpCreditsWhenNoLocalOrMapping(t *testing.T) { + // Setup: Create a router with no providers and no mappings + registry := NewRegistry() + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{}, + }, + } + + router := NewRouter(registry, cfg) + + // Test: Request a model with no local provider and no mapping + req := RoutingRequest{ + RequestedModel: "unknown-model", + PreferLocalProvider: true, + ForceModelMapping: false, + } + + decision := router.ResolveV2(req) + + // Assert: Should return AMP_CREDITS with ShouldProxy=true + assert.Equal(t, RouteTypeAmpCredits, decision.RouteType) + assert.Equal(t, "unknown-model", decision.ResolvedModel) + assert.True(t, decision.ShouldProxy) + assert.Empty(t, decision.ProviderName) +} + +func TestRouter_ForceMode_MapsEvenWithLocal(t *testing.T) { + // Setup: Create a router with BOTH a local provider for "gpt-4" AND a mapping from "gpt-4" to "claude-local" + // The mapping target "claude-local" also has a provider + registry := NewRegistry() + + // Local provider for gpt-4 + openaiProvider := &MockProvider{ + name: "openai", + supportedModels: []string{"gpt-4"}, + available: true, + priority: 1, + } + registry.Register(openaiProvider) + + // Local provider for the mapped model + anthropicProvider := &MockProvider{ + name: "anthropic", + supportedModels: []string{"claude-local"}, + available: true, + priority: 2, + } + registry.Register(anthropicProvider) + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{ + {From: "gpt-4", To: "claude-local"}, + }, + }, + } + + router := NewRouter(registry, cfg) + + // Test: Request gpt-4 with ForceModelMapping=true + // Even though gpt-4 has a local provider, mapping should take precedence + req := RoutingRequest{ + RequestedModel: "gpt-4", + PreferLocalProvider: false, + ForceModelMapping: true, + } + + decision := router.ResolveV2(req) + + // Assert: Should return MODEL_MAPPING, not LOCAL_PROVIDER + assert.Equal(t, RouteTypeModelMapping, decision.RouteType) + assert.Equal(t, "claude-local", decision.ResolvedModel) + assert.Equal(t, "anthropic", decision.ProviderName) + assert.False(t, decision.ShouldProxy) +} + +func TestRouter_ThinkingSuffix_Preserved(t *testing.T) { + // Setup: Create a router with mapping and provider for mapped model + registry := NewRegistry() + + mockProvider := &MockProvider{ + name: "anthropic", + supportedModels: []string{"claude-local"}, + available: true, + priority: 1, + } + registry.Register(mockProvider) + + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{ + {From: "claude-3-5-sonnet", To: "claude-local"}, + }, + }, + } + + router := NewRouter(registry, cfg) + + // Test: Request claude-3-5-sonnet with thinking suffix + req := RoutingRequest{ + RequestedModel: "claude-3-5-sonnet(thinking:foo)", + PreferLocalProvider: true, + ForceModelMapping: false, + } + + decision := router.ResolveV2(req) + + // Assert: Thinking suffix should be preserved in resolved model + assert.Equal(t, RouteTypeModelMapping, decision.RouteType) + assert.Equal(t, "claude-local(thinking:foo)", decision.ResolvedModel) + assert.Equal(t, "anthropic", decision.ProviderName) +} + +// MockProvider is a mock implementation of Provider for testing +type MockProvider struct { + name string + providerType ProviderType + supportedModels []string + available bool + priority int +} + +func (m *MockProvider) Name() string { + return m.name +} + +func (m *MockProvider) Type() ProviderType { + if m.providerType == "" { + return ProviderTypeOAuth + } + return m.providerType +} + +func (m *MockProvider) SupportsModel(model string) bool { + for _, supported := range m.supportedModels { + if supported == model { + return true + } + } + return false +} + +func (m *MockProvider) Available(model string) bool { + return m.available +} + +func (m *MockProvider) Priority() int { + return m.priority +} + +func (m *MockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) { + return executor.Response{}, nil +} + +func (m *MockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) { + return nil, nil +} diff --git a/internal/routing/testutil/fake_handler.go b/internal/routing/testutil/fake_handler.go new file mode 100644 index 0000000000..160aaad8b1 --- /dev/null +++ b/internal/routing/testutil/fake_handler.go @@ -0,0 +1,113 @@ +package testutil + +import ( + "io" + "net/http" + + "github.com/gin-gonic/gin" +) + +// FakeHandlerRecorder records handler invocations for testing. +type FakeHandlerRecorder struct { + Called bool + CallCount int + RequestBody []byte + RequestHeader http.Header + ContextKeys map[string]interface{} + ResponseStatus int + ResponseBody []byte +} + +// NewFakeHandlerRecorder creates a new fake handler recorder. +func NewFakeHandlerRecorder() *FakeHandlerRecorder { + return &FakeHandlerRecorder{ + ContextKeys: make(map[string]interface{}), + ResponseStatus: http.StatusOK, + ResponseBody: []byte(`{"status":"handled"}`), + } +} + +// GinHandler returns a gin.HandlerFunc that records the invocation. +func (f *FakeHandlerRecorder) GinHandler() gin.HandlerFunc { + return func(c *gin.Context) { + f.record(c) + c.Data(f.ResponseStatus, "application/json", f.ResponseBody) + } +} + +// GinHandlerWithModel returns a gin.HandlerFunc that records the invocation and returns the model from context. +// Useful for testing response rewriting in model mapping scenarios. +func (f *FakeHandlerRecorder) GinHandlerWithModel() gin.HandlerFunc { + return func(c *gin.Context) { + f.record(c) + // Return a response with the model field that would be in the actual API response + // If ResponseBody was explicitly set (not default), use that; otherwise generate from context + var body []byte + if mappedModel, exists := c.Get("mapped_model"); exists { + body = []byte(`{"model":"` + mappedModel.(string) + `","status":"handled"}`) + } else { + body = f.ResponseBody + } + c.Data(f.ResponseStatus, "application/json", body) + } +} + +// HTTPHandler returns an http.HandlerFunc that records the invocation. +func (f *FakeHandlerRecorder) HTTPHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + f.Called = true + f.CallCount++ + f.RequestBody = body + f.RequestHeader = r.Header.Clone() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(f.ResponseStatus) + w.Write(f.ResponseBody) + } +} + +// record captures the request details from gin context. +func (f *FakeHandlerRecorder) record(c *gin.Context) { + f.Called = true + f.CallCount++ + + body, _ := io.ReadAll(c.Request.Body) + f.RequestBody = body + f.RequestHeader = c.Request.Header.Clone() + + // Capture common context keys used by routing + if val, exists := c.Get("mapped_model"); exists { + f.ContextKeys["mapped_model"] = val + } + if val, exists := c.Get("fallback_models"); exists { + f.ContextKeys["fallback_models"] = val + } + if val, exists := c.Get("route_type"); exists { + f.ContextKeys["route_type"] = val + } +} + +// Reset clears the recorder state. +func (f *FakeHandlerRecorder) Reset() { + f.Called = false + f.CallCount = 0 + f.RequestBody = nil + f.RequestHeader = nil + f.ContextKeys = make(map[string]interface{}) +} + +// GetContextKey returns a captured context key value. +func (f *FakeHandlerRecorder) GetContextKey(key string) (interface{}, bool) { + val, ok := f.ContextKeys[key] + return val, ok +} + +// WasCalled returns true if the handler was called. +func (f *FakeHandlerRecorder) WasCalled() bool { + return f.Called +} + +// GetCallCount returns the number of times the handler was called. +func (f *FakeHandlerRecorder) GetCallCount() int { + return f.CallCount +} diff --git a/internal/routing/testutil/fake_proxy.go b/internal/routing/testutil/fake_proxy.go new file mode 100644 index 0000000000..3deea5a546 --- /dev/null +++ b/internal/routing/testutil/fake_proxy.go @@ -0,0 +1,83 @@ +package testutil + +import ( + "io" + "net/http" + "net/http/httptest" +) + +// CloseNotifierRecorder wraps httptest.ResponseRecorder with CloseNotify support. +// This is needed because ReverseProxy requires http.CloseNotifier. +type CloseNotifierRecorder struct { + *httptest.ResponseRecorder + closeChan chan bool +} + +// NewCloseNotifierRecorder creates a ResponseRecorder that implements CloseNotifier. +func NewCloseNotifierRecorder() *CloseNotifierRecorder { + return &CloseNotifierRecorder{ + ResponseRecorder: httptest.NewRecorder(), + closeChan: make(chan bool, 1), + } +} + +// CloseNotify implements http.CloseNotifier. +func (c *CloseNotifierRecorder) CloseNotify() <-chan bool { + return c.closeChan +} + +// FakeProxyRecorder records proxy invocations for testing. +type FakeProxyRecorder struct { + Called bool + CallCount int + RequestBody []byte + RequestHeaders http.Header + ResponseStatus int + ResponseBody []byte +} + +// NewFakeProxyRecorder creates a new fake proxy recorder. +func NewFakeProxyRecorder() *FakeProxyRecorder { + return &FakeProxyRecorder{ + ResponseStatus: http.StatusOK, + ResponseBody: []byte(`{"status":"proxied"}`), + } +} + +// ServeHTTP implements http.Handler to act as a reverse proxy. +func (f *FakeProxyRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request) { + f.Called = true + f.CallCount++ + f.RequestHeaders = r.Header.Clone() + + body, err := io.ReadAll(r.Body) + if err == nil { + f.RequestBody = body + } + + w.WriteHeader(f.ResponseStatus) + w.Write(f.ResponseBody) +} + +// GetCallCount returns the number of times the proxy was called. +func (f *FakeProxyRecorder) GetCallCount() int { + return f.CallCount +} + +// Reset clears the recorder state. +func (f *FakeProxyRecorder) Reset() { + f.Called = false + f.CallCount = 0 + f.RequestBody = nil + f.RequestHeaders = nil +} + +// ToHandler returns the recorder as an http.Handler for use with httptest. +func (f *FakeProxyRecorder) ToHandler() http.Handler { + return http.HandlerFunc(f.ServeHTTP) +} + +// CreateTestServer creates an httptest server with this fake proxy. +func (f *FakeProxyRecorder) CreateTestServer() *httptest.Server { + return httptest.NewServer(f.ToHandler()) +} diff --git a/internal/routing/types.go b/internal/routing/types.go new file mode 100644 index 0000000000..30c5061005 --- /dev/null +++ b/internal/routing/types.go @@ -0,0 +1,62 @@ +package routing + +// RouteType represents the type of routing decision made for a request. +type RouteType string + +const ( + // RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free). + RouteTypeLocalProvider RouteType = "LOCAL_PROVIDER" + // RouteTypeModelMapping indicates the request was remapped to another available model (free). + RouteTypeModelMapping RouteType = "MODEL_MAPPING" + // RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits). + RouteTypeAmpCredits RouteType = "AMP_CREDITS" + // RouteTypeNoProvider indicates no provider or fallback available. + RouteTypeNoProvider RouteType = "NO_PROVIDER" +) + +// RoutingRequest contains the information needed to make a routing decision. +type RoutingRequest struct { + // RequestedModel is the model name from the incoming request. + RequestedModel string + // PreferLocalProvider indicates whether to prefer local providers over mappings. + // When true, check local providers first before applying model mappings. + PreferLocalProvider bool + // ForceModelMapping indicates whether to force model mapping even if local provider exists. + // When true, apply model mappings first and skip local provider checks. + ForceModelMapping bool +} + +// RoutingDecision contains the result of a routing decision. +type RoutingDecision struct { + // RouteType indicates the type of routing decision. + RouteType RouteType + // ResolvedModel is the final model name after any mappings. + ResolvedModel string + // ProviderName is the name of the selected provider (if any). + ProviderName string + // FallbackModels is a list of alternative models to try if the primary fails. + FallbackModels []string + // ShouldProxy indicates whether the request should be proxied to ampcode.com. + ShouldProxy bool +} + +// NewRoutingDecision creates a new RoutingDecision with the given parameters. +func NewRoutingDecision(routeType RouteType, resolvedModel, providerName string, fallbackModels []string, shouldProxy bool) *RoutingDecision { + return &RoutingDecision{ + RouteType: routeType, + ResolvedModel: resolvedModel, + ProviderName: providerName, + FallbackModels: fallbackModels, + ShouldProxy: shouldProxy, + } +} + +// IsLocal returns true if the decision routes to a local provider. +func (d *RoutingDecision) IsLocal() bool { + return d.RouteType == RouteTypeLocalProvider || d.RouteType == RouteTypeModelMapping +} + +// HasFallbacks returns true if there are fallback models available. +func (d *RoutingDecision) HasFallbacks() bool { + return len(d.FallbackModels) > 0 +} diff --git a/internal/routing/wrapper.go b/internal/routing/wrapper.go new file mode 100644 index 0000000000..90d10eea08 --- /dev/null +++ b/internal/routing/wrapper.go @@ -0,0 +1,270 @@ +package routing + +import ( + "bufio" + "bytes" + "io" + "net" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys" + "github.com/sirupsen/logrus" +) + +// ProxyFunc is the function type for proxying requests. +type ProxyFunc func(c *gin.Context) + +// ModelRoutingWrapper wraps HTTP handlers with unified model routing logic. +// It replaces the FallbackHandler logic with a Router-based approach. +type ModelRoutingWrapper struct { + router *Router + extractor ModelExtractor + rewriter ModelRewriter + proxyFunc ProxyFunc + logger *logrus.Logger +} + +// NewModelRoutingWrapper creates a new ModelRoutingWrapper with the given dependencies. +// If extractor is nil, a DefaultModelExtractor is used. +// If rewriter is nil, a DefaultModelRewriter is used. +// proxyFunc is called for AMP_CREDITS route type; if nil, the handler will be called instead. +func NewModelRoutingWrapper(router *Router, extractor ModelExtractor, rewriter ModelRewriter, proxyFunc ProxyFunc) *ModelRoutingWrapper { + if extractor == nil { + extractor = NewModelExtractor() + } + if rewriter == nil { + rewriter = NewModelRewriter() + } + return &ModelRoutingWrapper{ + router: router, + extractor: extractor, + rewriter: rewriter, + proxyFunc: proxyFunc, + logger: logrus.New(), + } +} + +// SetLogger sets the logger for the wrapper. +func (w *ModelRoutingWrapper) SetLogger(logger *logrus.Logger) { + w.logger = logger +} + +// Wrap wraps a gin.HandlerFunc with model routing logic. +// The returned handler will: +// 1. Extract the model from the request +// 2. Get a routing decision from the Router +// 3. Handle the request according to the decision type (LOCAL_PROVIDER, MODEL_MAPPING, AMP_CREDITS) +func (w *ModelRoutingWrapper) Wrap(handler gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + // Read request body + bodyBytes, err := io.ReadAll(c.Request.Body) + if err != nil { + w.logger.Errorf("routing wrapper: failed to read request body: %v", err) + handler(c) + return + } + + // Extract model from request + ginParams := map[string]string{ + "action": c.Param("action"), + "path": c.Param("path"), + } + modelName, err := w.extractor.Extract(bodyBytes, ginParams) + if err != nil { + w.logger.Warnf("routing wrapper: failed to extract model: %v", err) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + handler(c) + return + } + + if modelName == "" { + // No model found, proceed with original handler + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + handler(c) + return + } + + // Get routing decision + req := RoutingRequest{ + RequestedModel: modelName, + PreferLocalProvider: true, + ForceModelMapping: false, // TODO: Get from config + } + decision := w.router.ResolveV2(req) + + // Store decision in context for downstream handlers + c.Set(string(ctxkeys.RoutingDecision), decision) + + // Handle based on route type + switch decision.RouteType { + case RouteTypeLocalProvider: + w.handleLocalProvider(c, handler, bodyBytes, decision) + case RouteTypeModelMapping: + w.handleModelMapping(c, handler, bodyBytes, decision) + case RouteTypeAmpCredits: + w.handleAmpCredits(c, handler, bodyBytes) + default: + // No provider available + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + handler(c) + } + } +} + +// handleLocalProvider handles the LOCAL_PROVIDER route type. +func (w *ModelRoutingWrapper) handleLocalProvider(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) { + // Filter Anthropic-Beta header for local provider + filterAnthropicBetaHeader(c) + + // Restore body with original content + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + + // Call handler + handler(c) +} + +// handleModelMapping handles the MODEL_MAPPING route type. +func (w *ModelRoutingWrapper) handleModelMapping(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) { + // Rewrite request body with mapped model + rewrittenBody, err := w.rewriter.RewriteRequestBody(bodyBytes, decision.ResolvedModel) + if err != nil { + w.logger.Warnf("routing wrapper: failed to rewrite request body: %v", err) + rewrittenBody = bodyBytes + } + _ = rewrittenBody + + // Store mapped model in context + c.Set(string(ctxkeys.MappedModel), decision.ResolvedModel) + + // Store fallback models in context if present + if len(decision.FallbackModels) > 0 { + c.Set(string(ctxkeys.FallbackModels), decision.FallbackModels) + } + + // Filter Anthropic-Beta header for local provider + filterAnthropicBetaHeader(c) + + // Restore body with rewritten content + c.Request.Body = io.NopCloser(bytes.NewReader(rewrittenBody)) + + // Wrap response writer to rewrite model back + wrappedWriter, cleanup := w.rewriter.WrapResponseWriter(c.Writer, decision.ResolvedModel, decision.ResolvedModel) + c.Writer = &ginResponseWriterAdapter{ResponseWriter: wrappedWriter, original: c.Writer} + + // Call handler + handler(c) + + // Cleanup (flush response rewriting) + cleanup() +} + +// handleAmpCredits handles the AMP_CREDITS route type. +// It calls the proxy function directly if available, otherwise passes to handler. +// Does NOT filter headers or rewrite body - proxy handles everything. +func (w *ModelRoutingWrapper) handleAmpCredits(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte) { + // Restore body with original content (no rewriting for proxy) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + + // Call proxy function if available, otherwise fall back to handler + if w.proxyFunc != nil { + w.proxyFunc(c) + } else { + handler(c) + } +} + +// filterAnthropicBetaHeader filters Anthropic-Beta header for local providers. +func filterAnthropicBetaHeader(c *gin.Context) { + if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" { + filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07") + if filtered != "" { + c.Request.Header.Set("Anthropic-Beta", filtered) + } else { + c.Request.Header.Del("Anthropic-Beta") + } + } +} + +// filterBetaFeatures removes specified beta features from the header. +func filterBetaFeatures(betaHeader, featureToRemove string) string { + // Simple implementation - can be enhanced + if betaHeader == featureToRemove { + return "" + } + return betaHeader +} + +// ginResponseWriterAdapter adapts http.ResponseWriter to gin.ResponseWriter. +type ginResponseWriterAdapter struct { + http.ResponseWriter + original gin.ResponseWriter +} + +func (a *ginResponseWriterAdapter) WriteHeader(code int) { + a.ResponseWriter.WriteHeader(code) +} + +func (a *ginResponseWriterAdapter) Write(data []byte) (int, error) { + return a.ResponseWriter.Write(data) +} + +func (a *ginResponseWriterAdapter) Header() http.Header { + return a.ResponseWriter.Header() +} + +// CloseNotify implements http.CloseNotifier. +func (a *ginResponseWriterAdapter) CloseNotify() <-chan bool { + if notifier, ok := a.ResponseWriter.(http.CloseNotifier); ok { + return notifier.CloseNotify() + } + return a.original.CloseNotify() +} + +// Flush implements http.Flusher. +func (a *ginResponseWriterAdapter) Flush() { + if flusher, ok := a.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +// Hijack implements http.Hijacker. +func (a *ginResponseWriterAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := a.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return a.original.Hijack() +} + +// Status returns the HTTP status code. +func (a *ginResponseWriterAdapter) Status() int { + return a.original.Status() +} + +// Size returns the number of bytes already written into the response http body. +func (a *ginResponseWriterAdapter) Size() int { + return a.original.Size() +} + +// Written returns whether or not the response for this context has been written. +func (a *ginResponseWriterAdapter) Written() bool { + return a.original.Written() +} + +// WriteHeaderNow forces WriteHeader to be called. +func (a *ginResponseWriterAdapter) WriteHeaderNow() { + a.original.WriteHeaderNow() +} + +// WriteString writes the given string into the response body. +func (a *ginResponseWriterAdapter) WriteString(s string) (int, error) { + return a.Write([]byte(s)) +} + +// Pusher returns the http.Pusher for server push. +func (a *ginResponseWriterAdapter) Pusher() http.Pusher { + if pusher, ok := a.ResponseWriter.(http.Pusher); ok { + return pusher + } + return nil +} diff --git a/internal/thinking/provider/claude/apply.go b/internal/thinking/provider/claude/apply.go index 3c74d5146d..3faf4786fb 100644 --- a/internal/thinking/provider/claude/apply.go +++ b/internal/thinking/provider/claude/apply.go @@ -83,6 +83,10 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo * // Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint) result = a.normalizeClaudeBudget(result, config.Budget, modelInfo) + + // When thinking is enabled, Claude API requires assistant messages with tool_use + // to have a thinking block. Inject empty thinking block if missing. + result = injectThinkingBlockForToolUse(result) return result, nil } @@ -149,18 +153,85 @@ func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, body = []byte(`{}`) } + var result []byte switch config.Mode { case thinking.ModeNone: - result, _ := sjson.SetBytes(body, "thinking.type", "disabled") + result, _ = sjson.SetBytes(body, "thinking.type", "disabled") result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens") return result, nil case thinking.ModeAuto: - result, _ := sjson.SetBytes(body, "thinking.type", "enabled") + result, _ = sjson.SetBytes(body, "thinking.type", "enabled") result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens") - return result, nil default: - result, _ := sjson.SetBytes(body, "thinking.type", "enabled") + result, _ = sjson.SetBytes(body, "thinking.type", "enabled") result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget) - return result, nil } + + // When thinking is enabled, Claude API requires assistant messages with tool_use + // to have a thinking block. Inject empty thinking block if missing. + result = injectThinkingBlockForToolUse(result) + return result, nil +} + +// injectThinkingBlockForToolUse adds empty thinking block to assistant messages +// that have tool_use but no thinking block. This is required by Claude API when +// thinking is enabled. +func injectThinkingBlockForToolUse(body []byte) []byte { + messages := gjson.GetBytes(body, "messages") + if !messages.IsArray() { + return body + } + + messageArray := messages.Array() + modified := false + newMessages := "[]" + + for _, msg := range messageArray { + role := msg.Get("role").String() + if role != "assistant" { + newMessages, _ = sjson.SetRaw(newMessages, "-1", msg.Raw) + continue + } + + content := msg.Get("content") + if !content.IsArray() { + newMessages, _ = sjson.SetRaw(newMessages, "-1", msg.Raw) + continue + } + + contentArray := content.Array() + hasToolUse := false + hasThinking := false + + for _, part := range contentArray { + partType := part.Get("type").String() + if partType == "tool_use" { + hasToolUse = true + } + if partType == "thinking" { + hasThinking = true + } + } + + if hasToolUse && !hasThinking { + // Inject empty thinking block at the beginning of content + newContent := "[]" + newContent, _ = sjson.SetRaw(newContent, "-1", `{"type":"thinking","thinking":""}`) + for _, part := range contentArray { + newContent, _ = sjson.SetRaw(newContent, "-1", part.Raw) + } + msgJSON := msg.Raw + msgJSON, _ = sjson.SetRaw(msgJSON, "content", newContent) + newMessages, _ = sjson.SetRaw(newMessages, "-1", msgJSON) + modified = true + continue + } + + newMessages, _ = sjson.SetRaw(newMessages, "-1", msg.Raw) + } + + if modified { + body, _ = sjson.SetRawBytes(body, "messages", []byte(newMessages)) + } + return body } diff --git a/internal/thinking/provider/claude/apply_test.go b/internal/thinking/provider/claude/apply_test.go new file mode 100644 index 0000000000..dc7916e84f --- /dev/null +++ b/internal/thinking/provider/claude/apply_test.go @@ -0,0 +1,187 @@ +package claude + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + "github.com/tidwall/gjson" +) + +func TestInjectThinkingBlockForToolUse(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "assistant with tool_use but no thinking - should inject thinking", + input: `{ + "model": "kimi-k2.5", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me use a tool"}, + {"type": "tool_use", "id": "tool_1", "name": "test_tool", "input": {}} + ] + } + ] + }`, + expected: "thinking", + }, + { + name: "assistant with tool_use and thinking - should not modify", + input: `{ + "model": "kimi-k2.5", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "I need to use a tool"}, + {"type": "tool_use", "id": "tool_1", "name": "test_tool", "input": {}} + ] + } + ] + }`, + expected: "thinking", + }, + { + name: "user message with tool_use - should not modify", + input: `{ + "model": "kimi-k2.5", + "messages": [ + { + "role": "user", + "content": [ + {"type": "tool_result", "tool_use_id": "tool_1", "content": "result"} + ] + } + ] + }`, + expected: "", + }, + { + name: "assistant without tool_use - should not modify", + input: `{ + "model": "kimi-k2.5", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Hello!"} + ] + } + ] + }`, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := injectThinkingBlockForToolUse([]byte(tt.input)) + + // Check if thinking block exists in assistant messages with tool_use + messages := gjson.GetBytes(result, "messages") + if !messages.IsArray() { + t.Fatal("messages is not an array") + } + + for _, msg := range messages.Array() { + if msg.Get("role").String() == "assistant" { + content := msg.Get("content") + if !content.IsArray() { + continue + } + + hasToolUse := false + hasThinking := false + for _, part := range content.Array() { + partType := part.Get("type").String() + if partType == "tool_use" { + hasToolUse = true + } + if partType == "thinking" { + hasThinking = true + } + } + + if hasToolUse && tt.expected == "thinking" && !hasThinking { + t.Errorf("Expected thinking block in assistant message with tool_use, but not found") + } + } + } + }) + } +} + +func TestApplyCompatibleClaude(t *testing.T) { + tests := []struct { + name string + input string + config thinking.ThinkingConfig + expectThinking bool + }{ + { + name: "thinking enabled with tool_use - should inject thinking block", + input: `{ + "model": "kimi-k2.5", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "tool_use", "id": "tool_1", "name": "test_tool", "input": {}} + ] + } + ] + }`, + config: thinking.ThinkingConfig{ + Mode: thinking.ModeBudget, + Budget: 4000, + }, + expectThinking: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := applyCompatibleClaude([]byte(tt.input), tt.config) + if err != nil { + t.Fatalf("applyCompatibleClaude failed: %v", err) + } + + // Check if thinking.type is enabled + thinkingType := gjson.GetBytes(result, "thinking.type").String() + if thinkingType != "enabled" { + t.Errorf("Expected thinking.type=enabled, got %s", thinkingType) + } + + // Check if thinking block is injected + messages := gjson.GetBytes(result, "messages") + if !messages.IsArray() { + t.Fatal("messages is not an array") + } + + for _, msg := range messages.Array() { + if msg.Get("role").String() == "assistant" { + content := msg.Get("content") + if !content.IsArray() { + continue + } + + hasThinking := false + for _, part := range content.Array() { + if part.Get("type").String() == "thinking" { + hasThinking = true + break + } + } + + if tt.expectThinking && !hasThinking { + t.Errorf("Expected thinking block in assistant message, but not found. Result: %s", string(result)) + } + } + } + }) + } +} diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index 9bef7125d5..3a0c8d7b0d 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -115,7 +115,9 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ if signatureResult.Exists() && signatureResult.String() != "" { arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2) if len(arrayClientSignatures) == 2 { - if modelName == arrayClientSignatures[0] { + // Compare using model group to handle model mapping + // e.g., claude-opus-4-5-thinking -> "claude" group should match "claude#signature" + if cache.GetModelGroup(modelName) == arrayClientSignatures[0] { clientSignature = arrayClientSignatures[1] } } diff --git a/internal/translator/openai/claude/openai_claude_request.go b/internal/translator/openai/claude/openai_claude_request.go index dc832e9cee..8fac14ecf6 100644 --- a/internal/translator/openai/claude/openai_claude_request.go +++ b/internal/translator/openai/claude/openai_claude_request.go @@ -61,10 +61,13 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream out, _ = sjson.Set(out, "stream", stream) // Thinking: Convert Claude thinking.budget_tokens to OpenAI reasoning_effort + // Also track if thinking is enabled to ensure reasoning_content is added for tool_calls + thinkingEnabled := false if thinkingConfig := root.Get("thinking"); thinkingConfig.Exists() && thinkingConfig.IsObject() { if thinkingType := thinkingConfig.Get("type"); thinkingType.Exists() { switch thinkingType.String() { case "enabled": + thinkingEnabled = true if budgetTokens := thinkingConfig.Get("budget_tokens"); budgetTokens.Exists() { budget := int(budgetTokens.Int()) if effort, ok := thinking.ConvertBudgetToLevel(budget); ok && effort != "" { @@ -217,6 +220,10 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream // Add reasoning_content if present if hasReasoning { msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", reasoningContent) + } else if thinkingEnabled && hasToolCalls { + // Claude API requires reasoning_content in assistant messages with tool_calls + // when thinking mode is enabled, even if empty + msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", "") } // Add tool_calls if present (in same message as content) diff --git a/internal/translator/openai/claude/openai_claude_request_test.go b/internal/translator/openai/claude/openai_claude_request_test.go index d08de1b25c..3e7fe8fd07 100644 --- a/internal/translator/openai/claude/openai_claude_request_test.go +++ b/internal/translator/openai/claude/openai_claude_request_test.go @@ -588,3 +588,124 @@ func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *t t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got) } } + +// TestConvertClaudeRequestToOpenAI_ThinkingEnabledToolCallsNoReasoning tests that +// when thinking mode is enabled and assistant message has tool_calls but no thinking content, +// an empty reasoning_content is added to satisfy Claude API requirements. +func TestConvertClaudeRequestToOpenAI_ThinkingEnabledToolCallsNoReasoning(t *testing.T) { + tests := []struct { + name string + inputJSON string + wantHasReasoningContent bool + wantReasoningContent string + }{ + { + name: "thinking enabled with tool_calls but no thinking content adds empty reasoning_content", + inputJSON: `{ + "model": "claude-3-opus", + "thinking": {"type": "enabled", "budget_tokens": 4000}, + "messages": [{ + "role": "assistant", + "content": [ + {"type": "text", "text": "I will help you."}, + {"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}} + ] + }] + }`, + wantHasReasoningContent: true, + wantReasoningContent: "", + }, + { + name: "thinking enabled with tool_calls and thinking content uses actual reasoning", + inputJSON: `{ + "model": "claude-3-opus", + "thinking": {"type": "enabled", "budget_tokens": 4000}, + "messages": [{ + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "Let me analyze this..."}, + {"type": "text", "text": "I will help you."}, + {"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}} + ] + }] + }`, + wantHasReasoningContent: true, + wantReasoningContent: "Let me analyze this...", + }, + { + name: "thinking disabled with tool_calls does not add reasoning_content", + inputJSON: `{ + "model": "claude-3-opus", + "thinking": {"type": "disabled"}, + "messages": [{ + "role": "assistant", + "content": [ + {"type": "text", "text": "I will help you."}, + {"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}} + ] + }] + }`, + wantHasReasoningContent: false, + wantReasoningContent: "", + }, + { + name: "no thinking config with tool_calls does not add reasoning_content", + inputJSON: `{ + "model": "claude-3-opus", + "messages": [{ + "role": "assistant", + "content": [ + {"type": "text", "text": "I will help you."}, + {"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}} + ] + }] + }`, + wantHasReasoningContent: false, + wantReasoningContent: "", + }, + { + name: "thinking enabled without tool_calls and no thinking content does not add reasoning_content", + inputJSON: `{ + "model": "claude-3-opus", + "thinking": {"type": "enabled", "budget_tokens": 4000}, + "messages": [{ + "role": "assistant", + "content": [ + {"type": "text", "text": "Simple response without tools."} + ] + }] + }`, + wantHasReasoningContent: false, + wantReasoningContent: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false) + resultJSON := gjson.ParseBytes(result) + + messages := resultJSON.Get("messages").Array() + if len(messages) == 0 { + t.Fatal("Expected at least one message") + } + + assistantMsg := messages[0] + if assistantMsg.Get("role").String() != "assistant" { + t.Fatalf("Expected assistant message, got %s", assistantMsg.Get("role").String()) + } + + hasReasoningContent := assistantMsg.Get("reasoning_content").Exists() + if hasReasoningContent != tt.wantHasReasoningContent { + t.Errorf("reasoning_content existence = %v, want %v", hasReasoningContent, tt.wantHasReasoningContent) + } + + if hasReasoningContent { + gotReasoningContent := assistantMsg.Get("reasoning_content").String() + if gotReasoningContent != tt.wantReasoningContent { + t.Errorf("reasoning_content = %q, want %q", gotReasoningContent, tt.wantReasoningContent) + } + } + }) + } +} diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 3de2b22953..36ffe2074d 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -255,16 +255,15 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c * parentCtx = logging.WithRequestID(parentCtx, requestID) } } - newCtx, cancel := context.WithCancel(parentCtx) - if requestCtx != nil && requestCtx != parentCtx { - go func() { - select { - case <-requestCtx.Done(): - cancel() - case <-newCtx.Done(): - } - }() + + // Use requestCtx as base if available to preserve amp context values (fallback_models, etc.) + // Falls back to parentCtx if no request context + baseCtx := parentCtx + if requestCtx != nil { + baseCtx = requestCtx } + + newCtx, cancel := context.WithCancel(baseCtx) newCtx = context.WithValue(newCtx, "gin", c) newCtx = context.WithValue(newCtx, "handler", handler) return newCtx, func(params ...interface{}) { diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 3a64c8c347..26c538d76d 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -18,6 +18,7 @@ import ( internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" @@ -562,192 +563,188 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli return nil, &Error{Code: "auth_not_found", Message: "no auth available"} } -func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - if len(providers) == 0 { - return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} - } +func (m *Manager) executeWithFallback( + ctx context.Context, + initialProviders []string, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + exec func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error, +) error { routeModel := req.Model + providers := initialProviders opts = ensureRequestedModelMetadata(opts, routeModel) tried := make(map[string]struct{}) var lastErr error + + // Track fallback models from context (provided by Amp module fallback_models key) + var fallbacks []string + if v := ctx.Value(ctxkeys.FallbackModels); v != nil { + if fs, ok := v.([]string); ok { + fallbacks = fs + } + } + fallbackIdx := -1 + for { auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried) if errPick != nil { + // No more auths for current model. Try next fallback model if available. + if fallbackIdx+1 < len(fallbacks) { + fallbackIdx++ + routeModel = fallbacks[fallbackIdx] + log.Debugf("no more auths for current model, trying fallback model: %s (fallback %d/%d)", routeModel, fallbackIdx+1, len(fallbacks)) + + // Reset tried set for the new model and find its providers + tried = make(map[string]struct{}) + providers = util.GetProviderName(thinking.ParseSuffix(routeModel).ModelName) + // Reset opts for the new model + opts = ensureRequestedModelMetadata(opts, routeModel) + if len(providers) == 0 { + log.Debugf("fallback model %s has no providers, skipping", routeModel) + continue // Try next fallback if this one has no providers + } + continue + } + if lastErr != nil { - return cliproxyexecutor.Response{}, lastErr + return lastErr } - return cliproxyexecutor.Response{}, errPick + return errPick } - entry := logEntryWithRequestID(ctx) - debugLogAuthSelection(entry, auth, provider, req.Model) - tried[auth.ID] = struct{}{} - execCtx := ctx - if rt := m.roundTripperFor(auth); rt != nil { - execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) - execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) - } - execReq := req - execReq.Model = rewriteModelForAuth(routeModel, auth) - execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) - execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) - resp, errExec := executor.Execute(execCtx, auth, execReq, opts) - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} - if errExec != nil { - if errCtx := execCtx.Err(); errCtx != nil { - return cliproxyexecutor.Response{}, errCtx - } - result.Error = &Error{Message: errExec.Error()} - var se cliproxyexecutor.StatusError - if errors.As(errExec, &se) && se != nil { - result.Error.HTTPStatus = se.StatusCode() - } - if ra := retryAfterFromError(errExec); ra != nil { - result.RetryAfter = ra + if err := exec(ctx, executor, auth, provider, routeModel); err != nil { + if errCtx := ctx.Err(); errCtx != nil { + return errCtx } - m.MarkResult(execCtx, result) - lastErr = errExec + lastErr = err continue } - m.MarkResult(execCtx, result) - return resp, nil + return nil } } -func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { +func (m *Manager) executeMixedAttempt( + ctx context.Context, + auth *Auth, + provider, routeModel string, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + exec func(ctx context.Context, execReq cliproxyexecutor.Request) error, +) error { + entry := logEntryWithRequestID(ctx) + debugLogAuthSelection(entry, auth, provider, req.Model) + + execCtx := ctx + if rt := m.roundTripperFor(auth); rt != nil { + execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) + execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) + } + + execReq := req + execReq.Model = rewriteModelForAuth(routeModel, auth) + execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) + execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) + + err := exec(execCtx, execReq) + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: err == nil} + if err != nil { + result.Error = &Error{Message: err.Error()} + var se cliproxyexecutor.StatusError + if errors.As(err, &se) && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + if ra := retryAfterFromError(err); ra != nil { + result.RetryAfter = ra + } + } + m.MarkResult(execCtx, result) + return err +} + +func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { if len(providers) == 0 { return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} } - routeModel := req.Model - opts = ensureRequestedModelMetadata(opts, routeModel) - tried := make(map[string]struct{}) - var lastErr error - for { - auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried) - if errPick != nil { - if lastErr != nil { - return cliproxyexecutor.Response{}, lastErr - } - return cliproxyexecutor.Response{}, errPick - } - entry := logEntryWithRequestID(ctx) - debugLogAuthSelection(entry, auth, provider, req.Model) + var resp cliproxyexecutor.Response + err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error { + return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error { + var errExec error + resp, errExec = executor.Execute(execCtx, auth, execReq, opts) + return errExec + }) + }) + return resp, err +} - tried[auth.ID] = struct{}{} - execCtx := ctx - if rt := m.roundTripperFor(auth); rt != nil { - execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) - execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) - } - execReq := req - execReq.Model = rewriteModelForAuth(routeModel, auth) - execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) - execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) - resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} - if errExec != nil { - if errCtx := execCtx.Err(); errCtx != nil { - return cliproxyexecutor.Response{}, errCtx - } - result.Error = &Error{Message: errExec.Error()} - var se cliproxyexecutor.StatusError - if errors.As(errExec, &se) && se != nil { - result.Error.HTTPStatus = se.StatusCode() - } - if ra := retryAfterFromError(errExec); ra != nil { - result.RetryAfter = ra - } - m.MarkResult(execCtx, result) - lastErr = errExec - continue - } - m.MarkResult(execCtx, result) - return resp, nil +func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + if len(providers) == 0 { + return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} } + + var resp cliproxyexecutor.Response + err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error { + return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error { + var errExec error + resp, errExec = executor.CountTokens(execCtx, auth, execReq, opts) + return errExec + }) + }) + return resp, err } func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { if len(providers) == 0 { return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} } - routeModel := req.Model - opts = ensureRequestedModelMetadata(opts, routeModel) - tried := make(map[string]struct{}) - var lastErr error - for { - auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried) - if errPick != nil { - if lastErr != nil { - return nil, lastErr - } - return nil, errPick - } - entry := logEntryWithRequestID(ctx) - debugLogAuthSelection(entry, auth, provider, req.Model) - - tried[auth.ID] = struct{}{} - execCtx := ctx - if rt := m.roundTripperFor(auth); rt != nil { - execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) - execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) - } - execReq := req - execReq.Model = rewriteModelForAuth(routeModel, auth) - execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) - execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) - chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) - if errStream != nil { - if errCtx := execCtx.Err(); errCtx != nil { - return nil, errCtx + var chunks <-chan cliproxyexecutor.StreamChunk + err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error { + return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error { + var errExec error + chunks, errExec = executor.ExecuteStream(execCtx, auth, execReq, opts) + if errExec != nil { + return errExec } - rerr := &Error{Message: errStream.Error()} - var se cliproxyexecutor.StatusError - if errors.As(errStream, &se) && se != nil { - rerr.HTTPStatus = se.StatusCode() - } - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} - result.RetryAfter = retryAfterFromError(errStream) - m.MarkResult(execCtx, result) - lastErr = errStream - continue - } - out := make(chan cliproxyexecutor.StreamChunk) - go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) { - defer close(out) - var failed bool - forward := true - for chunk := range streamChunks { - if chunk.Err != nil && !failed { - failed = true - rerr := &Error{Message: chunk.Err.Error()} - var se cliproxyexecutor.StatusError - if errors.As(chunk.Err, &se) && se != nil { - rerr.HTTPStatus = se.StatusCode() + + out := make(chan cliproxyexecutor.StreamChunk) + go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) { + defer close(out) + var failed bool + forward := true + for chunk := range streamChunks { + if chunk.Err != nil && !failed { + failed = true + rerr := &Error{Message: chunk.Err.Error()} + var se cliproxyexecutor.StatusError + if errors.As(chunk.Err, &se) && se != nil { + rerr.HTTPStatus = se.StatusCode() + } + m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr}) + } + if !forward { + continue + } + if streamCtx == nil { + out <- chunk + continue + } + select { + case <-streamCtx.Done(): + forward = false + case out <- chunk: } - m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr}) - } - if !forward { - continue - } - if streamCtx == nil { - out <- chunk - continue } - select { - case <-streamCtx.Done(): - forward = false - case out <- chunk: + if !failed { + m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true}) } - } - if !failed { - m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true}) - } - }(execCtx, auth.Clone(), provider, chunks) - return out, nil - } + }(execCtx, auth.Clone(), provider, chunks) + chunks = out + return nil + }) + }) + return chunks, err } func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {