diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 51156a81d1..61888b4d98 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -31,15 +31,16 @@ description: AI 代理插件配置参考 `provider`的配置字段说明如下: -| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | -| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `type` | string | 必填 | - | AI 服务提供商名称 | -| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 | -| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 | -| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 | -| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) | -| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 | -| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|------------------| --------------- | -------- | ------ |-----------------------------------------------------------------------------------------------------------------------------------------------------------| +| `type` | string | 必填 | - | AI 服务提供商名称 | +| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 | +| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 | +| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。
1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;
2. 支持使用 "*" 为键来配置通用兜底映射关系;
3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 | +| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) | +| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 | +| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | +| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 | `context`的配置字段说明如下: @@ -75,6 +76,16 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字 如果启用了raw模式,custom-setting会直接用输入的`name`和`value`去更改请求中的json内容,而不对参数名称做任何限制和修改。 对于大多数协议,custom-setting都会在json内容的根路径修改或者填充参数。对于`qwen`协议,ai-proxy会在json的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。 +`failover` 的配置字段说明如下: + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|------------------|--------|------|-------|-----------------------------| +| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 | +| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) | +| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) | +| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 | +| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 | +| healthCheckModel | string | 必填 | | 健康检测使用的模型 | ### 提供商特有配置 diff --git a/plugins/wasm-go/extensions/ai-proxy/config/config.go b/plugins/wasm-go/extensions/ai-proxy/config/config.go index b545271a70..48f08dd9e4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/config/config.go +++ b/plugins/wasm-go/extensions/ai-proxy/config/config.go @@ -1,9 +1,9 @@ package config import ( - "github.com/tidwall/gjson" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" ) // @Name ai-proxy @@ -75,13 +75,17 @@ func (c *PluginConfig) Validate() error { return nil } -func (c *PluginConfig) Complete() error { +func (c *PluginConfig) Complete(log wrapper.Log) error { if c.activeProviderConfig == nil { c.activeProvider = nil return nil } var err error c.activeProvider, err = provider.CreateProvider(*c.activeProviderConfig) + + providerConfig := c.GetProviderConfig() + err = providerConfig.SetApiTokensFailover(log, c.activeProvider) + return err } diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 9e0fafe179..3a29575c21 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -44,9 +44,10 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log if err := pluginConfig.Validate(); err != nil { return err } - if err := pluginConfig.Complete(); err != nil { + if err := pluginConfig.Complete(log); err != nil { return err } + return nil } @@ -59,9 +60,10 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug if err := pluginConfig.Validate(); err != nil { return err } - if err := pluginConfig.Complete(); err != nil { + if err := pluginConfig.Complete(log); err != nil { return err } + return nil } @@ -78,9 +80,16 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf rawPath := ctx.Path() path, _ := url.Parse(rawPath) - apiName := getOpenAiApiName(path.Path) + + var apiName provider.ApiName providerConfig := pluginConfig.GetProviderConfig() - if apiName == "" && !providerConfig.IsOriginal() { + if providerConfig.IsOriginal() { + apiName = activeProvider.GetApiName(path.Path) + } else { + apiName = provider.GetOpenAiApiName(path.Path) + } + + if apiName == "" { log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path) _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path) return types.ActionContinue @@ -88,8 +97,11 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf ctx.SetContext(ctxKeyApiName, apiName) if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok { - // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. + // Disable the route re-calculation since the plugin may modify some headers related to the chosen route. ctx.DisableReroute() + // Set the apiToken for the current request. + providerConfig.SetApiTokenInUse(ctx, log) + hasRequestBody := wrapper.HasRequestBody() action, err := handler.OnRequestHeaders(ctx, apiName, log) if err == nil { @@ -101,6 +113,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf } return action } + _ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err)) return types.ActionContinue } @@ -155,15 +168,24 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType()) + providerConfig := pluginConfig.GetProviderConfig() + apiTokenInUse := providerConfig.GetApiTokenInUse(ctx) + status, err := proxywasm.GetHttpResponseHeader(":status") if err != nil || status != "200" { if err != nil { log.Errorf("unable to load :status header from response: %v", err) } ctx.DontReadResponseBody() + providerConfig.OnRequestFailed(ctx, apiTokenInUse, log) + return types.ActionContinue } + // Reset ctxApiTokenRequestFailureCount if the request is successful, + // the apiToken is removed only when the number of consecutive request failures exceeds the threshold. + providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log) + if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok { apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) action, err := handler.OnResponseHeaders(ctx, apiName, log) @@ -232,16 +254,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi return types.ActionContinue } -func getOpenAiApiName(path string) provider.ApiName { - if strings.HasSuffix(path, "/v1/chat/completions") { - return provider.ApiNameChatCompletion - } - if strings.HasSuffix(path, "/v1/embeddings") { - return provider.ApiNameEmbeddings - } - return "" -} - func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) { contentType, err := proxywasm.GetHttpResponseHeader("Content-Type") if err != nil || !strings.HasPrefix(contentType, "text/event-stream") { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go index 00443fcf5e..439fd8ff10 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ai360.go @@ -1,19 +1,19 @@ package provider import ( - "encoding/json" "errors" - "fmt" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) // ai360Provider is the provider for 360 OpenAI service. const ( - ai360Domain = "api.360.cn" + ai360Domain = "api.360.cn" + ai360ChatCompletionPath = "/v1/chat/completions" ) type ai360ProviderInitializer struct { @@ -46,10 +46,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(ai360Domain) - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - _ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken()) + m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } @@ -58,47 +55,19 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { return types.ActionContinue, errUnsupportedApiName } - if apiName == ApiNameChatCompletion { - return m.onChatCompletionRequestBody(ctx, body, log) - } - if apiName == ApiNameEmbeddings { - return m.onEmbeddingsRequestBody(ctx, body, log) - } - return types.ActionContinue, errUnsupportedApiName + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } -func (m *ai360Provider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - if request.Model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - // 映射模型 - mappedModel := getMappedModel(request.Model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) - request.Model = mappedModel - return types.ActionContinue, replaceJsonRequestBody(request, log) +func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, ai360Domain) + util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") } -func (m *ai360Provider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - request := &embeddingsRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("missing model in embeddings request") - } - // 映射模型 - mappedModel := getMappedModel(request.Model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") +func (m *ai360Provider) GetApiName(path string) ApiName { + if strings.Contains(path, ai360ChatCompletionPath) { + return ApiNameChatCompletion } - ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) - request.Model = mappedModel - return types.ActionContinue, replaceJsonRequestBody(request, log) + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go index 2dcba2f8ff..9919aeb073 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/azure.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/azure.go @@ -3,16 +3,20 @@ package provider import ( "errors" "fmt" + "net/http" "net/url" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) -// azureProvider is the provider for Azure OpenAI service. +const ( + azureChatCompletionPath = "/chat/completions" +) +// azureProvider is the provider for Azure OpenAI service. type azureProviderInitializer struct { } @@ -55,47 +59,30 @@ func (m *azureProvider) GetProviderType() string { } func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - _ = util.OverwriteRequestPath(m.serviceUrl.RequestURI()) - _ = util.OverwriteRequestHost(m.serviceUrl.Host) - _ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.apiTokens[0]) - if apiName == ApiNameChatCompletion { - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - } else { - ctx.DontReadRequestBody() + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName } + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { - // We don't need to process the request body for other APIs. - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err + return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - return types.ActionContinue, nil - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.azure.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.azure.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI()) + util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host) + util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *azureProvider) GetApiName(path string) ApiName { + if strings.Contains(path, azureChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go index c16a8e4395..e016dc4553 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go @@ -2,11 +2,11 @@ package provider import ( "errors" - "fmt" + "net/http" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -47,10 +47,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(baichuanChatCompletionPath) - _ = util.OverwriteRequestHost(baichuanDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -58,28 +55,19 @@ func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.baichuan.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.baichuan.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath) + util.OverwriteRequestHostHeader(headers, baichuanDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *baichuanProvider) GetApiName(path string) ApiName { + if strings.Contains(path, baichuanChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index fc779d5306..42a1bc723d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -16,7 +17,8 @@ import ( // baiduProvider is the provider for baidu ernie bot service. const ( - baiduDomain = "aip.baidubce.com" + baiduDomain = "aip.baidubce.com" + baiduChatCompletionPath = "/chat" ) var baiduModelToPathSuffixMap = map[string]string{ @@ -60,98 +62,35 @@ func (b *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(baiduDomain) - - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - + b.config.handleRequestHeaders(b, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } +func (b *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, baiduDomain) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") +} + func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - // 使用文心一言接口协议 - if b.config.protocol == protocolOriginal { - request := &baiduTextGenRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("request model is empty") - } - // 根据模型重写requestPath - path := b.getRequestPath(request.Model) - _ = util.OverwriteRequestPath(path) - - if b.config.context == nil { - return types.ActionContinue, nil - } + return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body, log) +} - err := b.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.baidu.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - b.setSystemContent(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.baidu.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err - } +func (b *baiduProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err + err := b.config.parseRequestAndMapModel(ctx, request, body, log) + if err != nil { + return nil, err } + path := b.getRequestPath(ctx, request.Model) + util.OverwriteRequestPathHeader(headers, path) - // 映射模型重写requestPath - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, b.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - path := b.getRequestPath(mappedModel) - _ = util.OverwriteRequestPath(path) - - if b.config.context == nil { - baiduRequest := b.baiduTextGenRequest(request) - return types.ActionContinue, replaceJsonRequestBody(baiduRequest, log) - } - - err := b.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.baidu.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - baiduRequest := b.baiduTextGenRequest(request) - if err := replaceJsonRequestBody(baiduRequest, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.baidu.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + baiduRequest := b.baiduTextGenRequest(request) + return json.Marshal(baiduRequest) } func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { @@ -226,13 +165,13 @@ type baiduTextGenRequest struct { UserId string `json:"user_id,omitempty"` } -func (b *baiduProvider) getRequestPath(baiduModel string) string { +func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel string) string { // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t suffix, ok := baiduModelToPathSuffixMap[baiduModel] if !ok { suffix = baiduModel } - return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetRandomToken()) + return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetApiTokenInUse(ctx)) } func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) { @@ -339,3 +278,10 @@ func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, resp func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } + +func (b *baiduProvider) GetApiName(path string) ApiName { + if strings.Contains(path, baiduChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index 7bbbc93d79..8b98d62d64 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -105,102 +106,39 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } + c.config.handleRequestHeaders(c, ctx, apiName, log) + return types.ActionContinue, nil +} + +func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath) + util.OverwriteRequestHostHeader(headers, claudeDomain) - _ = util.OverwriteRequestPath(claudeChatCompletionPath) - _ = util.OverwriteRequestHost(claudeDomain) - _ = proxywasm.ReplaceHttpRequestHeader("x-api-key", c.config.GetRandomToken()) + headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx)) if c.config.claudeVersion == "" { c.config.claudeVersion = defaultVersion } - _ = proxywasm.AddHttpRequestHeader("anthropic-version", c.config.claudeVersion) - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - return types.ActionContinue, nil + headers.Add("anthropic-version", c.config.claudeVersion) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") } func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } + return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) +} - // use original protocol - if c.config.protocol == protocolOriginal { - if c.config.context == nil { - return types.ActionContinue, nil - } - - request := &claudeTextGenRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - - err := c.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err - } - - // use openai protocol +func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") + if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { + return nil, err } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, c.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - - streaming := request.Stream - if streaming { - _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") - } - - if c.config.context == nil { - claudeRequest := c.buildClaudeTextGenRequest(request) - return types.ActionContinue, replaceJsonRequestBody(claudeRequest, log) - } - - err := c.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - claudeRequest := c.buildClaudeTextGenRequest(request) - if err := replaceJsonRequestBody(claudeRequest, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + claudeRequest := c.buildClaudeTextGenRequest(request) + return json.Marshal(claudeRequest) } func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -369,3 +307,25 @@ func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextG func (c *claudeProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } + +func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) { + request := &claudeTextGenRequest{} + if err := json.Unmarshal(body, request); err != nil { + return nil, fmt.Errorf("unable to unmarshal request: %v", err) + } + + if request.System == "" { + request.System = content + } else { + request.System = content + "\n" + request.System + } + + return json.Marshal(request) +} + +func (c *claudeProvider) GetApiName(path string) ApiName { + if strings.Contains(path, claudeChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go index 35f6f2dc78..a4c02381fa 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cloudflare.go @@ -2,19 +2,19 @@ package provider import ( "errors" - "fmt" + "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) const ( cloudflareDomain = "api.cloudflare.com" // https://developers.cloudflare.com/workers-ai/configuration/open-ai-compatibility/ - cloudflareChatCompletionPath = "/client/v4/accounts/{account_id}/ai/v1/chat/completions" + cloudflareChatCompletionPath = "/v1/chat/completions" + cloudflareChatCompletionFullPath = "/client/v4/accounts/{account_id}/ai/v1/chat/completions" ) type cloudflareProviderInitializer struct { @@ -47,13 +47,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1)) - _ = util.OverwriteRequestHost(cloudflareDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + c.config.GetRandomToken()) - - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - + c.config.handleRequestHeaders(c, ctx, apiName, log) return types.ActionContinue, nil } @@ -61,49 +55,20 @@ func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiN if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } + return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log) +} - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, c.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - - streaming := request.Stream - if streaming { - _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") - } +func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionFullPath, "{account_id}", c.config.cloudflareAccountId, 1)) + util.OverwriteRequestHostHeader(headers, cloudflareDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx)) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") +} - if c.contextCache == nil { - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.cloudflare.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - return types.ActionContinue, nil - } - err := c.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.cloudflare.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.cloudflare.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil +func (c *cloudflareProvider) GetApiName(path string) ApiName { + if strings.Contains(path, cloudflareChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go index 7ffe1708af..72dbaf280b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/cohere.go @@ -3,17 +3,16 @@ package provider import ( "encoding/json" "errors" - "fmt" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) const ( - cohereDomain = "api.cohere.com" - chatCompletionPath = "/v1/chat" + cohereDomain = "api.cohere.com" + cohereChatCompletionPath = "/v1/chat" ) type cohereProviderInitializer struct{} @@ -27,12 +26,14 @@ func (m *cohereProviderInitializer) ValidateConfig(config ProviderConfig) error func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { return &cohereProvider{ - config: config, + config: config, + contextCache: createContextCache(&config), }, nil } type cohereProvider struct { - config ProviderConfig + config ProviderConfig + contextCache *contextCache } type cohereTextGenRequest struct { @@ -57,10 +58,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(cohereDomain) - _ = util.OverwriteRequestPath(chatCompletionPath) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -68,30 +66,7 @@ func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.config.protocol == protocolOriginal { - request := &cohereTextGenRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - return m.handleRequestBody(log, request) - } - origin := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, origin); err != nil { - return types.ActionContinue, err - } - request := m.buildCohereRequest(origin) - return m.handleRequestBody(log, request) -} - -func (m *cohereProvider) handleRequestBody(log wrapper.Log, request interface{}) (types.Action, error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - err := replaceJsonRequestBody(request, log) - if err != nil { - _ = util.SendResponse(500, "ai-proxy.cohere.proxy_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - return types.ActionContinue, err + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohereTextGenRequest { @@ -112,3 +87,27 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe PresencePenalty: origin.PresencePenalty, } } + +func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath) + util.OverwriteRequestHostHeader(headers, cohereDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + request := &chatCompletionRequest{} + if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { + return nil, err + } + + cohereRequest := m.buildCohereRequest(request) + return json.Marshal(cohereRequest) +} + +func (m *cohereProvider) GetApiName(path string) ApiName { + if strings.Contains(path, cohereChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/context.go b/plugins/wasm-go/extensions/ai-proxy/provider/context.go index 2026a9818a..d9fe2e26c4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/context.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/context.go @@ -1,12 +1,15 @@ package provider import ( + "encoding/json" "errors" "fmt" "net/http" "net/url" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/tidwall/gjson" ) @@ -57,6 +60,10 @@ type contextCache struct { content string } +type ContextInserter interface { + insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) +} + func (c *contextCache) GetContent(callback func(string, error), log wrapper.Log) error { if callback == nil { return errors.New("callback is nil") @@ -98,3 +105,79 @@ func createContextCache(providerConfig *ProviderConfig) *contextCache { timeout: providerConfig.timeout, } } + +func (c *contextCache) GetContextFromFile(ctx wrapper.HttpContext, provider Provider, body []byte, log wrapper.Log) error { + // get context will overwrite the original request host and path + // save the original request host and path in case they are needed for apiToken health check + ctx.SetContext(ctxRequestHost, wrapper.GetRequestHost()) + ctx.SetContext(ctxRequestPath, wrapper.GetRequestPath()) + + if c.loaded { + log.Debugf("context file loaded from cache") + insertContext(provider, c.content, nil, body, log) + return nil + } + + log.Infof("loading context file from %s", c.fileUrl.String()) + return c.client.Get(c.fileUrl.Path, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != http.StatusOK { + insertContext(provider, "", fmt.Errorf("failed to load context file, status: %d", statusCode), nil, log) + return + } + c.content = string(responseBody) + c.loaded = true + log.Debugf("content: %s", c.content) + insertContext(provider, c.content, nil, body, log) + }, c.timeout) +} + +func insertContext(provider Provider, content string, err error, body []byte, log wrapper.Log) { + defer func() { + _ = proxywasm.ResumeHttpRequest() + }() + + typ := provider.GetProviderType() + if err != nil { + log.Errorf("failed to load context file: %v", err) + _ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.load_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) + } + + if inserter, ok := provider.(ContextInserter); ok { + body, err = inserter.insertHttpContextMessage(body, content, false) + } else { + body, err = defaultInsertHttpContextMessage(body, content) + } + + if err != nil { + _ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to insert context message: %v", err)) + } + if err := replaceHttpJsonRequestBody(body, log); err != nil { + _ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) + } +} + +func defaultInsertHttpContextMessage(body []byte, content string) ([]byte, error) { + request := &chatCompletionRequest{} + if err := json.Unmarshal(body, request); err != nil { + return nil, fmt.Errorf("unable to unmarshal request: %v", err) + } + + fileMessage := chatMessage{ + Role: roleSystem, + Content: content, + } + var firstNonSystemMessageIndex int + for i, message := range request.Messages { + if message.Role != roleSystem { + firstNonSystemMessageIndex = i + break + } + } + if firstNonSystemMessageIndex == 0 { + request.Messages = append([]chatMessage{fileMessage}, request.Messages...) + } else { + request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...) + } + + return json.Marshal(request) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go index 924746c8c9..bafe6b3dde 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepl.go @@ -4,6 +4,8 @@ import ( "encoding/json" "errors" "fmt" + "net/http" + "strings" "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -78,49 +80,38 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(deeplChatCompletionPath) - _ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + d.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + d.config.handleRequestHeaders(d, ctx, apiName, log) return types.HeaderStopIteration, nil } +func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath) + util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") + headers.Del("Accept-Encoding") +} + func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if d.config.protocol == protocolOriginal { - request := &deeplRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if err := d.overwriteRequestHost(request.Model); err != nil { - return types.ActionContinue, err - } - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - return types.ActionContinue, replaceJsonRequestBody(request, log) - } else { - originRequest := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, originRequest); err != nil { - return types.ActionContinue, err - } - if err := d.overwriteRequestHost(originRequest.Model); err != nil { - return types.ActionContinue, err - } - ctx.SetContext(ctxKeyFinalRequestModel, originRequest.Model) - deeplRequest := &deeplRequest{ - Text: make([]string, 0), - TargetLang: d.config.targetLang, - } - for _, msg := range originRequest.Messages { - if msg.Role == roleSystem { - deeplRequest.Context = msg.StringContent() - } else { - deeplRequest.Text = append(deeplRequest.Text, msg.StringContent()) - } - } - return types.ActionContinue, replaceJsonRequestBody(deeplRequest, log) + return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log) +} + +func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return nil, err } + ctx.SetContext(ctxKeyFinalRequestModel, request.Model) + + err := d.overwriteRequestHost(headers, request.Model) + if err != nil { + return nil, err + } + + baiduRequest := d.deeplTextGenRequest(request) + return json.Marshal(baiduRequest) } func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { @@ -164,13 +155,35 @@ func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplRespo } } -func (d *deeplProvider) overwriteRequestHost(model string) error { +func (d *deeplProvider) overwriteRequestHost(headers http.Header, model string) error { if model == "Pro" { - _ = util.OverwriteRequestHost(deeplHostPro) + util.OverwriteRequestHostHeader(headers, deeplHostPro) } else if model == "Free" { - _ = util.OverwriteRequestHost(deeplHostFree) + util.OverwriteRequestHostHeader(headers, deeplHostFree) } else { return errors.New(`deepl model should be "Free" or "Pro"`) } return nil } + +func (d *deeplProvider) deeplTextGenRequest(request *chatCompletionRequest) *deeplRequest { + deeplRequest := &deeplRequest{ + Text: make([]string, 0), + TargetLang: d.config.targetLang, + } + for _, msg := range request.Messages { + if msg.Role == roleSystem { + deeplRequest.Context = msg.StringContent() + } else { + deeplRequest.Text = append(deeplRequest.Text, msg.StringContent()) + } + } + return deeplRequest +} + +func (d *deeplProvider) GetApiName(path string) ApiName { + if strings.Contains(path, deeplChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go index 8cb71462d2..c1eb57fe45 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go @@ -2,12 +2,11 @@ package provider import ( "errors" - "fmt" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) // deepseekProvider is the provider for deepseek Ai service. @@ -47,10 +46,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(deepseekChatCompletionPath) - _ = util.OverwriteRequestHost(deepseekDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -58,28 +54,19 @@ func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.deepseek.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.deepseek.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, deepseekChatCompletionPath) + util.OverwriteRequestHostHeader(headers, deepseekDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *deepseekProvider) GetApiName(path string) ApiName { + if strings.Contains(path, deepseekChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go index 0ca349a773..651b983206 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/doubao.go @@ -2,12 +2,11 @@ package provider import ( "errors" - "fmt" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) const ( @@ -41,17 +40,10 @@ func (m *doubaoProvider) GetProviderType() string { } func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - _ = util.OverwriteRequestHost(doubaoDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - if m.config.protocol == protocolOriginal { - ctx.DontReadRequestBody() - return types.ActionContinue, nil - } if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(doubaoChatCompletionPath) + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -59,44 +51,19 @@ func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - mappedModel := getMappedModel(model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - if m.contextCache != nil { - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.doubao.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.doubao.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } else { - return types.ActionContinue, err - } - } else { - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.doubao.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - return types.ActionContinue, err - } - _ = proxywasm.ResumeHttpRequest() - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, doubaoChatCompletionPath) + util.OverwriteRequestHostHeader(headers, doubaoDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *doubaoProvider) GetApiName(path string) ApiName { + if strings.Contains(path, doubaoChatCompletionPath) { + return ApiNameChatCompletion } + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go new file mode 100644 index 0000000000..32e92a4db4 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -0,0 +1,594 @@ +package provider + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/google/uuid" + "math/rand" + "net/http" + "strings" + "time" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "github.com/tidwall/gjson" +) + +type failover struct { + // @Title zh-CN 是否启用 apiToken 的 failover 机制 + enabled bool `required:"true" yaml:"enabled" json:"enabled"` + // @Title zh-CN 触发 failover 连续请求失败的阈值 + failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"` + // @Title zh-CN 健康检测的成功阈值 + successThreshold int64 `required:"false" yaml:"successThreshold" json:"successThreshold"` + // @Title zh-CN 健康检测的间隔时间,单位毫秒 + healthCheckInterval int64 `required:"false" yaml:"healthCheckInterval" json:"healthCheckInterval"` + // @Title zh-CN 健康检测的超时时间,单位毫秒 + healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"` + // @Title zh-CN 健康检测使用的模型 + healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"` + // @Title zh-CN 本次请求使用的 apiToken + ctxApiTokenInUse string + // @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数 + ctxApiTokenRequestFailureCount string + // @Title zh-CN 记录 apiToken 健康检测成功的次数,key 为 apiToken,value 为成功次数 + ctxApiTokenRequestSuccessCount string + // @Title zh-CN 记录所有可用的 apiToken 列表 + ctxApiTokens string + // @Title zh-CN 记录所有不可用的 apiToken 列表 + ctxUnavailableApiTokens string + // @Title zh-CN 记录请求的 cluster, host 和 path,用于在健康检测时构建请求 + ctxHealthCheckEndpoint string + // @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测 + ctxVmLease string +} + +type Lease struct { + VMID string `json:"vmID"` + Timestamp int64 `json:"timestamp"` +} + +type HealthCheckEndpoint struct { + Host string `json:"host"` + Path string `json:"path"` + Cluster string `json:"cluster"` +} + +const ( + casMaxRetries = 10 + addApiTokenOperation = "addApiToken" + removeApiTokenOperation = "removeApiToken" + addApiTokenRequestCountOperation = "addApiTokenRequestCount" + resetApiTokenRequestCountOperation = "resetApiTokenRequestCount" + ctxRequestHost = "requestHost" + ctxRequestPath = "requestPath" +) + +var ( + healthCheckClient wrapper.HttpClient +) + +func (f *failover) FromJson(json gjson.Result) { + f.enabled = json.Get("enabled").Bool() + f.failureThreshold = json.Get("failureThreshold").Int() + if f.failureThreshold == 0 { + f.failureThreshold = 3 + } + f.successThreshold = json.Get("successThreshold").Int() + if f.successThreshold == 0 { + f.successThreshold = 1 + } + f.healthCheckInterval = json.Get("healthCheckInterval").Int() + if f.healthCheckInterval == 0 { + f.healthCheckInterval = 5000 + } + f.healthCheckTimeout = json.Get("healthCheckTimeout").Int() + if f.healthCheckTimeout == 0 { + f.healthCheckTimeout = 5000 + } + f.healthCheckModel = json.Get("healthCheckModel").String() +} + +func (f *failover) Validate() error { + if f.healthCheckModel == "" { + return errors.New("missing healthCheckModel in failover config") + } + return nil +} + +func (c *ProviderConfig) initVariable() { + // Set provider name as prefix to differentiate shared data + provider := c.GetType() + c.failover.ctxApiTokenInUse = provider + "-apiTokenInUse" + c.failover.ctxApiTokenRequestFailureCount = provider + "-apiTokenRequestFailureCount" + c.failover.ctxApiTokenRequestSuccessCount = provider + "-apiTokenRequestSuccessCount" + c.failover.ctxApiTokens = provider + "-apiTokens" + c.failover.ctxUnavailableApiTokens = provider + "-unavailableApiTokens" + c.failover.ctxHealthCheckEndpoint = provider + "-requestHostAndPath" + c.failover.ctxVmLease = provider + "-vmLease" +} + +func parseConfig(json gjson.Result, config *any, log wrapper.Log) error { + return nil +} + +func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Provider) error { + c.initVariable() + // Reset shared data in case plugin configuration is updated + log.Debugf("ai-proxy plugin configuration is updated, reset shared data") + c.resetSharedData() + + if c.isFailoverEnabled() { + log.Debugf("ai-proxy plugin failover is enabled") + + vmID := generateVMID() + err := c.initApiTokens() + + if err != nil { + return fmt.Errorf("failed to init apiTokens: %v", err) + } + + wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() { + // Only the Wasm VM that successfully acquires the lease will perform health check + if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID, log) { + log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType()) + unavailableTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens) + if err != nil { + log.Errorf("Failed to get unavailable tokens: %v", err) + return + } + if len(unavailableTokens) > 0 { + for _, apiToken := range unavailableTokens { + log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", ")) + healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody(log) + healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{ + Host: healthCheckEndpoint.Host, + Cluster: healthCheckEndpoint.Cluster, + }) + + ctx := createHttpContext() + ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) + + modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body, log) + if err != nil { + log.Errorf("Failed to transform request headers and body: %v", err) + } + + // The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion + err = healthCheckClient.Post(healthCheckEndpoint.Path, modifiedHeaders, modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode == 200 { + c.handleAvailableApiToken(apiToken, log) + } + }, uint32(c.failover.healthCheckTimeout)) + if err != nil { + log.Errorf("Failed to perform health check request: %v", err) + } + } + } + } + }) + } + return nil +} + +func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, headers [][2]string, body []byte, log wrapper.Log) ([][2]string, []byte, error) { + originalHeaders := util.SliceToHeader(headers) + if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok { + handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, originalHeaders, log) + } + + var err error + if handler, ok := activeProvider.(TransformRequestBodyHandler); ok { + body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log) + } else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok { + headers := util.GetOriginalHttpHeaders() + body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log) + util.ReplaceOriginalHttpHeaders(headers) + } else { + body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log) + } + if err != nil { + return nil, nil, fmt.Errorf("failed to transform request body: %v", err) + } + + modifiedHeaders := util.HeaderToSlice(originalHeaders) + return modifiedHeaders, body, nil +} + +func createHttpContext() *wrapper.CommonHttpCtx[any] { + setParseConfig := wrapper.ParseConfigBy[any](parseConfig) + vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig) + pluginCtx := vmCtx.NewPluginContext(rand.Uint32()) + ctx := pluginCtx.NewHttpContext(rand.Uint32()).(*wrapper.CommonHttpCtx[any]) + return ctx +} + +func (c *ProviderConfig) generateRequestHeadersAndBody(log wrapper.Log) (HealthCheckEndpoint, [][2]string, []byte) { + data, _, err := proxywasm.GetSharedData(c.failover.ctxHealthCheckEndpoint) + if err != nil { + log.Errorf("Failed to get request host and path: %v", err) + } + var healthCheckEndpoint HealthCheckEndpoint + err = json.Unmarshal(data, &healthCheckEndpoint) + if err != nil { + log.Errorf("Failed to unmarshal request host and path: %v", err) + } + + headers := [][2]string{ + {"content-type", "application/json"}, + } + body := []byte(fmt.Sprintf(`{ + "model": "%s", + "messages": [ + { + "role": "user", + "content": "who are you?" + } + ] + }`, c.failover.healthCheckModel)) + return healthCheckEndpoint, headers, body +} + +func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool { + now := time.Now().Unix() + + data, cas, err := proxywasm.GetSharedData(c.failover.ctxVmLease) + if err != nil { + if errors.Is(err, types.ErrorStatusNotFound) { + return c.setLease(vmID, now, cas, log) + } else { + log.Errorf("Failed to get lease: %v", err) + return false + } + } + if data == nil { + return c.setLease(vmID, now, cas, log) + } + + var lease Lease + err = json.Unmarshal(data, &lease) + if err != nil { + log.Errorf("Failed to unmarshal lease data: %v", err) + return false + } + // If vmID is itself, try to renew the lease directly + // If the lease is expired (60s), try to acquire the lease + if lease.VMID == vmID || now-lease.Timestamp > 60 { + lease.VMID = vmID + lease.Timestamp = now + return c.setLease(vmID, now, cas, log) + } + + return false +} + +func (c *ProviderConfig) setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool { + lease := Lease{ + VMID: vmID, + Timestamp: timestamp, + } + leaseByte, err := json.Marshal(lease) + if err != nil { + log.Errorf("Failed to marshal lease data: %v", err) + return false + } + + if err := proxywasm.SetSharedData(c.failover.ctxVmLease, leaseByte, cas); err != nil { + log.Errorf("Failed to set or renew lease: %v", err) + return false + } + return true +} + +func generateVMID() string { + return uuid.New().String() +} + +// When number of request successes exceeds the threshold during health check, +// add the apiToken back to the available list and remove it from the unavailable list +func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Log) { + successApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount) + if err != nil { + log.Errorf("Failed to get successApiTokenRequestCount: %v", err) + return + } + + successCount := successApiTokenRequestCount[apiToken] + 1 + if successCount >= c.failover.successThreshold { + log.Infof("apiToken %s is available now, add it back to the apiTokens list", apiToken) + removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log) + addApiToken(c.failover.ctxApiTokens, apiToken, log) + resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log) + } else { + log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check...", apiToken, successCount) + addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log) + } +} + +// When number of request failures exceeds the threshold, +// remove the apiToken from the available list and add it to the unavailable list +func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiToken string, log wrapper.Log) { + failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount) + if err != nil { + log.Errorf("Failed to get failureApiTokenRequestCount: %v", err) + return + } + + availableTokens, _, err := getApiTokens(c.failover.ctxApiTokens) + if err != nil { + log.Errorf("Failed to get available apiToken: %v", err) + return + } + // unavailable apiToken has been removed from the available list + if !containsElement(availableTokens, apiToken) { + return + } + + failureCount := failureApiTokenRequestCount[apiToken] + 1 + if failureCount >= c.failover.failureThreshold { + log.Infof("apiToken %s is unavailable now, remove it from apiTokens list", apiToken) + removeApiToken(c.failover.ctxApiTokens, apiToken, log) + addApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log) + resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log) + // Set the request host and path to shared data in case they are needed in apiToken health check + c.setHealthCheckEndpoint(ctx, log) + } else { + log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount) + addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log) + } +} + +func addApiToken(key, apiToken string, log wrapper.Log) { + modifyApiToken(key, apiToken, addApiTokenOperation, log) +} + +func removeApiToken(key, apiToken string, log wrapper.Log) { + modifyApiToken(key, apiToken, removeApiTokenOperation, log) +} + +func modifyApiToken(key, apiToken, op string, log wrapper.Log) { + for attempt := 1; attempt <= casMaxRetries; attempt++ { + apiTokens, cas, err := getApiTokens(key) + if err != nil { + log.Errorf("Failed to get %s: %v", key, err) + continue + } + + exists := containsElement(apiTokens, apiToken) + if op == addApiTokenOperation && exists { + log.Debugf("%s already exists in %s", apiToken, key) + return + } else if op == removeApiTokenOperation && !exists { + log.Debugf("%s does not exist in %s", apiToken, key) + return + } + + if op == addApiTokenOperation { + apiTokens = append(apiTokens, apiToken) + } else { + apiTokens = removeElement(apiTokens, apiToken) + } + + if err := setApiTokens(key, apiTokens, cas); err == nil { + log.Debugf("Successfully updated %s in %s", apiToken, key) + return + } else if !errors.Is(err, types.ErrorStatusCasMismatch) { + log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err) + return + } + + log.Errorf("CAS mismatch when setting %s, retrying...", key) + } +} + +func getApiTokens(key string) ([]string, uint32, error) { + data, cas, err := proxywasm.GetSharedData(key) + if err != nil { + if errors.Is(err, types.ErrorStatusNotFound) { + return []string{}, cas, nil + } + return nil, 0, err + } + if data == nil { + return []string{}, cas, nil + } + + var apiTokens []string + if err = json.Unmarshal(data, &apiTokens); err != nil { + return nil, 0, fmt.Errorf("failed to unmarshal tokens: %v", err) + } + + return apiTokens, cas, nil +} + +func setApiTokens(key string, apiTokens []string, cas uint32) error { + data, err := json.Marshal(apiTokens) + if err != nil { + return fmt.Errorf("failed to marshal tokens: %v", err) + } + return proxywasm.SetSharedData(key, data, cas) +} + +func removeElement(slice []string, s string) []string { + for i := 0; i < len(slice); i++ { + if slice[i] == s { + slice = append(slice[:i], slice[i+1:]...) + i-- + } + } + return slice +} + +func containsElement(slice []string, s string) bool { + for _, item := range slice { + if item == s { + return true + } + } + return false +} + +func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) { + data, cas, err := proxywasm.GetSharedData(key) + if err != nil { + if errors.Is(err, types.ErrorStatusNotFound) { + return make(map[string]int64), cas, nil + } + return nil, 0, err + } + + if data == nil { + return make(map[string]int64), cas, nil + } + + var apiTokens map[string]int64 + err = json.Unmarshal(data, &apiTokens) + if err != nil { + return nil, 0, err + } + return apiTokens, cas, nil +} + +func addApiTokenRequestCount(key, apiToken string, log wrapper.Log) { + modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation, log) +} + +func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) { + modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation, log) +} + +func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log wrapper.Log) { + if c.isFailoverEnabled() { + failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount) + if err != nil { + log.Errorf("failed to get failureApiTokenRequestCount: %v", err) + } + if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok { + log.Infof("reset apiToken %s request failure count", apiTokenInUse) + resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log) + } + } +} + +func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log) { + for attempt := 1; attempt <= casMaxRetries; attempt++ { + apiTokenRequestCount, cas, err := getApiTokenRequestCount(key) + if err != nil { + log.Errorf("Failed to get %s: %v", key, err) + continue + } + + if op == resetApiTokenRequestCountOperation { + delete(apiTokenRequestCount, apiToken) + } else { + apiTokenRequestCount[apiToken]++ + } + + apiTokenRequestCountByte, err := json.Marshal(apiTokenRequestCount) + if err != nil { + log.Errorf("failed to marshal apiTokenRequestCount: %v", err) + } + + if err := proxywasm.SetSharedData(key, apiTokenRequestCountByte, cas); err == nil { + log.Debugf("Successfully updated the count of %s in %s", apiToken, key) + return + } else if !errors.Is(err, types.ErrorStatusCasMismatch) { + log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err) + return + } + + log.Errorf("CAS mismatch when setting %s, retrying...", key) + } +} + +func (c *ProviderConfig) initApiTokens() error { + return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0) +} + +func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string { + apiTokens, _, err := getApiTokens(c.failover.ctxApiTokens) + unavailableApiTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens) + log.Debugf("apiTokens: %v, unavailableApiTokens: %v", apiTokens, unavailableApiTokens) + + if err != nil { + return "" + } + count := len(apiTokens) + switch count { + case 0: + return "" + case 1: + return apiTokens[0] + default: + return apiTokens[rand.Intn(count)] + } +} + +func (c *ProviderConfig) isFailoverEnabled() bool { + return c.failover.enabled +} + +func (c *ProviderConfig) resetSharedData() { + _ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0) + _ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0) + _ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0) + _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0) + _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0) +} + +func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) { + if c.isFailoverEnabled() { + c.handleUnavailableApiToken(ctx, apiTokenInUse, log) + } +} + +func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string { + return ctx.GetContext(c.failover.ctxApiTokenInUse).(string) +} + +func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) { + var apiToken string + if c.isFailoverEnabled() { + // if enable apiToken failover, only use available apiToken + apiToken = c.GetGlobalRandomToken(log) + } else { + apiToken = c.GetRandomToken() + } + log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken) + ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) +} + +func (c *ProviderConfig) setHealthCheckEndpoint(ctx wrapper.HttpContext, log wrapper.Log) { + cluster, err := proxywasm.GetProperty([]string{"cluster_name"}) + if err != nil { + log.Errorf("Failed to get cluster_name: %v", err) + } + + host := wrapper.GetRequestHost() + if host == "" { + host = ctx.GetContext(ctxRequestHost).(string) + } + path := wrapper.GetRequestPath() + if path == "" { + path = ctx.GetContext(ctxRequestPath).(string) + } + + healthCheckEndpoint := HealthCheckEndpoint{ + Host: host, + Path: path, + Cluster: string(cluster), + } + + healthCheckEndpointByte, err := json.Marshal(healthCheckEndpoint) + if err != nil { + log.Errorf("Failed to marshal request host and path: %v", err) + + } + err = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, healthCheckEndpointByte, 0) + if err != nil { + log.Errorf("Failed to set request host and path: %v", err) + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 0d418c16a5..a4c1ef2cd9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -17,8 +18,11 @@ import ( // geminiProvider is the provider for google gemini/gemini flash service. const ( - geminiApiKeyHeader = "x-goog-api-key" - geminiDomain = "generativelanguage.googleapis.com" + geminiApiKeyHeader = "x-goog-api-key" + geminiDomain = "generativelanguage.googleapis.com" + geminiChatCompletionPath = "generateContent" + geminiChatCompletionStreamPath = "streamGenerateContent?alt=sse" + geminiEmbeddingPath = "batchEmbedContents" ) type geminiProviderInitializer struct { @@ -51,157 +55,56 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { return types.ActionContinue, errUnsupportedApiName } - - _ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, g.config.GetRandomToken()) - _ = util.OverwriteRequestHost(geminiDomain) - - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - + g.config.handleRequestHeaders(g, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } -func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName == ApiNameChatCompletion { - return g.onChatCompletionRequestBody(ctx, body, log) - } else if apiName == ApiNameEmbeddings { - return g.onEmbeddingsRequestBody(ctx, body, log) - } - return types.ActionContinue, errUnsupportedApiName +func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, geminiDomain) + headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx)) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") } -func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - // 使用gemini接口协议 - if g.config.protocol == protocolOriginal { - request := &geminiChatRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("request model is empty") - } - // 根据模型重写requestPath - path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream) - _ = util.OverwriteRequestPath(path) - - // 移除多余的model和stream字段 - request = &geminiChatRequest{ - Contents: request.Contents, - SafetySettings: request.SafetySettings, - GenerationConfig: request.GenerationConfig, - Tools: request.Tools, - } - if g.config.context == nil { - return types.ActionContinue, replaceJsonRequestBody(request, log) - } - - err := g.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - g.setSystemContent(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err +func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + return types.ActionContinue, errUnsupportedApiName } + return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) +} - // 映射模型重写requestPath - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, g.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") +func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + if apiName == ApiNameChatCompletion { + return g.onChatCompletionRequestBody(ctx, body, headers, log) + } else { + return g.onEmbeddingsRequestBody(ctx, body, headers, log) } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - path := g.getRequestPath(ApiNameChatCompletion, mappedModel, request.Stream) - _ = util.OverwriteRequestPath(path) +} - if g.config.context == nil { - geminiRequest := g.buildGeminiChatRequest(request) - return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log) +func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + request := &chatCompletionRequest{} + err := g.config.parseRequestAndMapModel(ctx, request, body, log) + if err != nil { + return nil, err } + path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream) + util.OverwriteRequestPathHeader(headers, path) - err := g.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - geminiRequest := g.buildGeminiChatRequest(request) - if err := replaceJsonRequestBody(geminiRequest, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + geminiRequest := g.buildGeminiChatRequest(request) + return json.Marshal(geminiRequest) } -func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - // 使用gemini接口协议 - if g.config.protocol == protocolOriginal { - request := &geminiBatchEmbeddingRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("request model is empty") - } - // 根据模型重写requestPath - path := g.getRequestPath(ApiNameEmbeddings, request.Model, false) - _ = util.OverwriteRequestPath(path) - - // 移除多余的model字段 - request = &geminiBatchEmbeddingRequest{ - Requests: request.Requests, - } - return types.ActionContinue, replaceJsonRequestBody(request, log) - } +func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { request := &embeddingsRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) + if err := g.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { + return nil, err } - - // 映射模型重写requestPath - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in embeddings request") - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, g.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - path := g.getRequestPath(ApiNameEmbeddings, mappedModel, false) - _ = util.OverwriteRequestPath(path) + path := g.getRequestPath(ApiNameEmbeddings, request.Model, false) + util.OverwriteRequestPathHeader(headers, path) geminiRequest := g.buildBatchEmbeddingRequest(request) - return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log) + return json.Marshal(geminiRequest) } func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { @@ -285,11 +188,11 @@ func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string { action := "" if apiName == ApiNameEmbeddings { - action = "batchEmbedContents" + action = geminiEmbeddingPath } else if stream { - action = "streamGenerateContent?alt=sse" + action = geminiChatCompletionStreamPath } else { - action = "generateContent" + action = geminiChatCompletionPath } return fmt.Sprintf("/v1/models/%s:%s", geminiModel, action) } @@ -605,3 +508,13 @@ func (g *geminiProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, gemini func (g *geminiProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } + +func (g *geminiProvider) GetApiName(path string) ApiName { + if strings.Contains(path, geminiChatCompletionPath) || strings.Contains(path, geminiChatCompletionStreamPath) { + return ApiNameChatCompletion + } + if strings.Contains(path, geminiEmbeddingPath) { + return ApiNameEmbeddings + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/github.go b/plugins/wasm-go/extensions/ai-proxy/provider/github.go index 5ee51b2742..0a2b0c84de 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/github.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/github.go @@ -1,14 +1,12 @@ package provider import ( - "encoding/json" "errors" - "fmt" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) // githubProvider is the provider for GitHub OpenAI service. @@ -48,16 +46,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(githubDomain) - if apiName == ApiNameChatCompletion { - _ = util.OverwriteRequestPath(githubCompletionPath) - } - if apiName == ApiNameEmbeddings { - _ = util.OverwriteRequestPath(githubEmbeddingPath) - } - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - _ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken()) + m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } @@ -66,47 +55,28 @@ func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { return types.ActionContinue, errUnsupportedApiName } - if apiName == ApiNameChatCompletion { - return m.onChatCompletionRequestBody(ctx, body, log) - } - if apiName == ApiNameEmbeddings { - return m.onEmbeddingsRequestBody(ctx, body, log) - } - return types.ActionContinue, errUnsupportedApiName + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } -func (m *githubProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - if request.Model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") +func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, githubDomain) + if apiName == ApiNameChatCompletion { + util.OverwriteRequestPathHeader(headers, githubCompletionPath) } - // 映射模型 - mappedModel := getMappedModel(request.Model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") + if apiName == ApiNameEmbeddings { + util.OverwriteRequestPathHeader(headers, githubEmbeddingPath) } - ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) - request.Model = mappedModel - return types.ActionContinue, replaceJsonRequestBody(request, log) + util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") } -func (m *githubProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - request := &embeddingsRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("missing model in embeddings request") +func (m *githubProvider) GetApiName(path string) ApiName { + if strings.Contains(path, githubCompletionPath) { + return ApiNameChatCompletion } - // 映射模型 - mappedModel := getMappedModel(request.Model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") + if strings.Contains(path, githubEmbeddingPath) { + return ApiNameEmbeddings } - ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) - request.Model = mappedModel - return types.ActionContinue, replaceJsonRequestBody(request, log) + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go index 644e450ee9..dfbd971261 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/groq.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/groq.go @@ -2,11 +2,11 @@ package provider import ( "errors" - "fmt" + "net/http" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -18,14 +18,14 @@ const ( type groqProviderInitializer struct{} -func (m *groqProviderInitializer) ValidateConfig(config ProviderConfig) error { +func (g *groqProviderInitializer) ValidateConfig(config ProviderConfig) error { if config.apiTokens == nil || len(config.apiTokens) == 0 { return errors.New("no apiToken found in provider config") } return nil } -func (m *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { +func (g *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { return &groqProvider{ config: config, contextCache: createContextCache(&config), @@ -37,47 +37,35 @@ type groqProvider struct { contextCache *contextCache } -func (m *groqProvider) GetProviderType() string { +func (g *groqProvider) GetProviderType() string { return providerTypeGroq } -func (m *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(groqChatCompletionPath) - _ = util.OverwriteRequestHost(groqDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + g.config.handleRequestHeaders(g, ctx, apiName, log) return types.ActionContinue, nil } -func (m *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { +func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.groq.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.groq.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) +} + +func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, groqChatCompletionPath) + util.OverwriteRequestHostHeader(headers, groqDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (g *groqProvider) GetApiName(path string) ApiName { + if strings.Contains(path, groqChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index 7640a380b3..99cb135db6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "time" @@ -114,26 +115,27 @@ func (m *hunyuanProvider) GetProviderType() string { } func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - // log.Debugf("hunyuanProvider.OnRequestHeaders called! hunyunSecretKey/id is: %s/%s", m.config.hunyuanAuthKey, m.config.hunyuanAuthId) if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } + m.config.handleRequestHeaders(m, ctx, apiName, log) + // Delay the header processing to allow changing streaming mode in OnRequestBody + return types.HeaderStopIteration, nil +} - _ = util.OverwriteRequestHost(hunyuanDomain) - _ = util.OverwriteRequestPath(hunyuanRequestPath) - - // 添加hunyuan需要的自定义字段 - _ = proxywasm.ReplaceHttpRequestHeader(actionKey, hunyuanChatCompletionTCAction) - _ = proxywasm.ReplaceHttpRequestHeader(versionKey, versionValue) +func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, hunyuanDomain) + util.OverwriteRequestPathHeader(headers, hunyuanRequestPath) - // 删除一些字段 - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + // 添加 hunyuan 需要的自定义字段 + headers.Add(actionKey, hunyuanChatCompletionTCAction) + headers.Add(versionKey, versionValue) - // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + headers.Del("Accept-Encoding") + headers.Del("Content-Length") } +// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法 func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName @@ -142,7 +144,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName // 为header添加时间戳字段 (因为需要根据body进行签名时依赖时间戳,故于body处理部分创建时间戳) var timestamp int64 = time.Now().Unix() _ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp)) - // log.Debugf("#debug nash5# OnRequestBody set timestamp header: ", timestamp) // 使用混元本身接口的协议 if m.config.protocol == protocolOriginal { @@ -198,7 +199,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName if err := decodeChatCompletionRequest(body, request); err != nil { return types.ActionContinue, err } - // log.Debugf("#debug nash5# OnRequestBody call hunyuan api using openai's api!") model := request.Model if model == "" { @@ -235,18 +235,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName string(body), ) _ = util.OverwriteRequestAuthorization(authorizedValueNew) - // log.Debugf("#debug nash5# OnRequestBody done, body is: ", string(body)) - - // // 打印所有的headers - // headers, err2 := proxywasm.GetHttpRequestHeaders() - // if err2 != nil { - // log.Errorf("failed to get request headers: %v", err2) - // } else { - // // 迭代并打印所有请求头 - // for _, header := range headers { - // log.Infof("#debug nash5# inB Request header - %s: %s", header[0], header[1]) - // } - // } return types.ActionContinue, replaceJsonRequestBody(hunyuanRequest, log) } @@ -277,6 +265,32 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName return types.ActionContinue, err } +// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用 +func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + request := &chatCompletionRequest{} + err := m.config.parseRequestAndMapModel(ctx, request, body, log) + if err != nil { + return nil, err + } + + hunyuanRequest := m.buildHunyuanTextGenerationRequest(request) + + var timestamp int64 = time.Now().Unix() + _ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp)) + // 根据确定好的payload进行签名: + body, _ = json.Marshal(hunyuanRequest) + authorizedValueNew := GetTC3Authorizationcode( + m.config.hunyuanAuthId, + m.config.hunyuanAuthKey, + timestamp, + hunyuanDomain, + hunyuanChatCompletionTCAction, + string(body), + ) + util.OverwriteRequestAuthorizationHeader(headers, authorizedValueNew) + return json.Marshal(hunyuanRequest) +} + func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { _ = proxywasm.RemoveHttpResponseHeader("Content-Length") return types.ActionContinue, nil @@ -561,3 +575,7 @@ func GetTC3Authorizationcode(secretId string, secretKey string, timestamp int64, // fmt.Println(curl) return authorization } + +func (m *hunyuanProvider) GetApiName(path string) ApiName { + return ApiNameChatCompletion +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index ded72d7b51..00aa0f7254 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -78,14 +79,17 @@ func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(minimaxDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - + m.config.handleRequestHeaders(m, ctx, apiName, log) // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } +func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, minimaxDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName @@ -107,51 +111,16 @@ func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName return m.handleRequestBodyByChatCompletionPro(body, log) } else { // 使用ChatCompletion v2接口 - return m.handleRequestBodyByChatCompletionV2(body, log) + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) } } +func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + return m.handleRequestBodyByChatCompletionV2(body, headers, log) +} + // handleRequestBodyByChatCompletionPro 使用ChatCompletion Pro接口处理请求体 func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) { - // 使用minimax接口协议 - if m.config.protocol == protocolOriginal { - request := &minimaxChatCompletionV2Request{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("request model is empty") - } - // 根据模型重写requestPath - if m.config.minimaxGroupId == "" { - return types.ActionContinue, errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when use %s model ", request.Model)) - } - _ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId)) - - if m.config.context == nil { - return types.ActionContinue, nil - } - - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - m.setBotSettings(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err - } - request := &chatCompletionRequest{} if err := decodeChatCompletionRequest(body, request); err != nil { return types.ActionContinue, err @@ -174,6 +143,9 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log log.Errorf("failed to load context file: %v", err) _ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) } + // 由于 minimaxChatCompletionV2(格式和 OpenAI 一致)和 minimaxChatCompletionPro(格式和 OpenAI 不一致)中 insertHttpContextMessage 的逻辑不同,无法做到同一个 provider 统一 + // 因此对于 minimaxChatCompletionPro 需要手动处理 context 消息 + // minimaxChatCompletionV2 交给默认的 defaultInsertHttpContextMessage 方法插入 context 消息 minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content) if err := replaceJsonRequestBody(minimaxRequest, log); err != nil { _ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err)) @@ -186,37 +158,17 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log } // handleRequestBodyByChatCompletionV2 使用ChatCompletion v2接口处理请求体 -func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, log wrapper.Log) (types.Action, error) { +func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err + return nil, err } // 映射模型重写requestPath request.Model = getMappedModel(request.Model, m.config.modelMapping, log) - _ = util.OverwriteRequestPath(minimaxChatCompletionV2Path) - - if m.contextCache == nil { - return types.ActionContinue, replaceJsonRequestBody(request, log) - } + util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path) - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + return body, nil } func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { @@ -474,3 +426,10 @@ func (m *minimaxProvider) responseV2ToOpenAI(response *minimaxChatCompletionV2Re func (m *minimaxProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } + +func (m *minimaxProvider) GetApiName(path string) ApiName { + if strings.Contains(path, minimaxChatCompletionV2Path) || strings.Contains(path, minimaxChatCompletionProPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go index b217d8019e..23b870cff0 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/mistral.go @@ -2,16 +2,16 @@ package provider import ( "errors" - "fmt" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) const ( - mistralDomain = "api.mistral.ai" + mistralDomain = "api.mistral.ai" + mistralChatCompletionPath = "/v1/chat/completions" ) type mistralProviderInitializer struct{} @@ -43,9 +43,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(mistralDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -53,28 +51,18 @@ func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.mistral.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.mistral.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *mistralProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, mistralDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *mistralProvider) GetApiName(path string) ApiName { + if strings.Contains(path, mistralChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go index 6023b4abe8..de40471c92 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" @@ -58,33 +59,29 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(moonshotChatCompletionPath) - _ = util.OverwriteRequestHost(moonshotDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } +func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, moonshotChatCompletionPath) + util.OverwriteRequestHostHeader(headers, moonshotDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +// moonshot 有自己获取 context 的配置(moonshotFileId),因此无法复用 handleRequestBody 方法 +// moonshot 的 body 没有修改,无须实现TransformRequestBody,使用默认的 defaultTransformRequestBody 方法 func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { + if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { return types.ActionContinue, err } - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - mappedModel := getMappedModel(model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - if m.config.moonshotFileId == "" && m.contextCache == nil { return types.ActionContinue, replaceJsonRequestBody(request, log) } @@ -154,3 +151,10 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba return errors.New("unsupported method: " + method) } } + +func (m *moonshotProvider) GetApiName(path string) ApiName { + if strings.Contains(path, moonshotChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go index 8895489fbe..3f1303d750 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/ollama.go @@ -3,11 +3,11 @@ package provider import ( "errors" "fmt" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" + "net/http" + "strings" ) // ollamaProvider is the provider for Ollama service. @@ -53,10 +53,7 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(ollamaChatCompletionPath) - _ = util.OverwriteRequestHost(m.serviceDomain) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -64,51 +61,18 @@ func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} - if m.config.modelMapping == nil && m.contextCache == nil { - return types.ActionContinue, nil - } - - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - mappedModel := getMappedModel(model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel +func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, ollamaChatCompletionPath) + util.OverwriteRequestHostHeader(headers, m.serviceDomain) + headers.Del("Content-Length") +} - if m.contextCache != nil { - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.ollama.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.ollama.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } else { - return types.ActionContinue, err - } - } else { - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.ollama.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - return types.ActionContinue, err - } - _ = proxywasm.ResumeHttpRequest() - return types.ActionPause, nil +func (m *ollamaProvider) GetApiName(path string) ApiName { + if strings.Contains(path, ollamaChatCompletionPath) { + return ApiNameChatCompletion } + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 9f34932c1a..ab92191dfe 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -1,12 +1,13 @@ package provider import ( + "encoding/json" "fmt" + "net/http" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -57,27 +58,31 @@ func (m *openaiProvider) GetProviderType() string { } func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { + m.config.handleRequestHeaders(m, ctx, apiName, log) + return types.ActionContinue, nil +} + +func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { if m.customPath == "" { switch apiName { case ApiNameChatCompletion: - _ = util.OverwriteRequestPath(defaultOpenaiChatCompletionPath) + util.OverwriteRequestPathHeader(headers, defaultOpenaiChatCompletionPath) case ApiNameEmbeddings: ctx.DontReadRequestBody() - _ = util.OverwriteRequestPath(defaultOpenaiEmbeddingsPath) + util.OverwriteRequestPathHeader(headers, defaultOpenaiEmbeddingsPath) } } else { - _ = util.OverwriteRequestPath(m.customPath) + util.OverwriteRequestPathHeader(headers, m.customPath) } if m.customDomain == "" { - _ = util.OverwriteRequestHost(defaultOpenaiDomain) + util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain) } else { - _ = util.OverwriteRequestHost(m.customDomain) + util.OverwriteRequestHostHeader(headers, m.customDomain) } if len(m.config.apiTokens) > 0 { - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) } - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - return types.ActionContinue, nil + headers.Del("Content-Length") } func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { @@ -85,9 +90,13 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, // We don't need to process the request body for other APIs. return types.ActionContinue, nil } + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err + return nil, err } if m.config.responseJsonSchema != nil { log.Debugf("[ai-proxy] set response format to %s", m.config.responseJsonSchema) @@ -101,27 +110,9 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, request.StreamOptions.IncludeUsage = true } } - if m.contextCache == nil { - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - return types.ActionContinue, nil - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.openai.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.openai.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + return json.Marshal(request) +} + +func (m *openaiProvider) GetApiName(path string) ApiName { + return GetOpenAiApiName(path) } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index facd8bb283..41160464ac 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -1,14 +1,17 @@ package provider import ( + "encoding/json" "errors" "math/rand" + "net/http" "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" - - "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) type ApiName string @@ -106,16 +109,31 @@ var ( type Provider interface { GetProviderType() string + GetApiName(path string) ApiName } type RequestHeadersHandler interface { OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) } +type TransformRequestHeadersHandler interface { + TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) +} + type RequestBodyHandler interface { OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) } +type TransformRequestBodyHandler interface { + TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) +} + +// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body. +// Some providers (e.g. baidu, gemini) transform request headers (e.g., path) based on the request body (e.g., model). +type TransformRequestBodyHeadersHandler interface { + TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) +} + type ResponseHeadersHandler interface { OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) } @@ -141,6 +159,9 @@ type ProviderConfig struct { // @Title zh-CN 请求超时 // @Description zh-CN 请求AI服务的超时时间,单位为毫秒。默认值为120000,即2分钟 timeout uint32 `required:"false" yaml:"timeout" json:"timeout"` + // @Title zh-CN apiToken 故障切换 + // @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表 + failover *failover `required:"false" yaml:"failover" json:"failover"` // @Title zh-CN 基于OpenAI协议的自定义后端URL // @Description zh-CN 仅适用于支持 openai 协议的服务。 openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"` @@ -287,6 +308,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } } } + + failoverJson := json.Get("failover") + c.failover = &failover{ + enabled: false, + } + if failoverJson.Exists() { + c.failover.FromJson(failoverJson) + } } func (c *ProviderConfig) Validate() error { @@ -302,6 +331,12 @@ func (c *ProviderConfig) Validate() error { } } + if c.failover.enabled { + if err := c.failover.Validate(); err != nil { + return err + } + } + if c.typ == "" { return errors.New("missing type in provider config") } @@ -353,6 +388,60 @@ func CreateProvider(pc ProviderConfig) (Provider, error) { return initializer.CreateProvider(pc) } +func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte, log wrapper.Log) error { + switch req := request.(type) { + case *chatCompletionRequest: + if err := decodeChatCompletionRequest(body, req); err != nil { + return err + } + + streaming := req.Stream + if streaming { + _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") + } + + return c.setRequestModel(ctx, req, log) + case *embeddingsRequest: + if err := decodeEmbeddingsRequest(body, req); err != nil { + return err + } + return c.setRequestModel(ctx, req, log) + default: + return errors.New("unsupported request type") + } +} + +func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}, log wrapper.Log) error { + var model *string + + switch req := request.(type) { + case *chatCompletionRequest: + model = &req.Model + case *embeddingsRequest: + model = &req.Model + default: + return errors.New("unsupported request type") + } + + return c.mapModel(ctx, model, log) +} + +func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wrapper.Log) error { + if *model == "" { + return errors.New("missing model in request") + } + ctx.SetContext(ctxKeyOriginalRequestModel, *model) + + mappedModel := getMappedModel(*model, c.modelMapping, log) + if mappedModel == "" { + return errors.New("model becomes empty after applying the configured mapping") + } + + *model = mappedModel + ctx.SetContext(ctxKeyFinalRequestModel, *model) + return nil +} + func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string { mappedModel := doGetMappedModel(model, modelMapping, log) if len(mappedModel) != 0 { @@ -389,3 +478,72 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper. return "" } + +func (c *ProviderConfig) handleRequestBody( + provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log, +) (types.Action, error) { + // use original protocol + if c.protocol == protocolOriginal { + return types.ActionContinue, nil + } + + // use openai protocol + var err error + if handler, ok := provider.(TransformRequestBodyHandler); ok { + body, err = handler.TransformRequestBody(ctx, apiName, body, log) + } else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok { + headers := util.GetOriginalHttpHeaders() + body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log) + util.ReplaceOriginalHttpHeaders(headers) + } else { + body, err = c.defaultTransformRequestBody(ctx, apiName, body, log) + } + + if err != nil { + return types.ActionContinue, err + } + + if apiName == ApiNameChatCompletion { + if c.context == nil { + return types.ActionContinue, replaceHttpJsonRequestBody(body, log) + } + err = contextCache.GetContextFromFile(ctx, provider, body, log) + + if err == nil { + return types.ActionPause, nil + } + return types.ActionContinue, err + } + return types.ActionContinue, replaceHttpJsonRequestBody(body, log) +} + +func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) { + if handler, ok := provider.(TransformRequestHeadersHandler); ok { + originalHeaders := util.GetOriginalHttpHeaders() + handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log) + util.ReplaceOriginalHttpHeaders(originalHeaders) + } +} + +func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { + var request interface{} + if apiName == ApiNameChatCompletion { + request = &chatCompletionRequest{} + } else { + request = &embeddingsRequest{} + } + if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil { + return nil, err + } + return json.Marshal(request) +} + +func GetOpenAiApiName(path string) ApiName { + if strings.HasSuffix(path, "/v1/chat/completions") { + return ApiNameChatCompletion + } + if strings.HasSuffix(path, "/v1/embeddings") { + return ApiNameEmbeddings + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index f673fa98b2..771feeb51e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math" + "net/http" "reflect" "strings" "time" @@ -58,35 +59,50 @@ func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provide } type qwenProvider struct { - config ProviderConfig - + config ProviderConfig contextCache *contextCache } +func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestHostHeader(headers, qwenDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + + if m.config.qwenEnableCompatible { + util.OverwriteRequestPathHeader(headers, qwenCompatiblePath) + } else if apiName == ApiNameChatCompletion { + util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath) + } else if apiName == ApiNameEmbeddings { + util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath) + } + + headers.Del("Accept-Encoding") + headers.Del("Content-Length") +} + +func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { + if apiName == ApiNameChatCompletion { + return m.onChatCompletionRequestBody(ctx, body, headers, log) + } else { + return m.onEmbeddingsRequestBody(ctx, body, log) + } +} + func (m *qwenProvider) GetProviderType() string { return providerTypeQwen } func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - _ = util.OverwriteRequestHost(qwenDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) + if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + return types.ActionContinue, errUnsupportedApiName + } + + m.config.handleRequestHeaders(m, ctx, apiName, log) if m.config.protocol == protocolOriginal { ctx.DontReadRequestBody() return types.ActionContinue, nil - } else if m.config.qwenEnableCompatible { - _ = util.OverwriteRequestPath(qwenCompatiblePath) - } else if apiName == ApiNameChatCompletion { - _ = util.OverwriteRequestPath(qwenChatCompletionPath) - } else if apiName == ApiNameEmbeddings { - _ = util.OverwriteRequestPath(qwenTextEmbeddingPath) - } else { - return types.ActionContinue, errUnsupportedApiName } - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") - // Delay the header processing to allow changing streaming mode in OnRequestBody return types.HeaderStopIteration, nil } @@ -121,65 +137,23 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b } return types.ActionContinue, nil } - if apiName == ApiNameChatCompletion { - return m.onChatCompletionRequestBody(ctx, body, log) - } - if apiName == ApiNameEmbeddings { - return m.onEmbeddingsRequestBody(ctx, body, log) - } - return types.ActionContinue, errUnsupportedApiName -} -func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { - if m.config.protocol == protocolOriginal { - if m.config.context == nil { - return types.ActionContinue, nil - } - - request := &qwenTextGenRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.qwen.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - m.insertContextMessage(request, content, false) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.qwen.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { + return types.ActionContinue, errUnsupportedApiName } + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} +func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err + err := m.config.parseRequestAndMapModel(ctx, request, body, log) + if err != nil { + return nil, err } - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) // Use the qwen multimodal model generation API if strings.HasPrefix(request.Model, qwenVlModelPrefixName) { - _ = util.OverwriteRequestPath(qwenMultimodalGenerationPath) + util.OverwriteRequestPathHeader(headers, qwenMultimodalGenerationPath) } streaming := request.Stream @@ -191,62 +165,20 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body _ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE") } - if m.config.context == nil { - qwenRequest := m.buildQwenTextGenerationRequest(request, streaming) - if streaming { - ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput) - } - return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log) - } - - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.qwen.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - qwenRequest := m.buildQwenTextGenerationRequest(request, streaming) - if streaming { - ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput) - } - if err := replaceJsonRequestBody(qwenRequest, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.qwen.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil - } - return types.ActionContinue, err + return m.buildQwenTextGenerationRequest(ctx, request, streaming) } -func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { +func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) { request := &embeddingsRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) + if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil { + return nil, err } - log.Debugf("=== embeddings request: %v", request) - - model := request.Model - if model == "" { - return types.ActionContinue, errors.New("missing model in the request") - } - ctx.SetContext(ctxKeyOriginalRequestModel, model) - mappedModel := getMappedModel(model, m.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - request.Model = mappedModel - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - - if qwenRequest, err := m.buildQwenTextEmbeddingRequest(request); err == nil { - return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log) - } else { - return types.ActionContinue, err + qwenRequest, err := m.buildQwenTextEmbeddingRequest(request) + if err != nil { + return nil, err } + return json.Marshal(qwenRequest) } func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { @@ -375,7 +307,7 @@ func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body [] return types.ActionContinue, replaceJsonResponseBody(response, log) } -func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletionRequest, streaming bool) *qwenTextGenRequest { +func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) { messages := make([]qwenMessage, 0, len(origRequest.Messages)) for i := range origRequest.Messages { messages = append(messages, chatMessage2QwenMessage(origRequest.Messages[i])) @@ -397,6 +329,11 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio Tools: origRequest.Tools, }, } + + if streaming { + ctx.SetContext(ctxKeyIncrementalStreaming, request.Parameters.IncrementalOutput) + } + if len(m.config.qwenFileIds) != 0 && origRequest.Model == qwenLongModelName { builder := strings.Builder{} for _, fileId := range m.config.qwenFileIds { @@ -406,13 +343,15 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio builder.WriteString("fileid://") builder.WriteString(fileId) } - contextMessageId := m.insertContextMessage(request, builder.String(), true) - if contextMessageId == 0 { - // The context message cannot come first. We need to add another dummy system message before it. - request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...) + + body, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("unable to marshal request: %v", err) } + + return m.insertHttpContextMessage(body, builder.String(), true) } - return request + return json.Marshal(request) } func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) *chatCompletionResponse { @@ -569,7 +508,12 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild return nil } -func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string, onlyOneSystemBeforeFile bool) int { +func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) { + request := &qwenTextGenRequest{} + if err := json.Unmarshal(body, request); err != nil { + return nil, fmt.Errorf("unable to unmarshal request: %v", err) + } + fileMessage := qwenMessage{ Role: roleSystem, Content: content, @@ -586,10 +530,8 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content } if firstNonSystemMessageIndex == 0 { request.Input.Messages = append([]qwenMessage{fileMessage}, request.Input.Messages...) - return 0 } else if !onlyOneSystemBeforeFile { request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]qwenMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...) - return firstNonSystemMessageIndex } else { builder := strings.Builder{} for _, message := range request.Input.Messages[:firstNonSystemMessageIndex] { @@ -599,8 +541,15 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content builder.WriteString(message.StringContent()) } request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: builder.String()}, fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...) - return 1 + firstNonSystemMessageIndex = 1 + } + + if firstNonSystemMessageIndex == 0 { + // The context message cannot come first. We need to add another dummy system message before it. + request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...) } + + return json.Marshal(request) } func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) { @@ -804,3 +753,16 @@ func chatMessage2QwenMessage(chatMessage chatMessage) qwenMessage { } } } + +func (m *qwenProvider) GetApiName(path string) ApiName { + switch { + case strings.Contains(path, qwenChatCompletionPath), + strings.Contains(path, qwenMultimodalGenerationPath), + strings.Contains(path, qwenCompatiblePath): + return ApiNameChatCompletion + case strings.Contains(path, qwenTextEmbeddingPath): + return ApiNameEmbeddings + default: + return "" + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go index 19060849ac..dd9864702e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/request_helper.go @@ -3,7 +3,6 @@ package provider import ( "encoding/json" "fmt" - "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" ) @@ -18,6 +17,13 @@ func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) er return nil } +func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error { + if err := json.Unmarshal(body, request); err != nil { + return fmt.Errorf("unable to unmarshal request: %v", err) + } + return nil +} + func replaceJsonRequestBody(request interface{}, log wrapper.Log) error { body, err := json.Marshal(request) if err != nil { @@ -31,6 +37,15 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error { return err } +func replaceHttpJsonRequestBody(body []byte, log wrapper.Log) error { + log.Debugf("request body: %s", string(body)) + err := proxywasm.ReplaceHttpRequestBody(body) + if err != nil { + return fmt.Errorf("unable to replace the original request body: %v", err) + } + return nil +} + func insertContextMessage(request *chatCompletionRequest, content string) { fileMessage := chatMessage{ Role: roleSystem, diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go index fc266dfbaa..e39bdaded9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/spark.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/spark.go @@ -2,8 +2,8 @@ package provider import ( "encoding/json" - "errors" "fmt" + "net/http" "strings" "time" @@ -71,11 +71,7 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestHost(sparkHost) - _ = util.OverwriteRequestPath(sparkChatCompletionPath) - _ = util.OverwriteRequestAuthorization("Bearer " + p.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + p.config.handleRequestHeaders(p, ctx, apiName, log) return types.ActionContinue, nil } @@ -83,36 +79,7 @@ func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - // 使用Spark协议 - if p.config.protocol == protocolOriginal { - request := &sparkRequest{} - if err := json.Unmarshal(body, request); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err) - } - if request.Model == "" { - return types.ActionContinue, errors.New("request model is empty") - } - // 目前星火在模型名称错误时,也会调用generalv3,这里还是按照输入的模型名称设置响应里的模型名称 - ctx.SetContext(ctxKeyFinalRequestModel, request.Model) - return types.ActionContinue, replaceJsonRequestBody(request, log) - } else { - // 使用openai协议 - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - if request.Model == "" { - return types.ActionContinue, errors.New("missing model in chat completion request") - } - // 映射模型 - mappedModel := getMappedModel(request.Model, p.config.modelMapping, log) - if mappedModel == "" { - return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping") - } - ctx.SetContext(ctxKeyFinalRequestModel, mappedModel) - request.Model = mappedModel - return types.ActionContinue, replaceJsonRequestBody(request, log) - } + return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log) } func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { @@ -205,3 +172,18 @@ func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, resp func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) } + +func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath) + util.OverwriteRequestHostHeader(headers, sparkHost) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx)) + headers.Del("Accept-Encoding") + headers.Del("Content-Length") +} + +func (p *sparkProvider) GetApiName(path string) ApiName { + if strings.Contains(path, sparkChatCompletionPath) { + return ApiNameChatCompletion + } + return "" +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go index dd6792ed65..f96e59e65b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/stepfun.go @@ -2,11 +2,11 @@ package provider import ( "errors" - "fmt" + "net/http" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -45,10 +45,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(stepfunChatCompletionPath) - _ = util.OverwriteRequestHost(stepfunDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -56,28 +53,19 @@ func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.stepfun.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.stepfun.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, stepfunChatCompletionPath) + util.OverwriteRequestHostHeader(headers, stepfunDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *stepfunProvider) GetApiName(path string) ApiName { + if strings.Contains(path, stepfunChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go index 287945d903..ef1141304e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/yi.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/yi.go @@ -2,11 +2,11 @@ package provider import ( "errors" - "fmt" + "net/http" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -45,10 +45,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(yiChatCompletionPath) - _ = util.OverwriteRequestHost(yiDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -56,28 +53,19 @@ func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, bod if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.yi.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.yi.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, yiChatCompletionPath) + util.OverwriteRequestHostHeader(headers, yiDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *yiProvider) GetApiName(path string) ApiName { + if strings.Contains(path, yiChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go index 9640cd02f4..40fbe4ef88 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/zhipuai.go @@ -2,11 +2,11 @@ package provider import ( "errors" - "fmt" + "net/http" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) @@ -44,10 +44,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - _ = util.OverwriteRequestPath(zhipuAiChatCompletionPath) - _ = util.OverwriteRequestHost(zhipuAiDomain) - _ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken()) - _ = proxywasm.RemoveHttpRequestHeader("Content-Length") + m.config.handleRequestHeaders(m, ctx, apiName, log) return types.ActionContinue, nil } @@ -55,28 +52,19 @@ func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - if m.contextCache == nil { - return types.ActionContinue, nil - } - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return types.ActionContinue, err - } - err := m.contextCache.GetContent(func(content string, err error) { - defer func() { - _ = proxywasm.ResumeHttpRequest() - }() - if err != nil { - log.Errorf("failed to load context file: %v", err) - _ = util.SendResponse(500, "ai-proxy.zhihupai.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err)) - } - insertContextMessage(request, content) - if err := replaceJsonRequestBody(request, log); err != nil { - _ = util.SendResponse(500, "ai-proxy.zhihupai.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err)) - } - }, log) - if err == nil { - return types.ActionPause, nil + return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) +} + +func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, zhipuAiChatCompletionPath) + util.OverwriteRequestHostHeader(headers, zhipuAiDomain) + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) + headers.Del("Content-Length") +} + +func (m *zhipuAiProvider) GetApiName(path string) ApiName { + if strings.Contains(path, zhipuAiChatCompletionPath) { + return ApiNameChatCompletion } - return types.ActionContinue, err + return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/util/http.go b/plugins/wasm-go/extensions/ai-proxy/util/http.go index 43135ec0a2..f0d4c0ce7c 100644 --- a/plugins/wasm-go/extensions/ai-proxy/util/http.go +++ b/plugins/wasm-go/extensions/ai-proxy/util/http.go @@ -1,6 +1,10 @@ package util -import "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" +import ( + "net/http" + + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" +) const ( HeaderContentType = "Content-Type" @@ -21,13 +25,6 @@ func CreateHeaders(kvs ...string) [][2]string { return headers } -func OverwriteRequestHost(host string) error { - if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil { - _ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-HOST", originHost) - } - return proxywasm.ReplaceHttpRequestHeader(":authority", host) -} - func OverwriteRequestPath(path string) error { if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil { _ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-PATH", originPath) @@ -43,3 +40,56 @@ func OverwriteRequestAuthorization(credential string) error { } return proxywasm.ReplaceHttpRequestHeader("Authorization", credential) } + +func OverwriteRequestHostHeader(headers http.Header, host string) { + if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil { + headers.Set("X-ENVOY-ORIGINAL-HOST", originHost) + } + headers.Set(":authority", host) +} + +func OverwriteRequestPathHeader(headers http.Header, path string) { + if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil { + headers.Set("X-ENVOY-ORIGINAL-PATH", originPath) + } + headers.Set(":path", path) +} + +func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) { + if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" { + if originAuth := headers.Get("Authorization"); originAuth != "" { + headers.Set("X-HI-ORIGINAL-AUTH", originAuth) + } + } + headers.Set("Authorization", credential) +} + +func HeaderToSlice(header http.Header) [][2]string { + slice := make([][2]string, 0, len(header)) + for key, values := range header { + for _, value := range values { + slice = append(slice, [2]string{key, value}) + } + } + return slice +} + +func SliceToHeader(slice [][2]string) http.Header { + header := make(http.Header) + for _, pair := range slice { + key := pair[0] + value := pair[1] + header.Add(key, value) + } + return header +} + +func GetOriginalHttpHeaders() http.Header { + originalHeaders, _ := proxywasm.GetHttpRequestHeaders() + return SliceToHeader(originalHeaders) +} + +func ReplaceOriginalHttpHeaders(headers http.Header) { + modifiedHeaders := HeaderToSlice(headers) + _ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders) +} diff --git a/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go b/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go index 96600192b1..e797394b54 100644 --- a/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/cluster_wrapper.go @@ -45,6 +45,19 @@ func (c RouteCluster) HostName() string { return GetRequestHost() } +type TargetCluster struct { + Host string + Cluster string +} + +func (c TargetCluster) ClusterName() string { + return c.Cluster +} + +func (c TargetCluster) HostName() string { + return c.Host +} + type K8sCluster struct { ServiceName string Namespace string