Skip to content
Draft
7 changes: 7 additions & 0 deletions internal/api/modules/amp/amp.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ func (m *AmpModule) Register(ctx modules.Context) error {
m.registerOnce.Do(func() {
// Initialize model mapper from config (for routing unavailable models to alternatives)
m.modelMapper = NewModelMapper(settings.ModelMappings)
// Load oauth-model-alias for provider lookup via aliases
m.modelMapper.UpdateOAuthModelAlias(ctx.Config.OAuthModelAlias)

// Store initial config for partial reload comparison
settingsCopy := settings
Expand Down Expand Up @@ -212,6 +214,11 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
}
}

// Always update oauth-model-alias for model mapper (used for provider lookup)
if m.modelMapper != nil {
m.modelMapper.UpdateOAuthModelAlias(cfg.OAuthModelAlias)
}

if m.enabled {
// Check upstream URL change - now supports hot-reload
if newUpstreamURL == "" && oldUpstreamURL != "" {
Expand Down
127 changes: 89 additions & 38 deletions internal/api/modules/amp/fallback_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package amp

import (
"bytes"
"errors"
"io"
"net/http"
"net/http/httputil"
"strings"
"time"

"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
Expand All @@ -30,7 +33,13 @@ const (
)

// MappedModelContextKey is the Gin context key for passing mapped model names.
const MappedModelContextKey = "mapped_model"
// Deprecated: Use ctxkeys.MappedModel instead.
const MappedModelContextKey = string(ctxkeys.MappedModel)

// FallbackModelsContextKey is the Gin context key for passing fallback model names.
// When the primary mapped model fails (e.g., quota exceeded), these models can be tried.
// Deprecated: Use ctxkeys.FallbackModels instead.
const FallbackModelsContextKey = string(ctxkeys.FallbackModels)

// logAmpRouting logs the routing decision for an Amp request with structured fields
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
Expand Down Expand Up @@ -77,6 +86,10 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid

// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
// when the model's provider is not available in CLIProxyAPI
//
// Deprecated: FallbackHandler is deprecated in favor of routing.ModelRoutingWrapper.
// Use routing.NewModelRoutingWrapper() instead for unified routing logic.
// This type is kept for backward compatibility and test purposes.
type FallbackHandler struct {
getProxy func() *httputil.ReverseProxy
modelMapper ModelMapper
Expand All @@ -85,6 +98,8 @@ type FallbackHandler struct {

// NewFallbackHandler creates a new fallback handler wrapper
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
//
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
return &FallbackHandler{
getProxy: getProxy,
Expand All @@ -93,6 +108,8 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler
}

// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
//
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
if forceModelMappings == nil {
forceModelMappings = func() bool { return false }
Expand All @@ -113,6 +130,20 @@ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
// Swallow ErrAbortHandler panics from ReverseProxy to avoid noisy stack traces.
// ReverseProxy raises this panic when the client connection is closed prematurely
// (e.g., user cancels request, network disconnect) or when ServeHTTP is called
// with a ResponseWriter that doesn't implement http.CloseNotifier.
// This is an expected error condition, not a bug, so we handle it gracefully.
defer func() {
if rec := recover(); rec != nil {
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
return
}
panic(rec)
}
}()

requestPath := c.Request.URL.Path

// Read the request body to extract the model name
Expand Down Expand Up @@ -142,58 +173,85 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
}

resolveMappedModel := func() (string, []string) {
// resolveMappedModels returns all mapped models (primary + fallbacks) and providers for the first one.
resolveMappedModels := func() ([]string, []string) {
if fh.modelMapper == nil {
return "", nil
return nil, nil
}

mappedModel := fh.modelMapper.MapModel(modelName)
if mappedModel == "" {
mappedModel = fh.modelMapper.MapModel(normalizedModel)
mapper, ok := fh.modelMapper.(*DefaultModelMapper)
if !ok {
// Fallback to single model for non-DefaultModelMapper
mappedModel := fh.modelMapper.MapModel(modelName)
if mappedModel == "" {
mappedModel = fh.modelMapper.MapModel(normalizedModel)
}
if mappedModel == "" {
return nil, nil
}
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
mappedProviders := util.GetProviderName(mappedBaseModel)
if len(mappedProviders) == 0 {
return nil, nil
}
return []string{mappedModel}, mappedProviders
}
mappedModel = strings.TrimSpace(mappedModel)
if mappedModel == "" {
return "", nil

// Use MapModelWithFallbacks for DefaultModelMapper
mappedModels := mapper.MapModelWithFallbacks(modelName)
if len(mappedModels) == 0 {
mappedModels = mapper.MapModelWithFallbacks(normalizedModel)
}
if len(mappedModels) == 0 {
return nil, nil
}

// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
// already specifies its own thinking suffix.
if thinkingSuffix != "" {
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
if !mappedSuffixResult.HasSuffix {
mappedModel += thinkingSuffix
// Apply thinking suffix if needed
for i, model := range mappedModels {
if thinkingSuffix != "" {
suffixResult := thinking.ParseSuffix(model)
if !suffixResult.HasSuffix {
mappedModels[i] = model + thinkingSuffix
}
}
}

mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
mappedProviders := util.GetProviderName(mappedBaseModel)
if len(mappedProviders) == 0 {
return "", nil
// Get providers for the first model
firstBaseModel := thinking.ParseSuffix(mappedModels[0]).ModelName
providers := util.GetProviderName(firstBaseModel)
if len(providers) == 0 {
return nil, nil
}

return mappedModel, mappedProviders
return mappedModels, providers
}

// Track resolved model for logging (may change if mapping is applied)
resolvedModel := normalizedModel
usedMapping := false
var providers []string

// Helper to apply model mapping and update state
applyMapping := func(mappedModels []string, mappedProviders []string) {
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
if len(mappedModels) > 1 {
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
}
resolvedModel = mappedModels[0]
usedMapping = true
providers = mappedProviders
}

// Check if model mappings should be forced ahead of local API keys
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()

if forceMappings {
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
// This allows users to route Amp requests to their preferred OAuth providers
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
applyMapping(mappedModels, mappedProviders)
}
Comment on lines +253 to 255
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code is identical to the one at lines 249-261. To improve maintainability and avoid code duplication, consider extracting this logic into a local helper function.

For example:

applyMapping := func(mappedModels []string, mappedProviders []string) {
    bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
    c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
    // Store mapped model and fallbacks in context for handlers
    c.Set(MappedModelContextKey, mappedModels[0])
    if len(mappedModels) > 1 {
        c.Set(FallbackModelsContextKey, mappedModels[1:])
    }
    resolvedModel = mappedModels[0]
    usedMapping = true
    providers = mappedProviders
}

if forceMappings {
    if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
        applyMapping(mappedModels, mappedProviders)
    }
    // ...
} else {
    // ...
    if len(providers) == 0 {
        if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
            applyMapping(mappedModels, mappedProviders)
        }
    }
}

Note that bodyBytes, resolvedModel, usedMapping, and providers would need to be handled as they are modified by this helper. Using a closure that captures these variables would be a clean way to implement this.


// If no mapping applied, check for local providers
Expand All @@ -206,15 +264,8 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc

if len(providers) == 0 {
// No providers configured - check if we have a model mapping
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
applyMapping(mappedModels, mappedProviders)
}
}
}
Expand Down
Loading
Loading