diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index e3dff6b8f..1d5997a65 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -175,6 +175,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String()) require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) + require.Equal(t, "/v1/messages?beta=true", upstream.lastReq.URL.RequestURI()) require.Empty(t, upstream.lastReq.Header.Get("authorization")) require.Empty(t, upstream.lastReq.Header.Get("x-goog-api-key")) require.Empty(t, upstream.lastReq.Header.Get("cookie")) @@ -256,6 +257,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体") require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String()) require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) + require.Equal(t, "/v1/messages/count_tokens?beta=true", upstream.lastReq.URL.RequestURI()) require.Empty(t, upstream.lastReq.Header.Get("authorization")) require.Empty(t, upstream.lastReq.Header.Get("cookie")) require.Equal(t, http.StatusOK, rec.Code) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 3fabead05..888be21bf 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -12,6 +12,7 @@ import ( "log/slog" mathrand "math/rand" "net/http" + "net/url" "os" "regexp" "sort" @@ -52,6 +53,21 @@ const ( defaultModelsListCacheTTL = 15 * time.Second ) +func buildAnthropicTargetURL(validatedURL, endpoint string) string { + targetURL := validatedURL + endpoint + parsedURL, err := url.Parse(targetURL) + if err != nil { + if strings.Contains(targetURL, "?") { + return targetURL + "&beta=true" + } + return targetURL + "?beta=true" + } + query := parsedURL.Query() + query.Set("beta", "true") + parsedURL.RawQuery = query.Encode() + return parsedURL.String() +} + const ( claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) @@ -3963,7 +3979,7 @@ func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough( if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages" + targetURL = buildAnthropicTargetURL(validatedURL, "/v1/messages") } req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) @@ -4343,7 +4359,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages" + targetURL = buildAnthropicTargetURL(validatedURL, "/v1/messages") } } @@ -6291,7 +6307,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough( if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages/count_tokens" + targetURL = buildAnthropicTargetURL(validatedURL, "/v1/messages/count_tokens") } req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) @@ -6338,7 +6354,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages/count_tokens" + targetURL = buildAnthropicTargetURL(validatedURL, "/v1/messages/count_tokens") } }