diff --git a/plugins/wasm-go/extensions/ai-cache/.gitignore b/plugins/wasm-go/extensions/ai-cache/.gitignore index 47db8eedba..8a34bf52ad 100644 --- a/plugins/wasm-go/extensions/ai-cache/.gitignore +++ b/plugins/wasm-go/extensions/ai-cache/.gitignore @@ -1,5 +1,5 @@ # File generated by hgctl. Modify as required. - +docker-compose-test/ * !/.gitignore diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 1de252f12c..ca91bdf5a1 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -1,9 +1,15 @@ +## 简介 --- title: AI 缓存 keywords: [higress,ai cache] description: AI 缓存插件配置参考 --- +**Note** + +> 需要数据面的proxy wasm版本大于等于0.2.100 +> 编译时,需要带上版本的tag,例如:`tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags="custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100" ./` +> ## 功能说明 @@ -19,33 +25,113 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 插件执行阶段:`认证阶段` 插件执行优先级:`10` +## 配置说明 +配置分为 3 个部分:向量数据库(vector);文本向量化接口(embedding);缓存数据库(cache),同时也提供了细粒度的 LLM 请求/响应提取参数配置等。 + ## 配置说明 -| Name | Type | Requirement | Default | Description | -| -------- | -------- | -------- | -------- | -------- | -| cacheKeyFrom.requestBody | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheValueFrom.responseBody | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheStreamValueFrom.responseBody | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheKeyPrefix | string | optional | "higress-ai-cache:" | Redis缓存Key的前缀 | -| cacheTTL | integer | optional | 0 | 缓存的过期时间,单位是秒,默认值为0,即永不过期 | -| redis.serviceName | string | requried | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local | -| redis.servicePort | integer | optional | 6379 | redis 服务端口 | -| redis.timeout | integer | optional | 1000 | 请求 redis 的超时时间,单位为毫秒 | -| redis.username | string | optional | - | 登陆 redis 的用户名 | -| redis.password | string | optional | - | 登陆 redis 的密码 | -| returnResponseTemplate | string | optional | `{"id":"from-cache","choices":[%s],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | -| returnStreamResponseTemplate | string | optional | `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | +本插件同时支持基于向量数据库的语义化缓存和基于字符串匹配的缓存方法,如果同时配置了向量数据库和缓存数据库,优先使用向量数据库。 + +*Note*: 向量数据库(vector) 和 缓存数据库(cache) 不能同时为空,否则本插件无法提供缓存服务。 + +| Name | Type | Requirement | Default | Description | +| --- | --- | --- | --- | --- | +| vector | string | optional | "" | 向量存储服务提供者类型,例如 dashvector | +| embedding | string | optional | "" | 请求文本向量化服务类型,例如 dashscope | +| cache | string | optional | "" | 缓存服务类型,例如 redis | +| cacheKeyStrategy | string | optional | "lastQuestion" | 决定如何根据历史问题生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存) | +| enableSemanticCache | bool | optional | true | 是否启用语义化缓存, 若不启用,则使用字符串匹配的方式来查找缓存,此时需要配置cache服务 | + +根据是否需要启用语义缓存,可以只配置组件的组合为: +1. `cache`: 仅启用字符串匹配缓存 +3. `vector (+ embedding)`: 启用语义化缓存, 其中若 `vector` 未提供字符串表征服务,则需要自行配置 `embedding` 服务 +2. `vector (+ embedding) + cache`: 启用语义化缓存并用缓存服务存储LLM响应以加速 + +注意若不配置相关组件,则可以忽略相应组件的`required`字段。 + + +## 向量数据库服务(vector) +| Name | Type | Requirement | Default | Description | +| --- | --- | --- | --- | --- | +| vector.type | string | required | "" | 向量存储服务提供者类型,例如 dashvector | +| vector.serviceName | string | required | "" | 向量存储服务名称 | +| vector.serviceHost | string | required | "" | 向量存储服务域名 | +| vector.servicePort | int64 | optional | 443 | 向量存储服务端口 | +| vector.apiKey | string | optional | "" | 向量存储服务 API Key | +| vector.topK | int | optional | 1 | 返回TopK结果,默认为 1 | +| vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 | +| vector.collectionID | string | optional | "" | dashvector 向量存储服务 Collection ID | +| vector.threshold | float64 | optional | 1000 | 向量相似度度量阈值 | +| vector.thresholdRelation | string | optional | lt | 相似度度量方式有 `Cosine`, `DotProduct`, `Euclidean` 等,前两者值越大相似度越高,后者值越小相似度越高。对于 `Cosine` 和 `DotProduct` 选择 `gt`,对于 `Euclidean` 则选择 `lt`。默认为 `lt`,所有条件包括 `lt` (less than,小于)、`lte` (less than or equal to,小等于)、`gt` (greater than,大于)、`gte` (greater than or equal to,大等于) | + +## 文本向量化服务(embedding) +| Name | Type | Requirement | Default | Description | +| --- | --- | --- | --- | --- | +| embedding.type | string | required | "" | 请求文本向量化服务类型,例如 dashscope | +| embedding.serviceName | string | required | "" | 请求文本向量化服务名称 | +| embedding.serviceHost | string | optional | "" | 请求文本向量化服务域名 | +| embedding.servicePort | int64 | optional | 443 | 请求文本向量化服务端口 | +| embedding.apiKey | string | optional | "" | 请求文本向量化服务的 API Key | +| embedding.timeout | uint32 | optional | 10000 | 请求文本向量化服务的超时时间,单位为毫秒。默认值是10000,即10秒 | +| embedding.model | string | optional | "" | 请求文本向量化服务的模型名称 | + + +## 缓存服务(cache) +| cache.type | string | required | "" | 缓存服务类型,例如 redis | +| --- | --- | --- | --- | --- | +| cache.serviceName | string | required | "" | 缓存服务名称 | +| cache.serviceHost | string | required | "" | 缓存服务域名 | +| cache.servicePort | int64 | optional | 6379 | 缓存服务端口 | +| cache.username | string | optional | "" | 缓存服务用户名 | +| cache.password | string | optional | "" | 缓存服务密码 | +| cache.timeout | uint32 | optional | 10000 | 缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒 | +| cache.cacheTTL | int | optional | 0 | 缓存过期时间,单位为秒。默认值是 0,即 永不过期| +| cacheKeyPrefix | string | optional | "higress-ai-cache:" | 缓存 Key 的前缀,默认值为 "higress-ai-cache:" | + + +## 其他配置 +| Name | Type | Requirement | Default | Description | +| --- | --- | --- | --- | --- | +| cacheKeyFrom | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| cacheValueFrom | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| cacheStreamValueFrom | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| cacheToolCallsFrom | string | optional | "choices.0.delta.content.tool_calls" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| responseTemplate | string | optional | `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | +| streamResponseTemplate | string | optional | `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | + ## 配置示例 +### 基础配置 +```yaml +embedding: + type: dashscope + serviceName: my_dashscope.dns + apiKey: [Your Key] + +vector: + type: dashvector + serviceName: my_dashvector.dns + collectionID: [Your Collection ID] + serviceDomain: [Your domain] + apiKey: [Your key] + +cache: + type: redis + serviceName: my_redis.dns + servicePort: 6379 + timeout: 100 +``` + +旧版本配置兼容 ```yaml redis: - serviceName: my-redis.dns - timeout: 2000 + serviceName: my_redis.dns + servicePort: 6379 + timeout: 100 ``` ## 进阶用法 - 当前默认的缓存 key 是基于 GJSON PATH 的表达式:`messages.@reverse.0.content` 提取,含义是把 messages 数组反转后取第一项的 content; GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user 的 content 作为 key,可以写成: `messages.@reverse.#(role=="user").content`; @@ -55,3 +141,7 @@ GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user 还可以支持管道语法,例如希望取到数第二个 role 为 user 的 content 作为 key,可以写成:`messages.@reverse.#(role=="user")#.content|1`。 更多用法可以参考[官方文档](https://github.com/tidwall/gjson/blob/master/SYNTAX.md),可以使用 [GJSON Playground](https://gjson.dev/) 进行语法测试。 + +## 常见问题 + +1. 如果返回的错误为 `error status returned by host: bad argument`,请检查`serviceName`是否正确包含了服务的类型后缀(.dns等)。 \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go new file mode 100644 index 0000000000..1238d21570 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -0,0 +1,135 @@ +package cache + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + PROVIDER_TYPE_REDIS = "redis" + DEFAULT_CACHE_PREFIX = "higress-ai-cache:" +) + +type providerInitializer interface { + ValidateConfig(ProviderConfig) error + CreateProvider(ProviderConfig) (Provider, error) +} + +var ( + providerInitializers = map[string]providerInitializer{ + PROVIDER_TYPE_REDIS: &redisProviderInitializer{}, + } +) + +type ProviderConfig struct { + // @Title zh-CN redis 缓存服务提供者类型 + // @Description zh-CN 缓存服务提供者类型,例如 redis + typ string + // @Title zh-CN redis 缓存服务名称 + // @Description zh-CN 缓存服务名称 + serviceName string + // @Title zh-CN redis 缓存服务端口 + // @Description zh-CN 缓存服务端口,默认值为6379 + servicePort int + // @Title zh-CN redis 缓存服务地址 + // @Description zh-CN Cache 缓存服务地址,非必填 + serviceHost string + // @Title zh-CN 缓存服务用户名 + // @Description zh-CN 缓存服务用户名,非必填 + username string + // @Title zh-CN 缓存服务密码 + // @Description zh-CN 缓存服务密码,非必填 + password string + // @Title zh-CN 请求超时 + // @Description zh-CN 请求缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒 + timeout uint32 + // @Title zh-CN 缓存过期时间 + // @Description zh-CN 缓存过期时间,单位为秒。默认值是0,即永不过期 + cacheTTL int + // @Title 缓存 Key 前缀 + // @Description 缓存 Key 的前缀,默认值为 "higressAiCache:" + cacheKeyPrefix string +} + +func (c *ProviderConfig) GetProviderType() string { + return c.typ +} + +func (c *ProviderConfig) FromJson(json gjson.Result) { + c.typ = json.Get("type").String() + c.serviceName = json.Get("serviceName").String() + c.servicePort = int(json.Get("servicePort").Int()) + if !json.Get("servicePort").Exists() { + c.servicePort = 6379 + } + c.serviceHost = json.Get("serviceHost").String() + c.username = json.Get("username").String() + if !json.Get("username").Exists() { + c.username = "" + } + c.password = json.Get("password").String() + if !json.Get("password").Exists() { + c.password = "" + } + c.timeout = uint32(json.Get("timeout").Int()) + if !json.Get("timeout").Exists() { + c.timeout = 10000 + } + c.cacheTTL = int(json.Get("cacheTTL").Int()) + if !json.Get("cacheTTL").Exists() { + c.cacheTTL = 0 + // c.cacheTTL = 3600000 + } + if json.Get("cacheKeyPrefix").Exists() { + c.cacheKeyPrefix = json.Get("cacheKeyPrefix").String() + } else { + c.cacheKeyPrefix = DEFAULT_CACHE_PREFIX + } + +} + +func (c *ProviderConfig) ConvertLegacyJson(json gjson.Result) { + c.FromJson(json.Get("redis")) + c.typ = "redis" + if json.Get("cacheTTL").Exists() { + c.cacheTTL = int(json.Get("cacheTTL").Int()) + } +} + +func (c *ProviderConfig) Validate() error { + if c.typ == "" { + return errors.New("cache service type is required") + } + if c.serviceName == "" { + return errors.New("cache service name is required") + } + if c.cacheTTL < 0 { + return errors.New("cache TTL must be greater than or equal to 0") + } + initializer, has := providerInitializers[c.typ] + if !has { + return errors.New("unknown cache service provider type: " + c.typ) + } + if err := initializer.ValidateConfig(*c); err != nil { + return err + } + return nil +} + +func CreateProvider(pc ProviderConfig) (Provider, error) { + initializer, has := providerInitializers[pc.typ] + if !has { + return nil, errors.New("unknown provider type: " + pc.typ) + } + return initializer.CreateProvider(pc) +} + +type Provider interface { + GetProviderType() string + Init(username string, password string, timeout uint32) error + Get(key string, cb wrapper.RedisResponseCallback) error + Set(key string, value string, cb wrapper.RedisResponseCallback) error + GetCacheKeyPrefix() string +} diff --git a/plugins/wasm-go/extensions/ai-cache/cache/redis.go b/plugins/wasm-go/extensions/ai-cache/cache/redis.go new file mode 100644 index 0000000000..4cb69744e1 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/cache/redis.go @@ -0,0 +1,58 @@ +package cache + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +type redisProviderInitializer struct { +} + +func (r *redisProviderInitializer) ValidateConfig(cf ProviderConfig) error { + if len(cf.serviceName) == 0 { + return errors.New("cache service name is required") + } + return nil +} + +func (r *redisProviderInitializer) CreateProvider(cf ProviderConfig) (Provider, error) { + rp := redisProvider{ + config: cf, + client: wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ + FQDN: cf.serviceName, + Host: cf.serviceHost, + Port: int64(cf.servicePort)}), + } + err := rp.Init(cf.username, cf.password, cf.timeout) + return &rp, err +} + +type redisProvider struct { + config ProviderConfig + client wrapper.RedisClient +} + +func (rp *redisProvider) GetProviderType() string { + return PROVIDER_TYPE_REDIS +} + +func (rp *redisProvider) Init(username string, password string, timeout uint32) error { + return rp.client.Init(rp.config.username, rp.config.password, int64(rp.config.timeout)) +} + +func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) error { + return rp.client.Get(key, cb) +} + +func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) error { + if rp.config.cacheTTL == 0 { + return rp.client.Set(key, value, cb) + } else { + return rp.client.SetEx(key, value, rp.config.cacheTTL, cb) + } +} + +func (rp *redisProvider) GetCacheKeyPrefix() string { + return rp.config.cacheKeyPrefix +} diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go new file mode 100644 index 0000000000..4bd6e2a18f --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -0,0 +1,225 @@ +package config + +import ( + "fmt" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/cache" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + CACHE_KEY_STRATEGY_LAST_QUESTION = "lastQuestion" + CACHE_KEY_STRATEGY_ALL_QUESTIONS = "allQuestions" + CACHE_KEY_STRATEGY_DISABLED = "disabled" +) + +type PluginConfig struct { + // @Title zh-CN 返回 HTTP 响应的模版 + // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 + ResponseTemplate string + // @Title zh-CN 返回流式 HTTP 响应的模版 + // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 + StreamResponseTemplate string + + cacheProvider cache.Provider + embeddingProvider embedding.Provider + vectorProvider vector.Provider + + embeddingProviderConfig embedding.ProviderConfig + vectorProviderConfig vector.ProviderConfig + cacheProviderConfig cache.ProviderConfig + + CacheKeyFrom string + CacheValueFrom string + CacheStreamValueFrom string + CacheToolCallsFrom string + + // @Title zh-CN 启用语义化缓存 + // @Description zh-CN 控制是否启用语义化缓存功能。true 表示启用,false 表示禁用。 + EnableSemanticCache bool + + // @Title zh-CN 缓存键策略 + // @Description zh-CN 决定如何生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存) + CacheKeyStrategy string +} + +func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) { + + c.vectorProviderConfig.FromJson(json.Get("vector")) + c.embeddingProviderConfig.FromJson(json.Get("embedding")) + c.cacheProviderConfig.FromJson(json.Get("cache")) + if json.Get("redis").Exists() { + // compatible with legacy config + c.cacheProviderConfig.ConvertLegacyJson(json) + } + + c.CacheKeyStrategy = json.Get("cacheKeyStrategy").String() + if c.CacheKeyStrategy == "" { + c.CacheKeyStrategy = CACHE_KEY_STRATEGY_LAST_QUESTION // set default value + } + c.CacheKeyFrom = json.Get("cacheKeyFrom").String() + if c.CacheKeyFrom == "" { + c.CacheKeyFrom = "messages.@reverse.0.content" + } + c.CacheValueFrom = json.Get("cacheValueFrom").String() + if c.CacheValueFrom == "" { + c.CacheValueFrom = "choices.0.message.content" + } + c.CacheStreamValueFrom = json.Get("cacheStreamValueFrom").String() + if c.CacheStreamValueFrom == "" { + c.CacheStreamValueFrom = "choices.0.delta.content" + } + c.CacheToolCallsFrom = json.Get("cacheToolCallsFrom").String() + if c.CacheToolCallsFrom == "" { + c.CacheToolCallsFrom = "choices.0.delta.content.tool_calls" + } + + c.StreamResponseTemplate = json.Get("streamResponseTemplate").String() + if c.StreamResponseTemplate == "" { + c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" + } + c.ResponseTemplate = json.Get("responseTemplate").String() + if c.ResponseTemplate == "" { + c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + } + + if json.Get("enableSemanticCache").Exists() { + c.EnableSemanticCache = json.Get("enableSemanticCache").Bool() + } else { + c.EnableSemanticCache = true // set default value to true + } + + // compatible with legacy config + convertLegacyMapFields(c, json, log) +} + +func (c *PluginConfig) Validate() error { + // if cache provider is configured, validate it + if c.cacheProviderConfig.GetProviderType() != "" { + if err := c.cacheProviderConfig.Validate(); err != nil { + return err + } + } + if c.embeddingProviderConfig.GetProviderType() != "" { + if err := c.embeddingProviderConfig.Validate(); err != nil { + return err + } + } + if c.vectorProviderConfig.GetProviderType() != "" { + if err := c.vectorProviderConfig.Validate(); err != nil { + return err + } + } + + // cache, vector, and embedding cannot all be empty + if c.vectorProviderConfig.GetProviderType() == "" && + c.embeddingProviderConfig.GetProviderType() == "" && + c.cacheProviderConfig.GetProviderType() == "" { + return fmt.Errorf("vector, embedding and cache provider cannot be all empty") + } + + // Validate the value of CacheKeyStrategy + if c.CacheKeyStrategy != CACHE_KEY_STRATEGY_LAST_QUESTION && + c.CacheKeyStrategy != CACHE_KEY_STRATEGY_ALL_QUESTIONS && + c.CacheKeyStrategy != CACHE_KEY_STRATEGY_DISABLED { + return fmt.Errorf("invalid CacheKeyStrategy: %s", c.CacheKeyStrategy) + } + + // If semantic cache is enabled, ensure necessary components are configured + // if c.EnableSemanticCache { + // if c.embeddingProviderConfig.GetProviderType() == "" { + // return fmt.Errorf("semantic cache is enabled but embedding provider is not configured") + // } + // // if only configure cache, just warn the user + // } + return nil +} + +func (c *PluginConfig) Complete(log wrapper.Log) error { + var err error + if c.embeddingProviderConfig.GetProviderType() != "" { + log.Debugf("embedding provider is set to %s", c.embeddingProviderConfig.GetProviderType()) + c.embeddingProvider, err = embedding.CreateProvider(c.embeddingProviderConfig) + if err != nil { + return err + } + } else { + log.Info("embedding provider is not configured") + c.embeddingProvider = nil + } + if c.cacheProviderConfig.GetProviderType() != "" { + log.Debugf("cache provider is set to %s", c.cacheProviderConfig.GetProviderType()) + c.cacheProvider, err = cache.CreateProvider(c.cacheProviderConfig) + if err != nil { + return err + } + } else { + log.Info("cache provider is not configured") + c.cacheProvider = nil + } + if c.vectorProviderConfig.GetProviderType() != "" { + log.Debugf("vector provider is set to %s", c.vectorProviderConfig.GetProviderType()) + c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig) + if err != nil { + return err + } + } else { + log.Info("vector provider is not configured") + c.vectorProvider = nil + } + return nil +} + +func (c *PluginConfig) GetEmbeddingProvider() embedding.Provider { + return c.embeddingProvider +} + +func (c *PluginConfig) GetVectorProvider() vector.Provider { + return c.vectorProvider +} + +func (c *PluginConfig) GetVectorProviderConfig() vector.ProviderConfig { + return c.vectorProviderConfig +} + +func (c *PluginConfig) GetCacheProvider() cache.Provider { + return c.cacheProvider +} + +func convertLegacyMapFields(c *PluginConfig, json gjson.Result, log wrapper.Log) { + keyMap := map[string]string{ + "cacheKeyFrom.requestBody": "cacheKeyFrom", + "cacheValueFrom.requestBody": "cacheValueFrom", + "cacheStreamValueFrom.requestBody": "cacheStreamValueFrom", + "returnResponseTemplate": "responseTemplate", + "returnStreamResponseTemplate": "streamResponseTemplate", + } + + for oldKey, newKey := range keyMap { + if json.Get(oldKey).Exists() { + log.Debugf("[convertLegacyMapFields] mapping %s to %s", oldKey, newKey) + setField(c, newKey, json.Get(oldKey).String(), log) + } else { + log.Debugf("[convertLegacyMapFields] %s not exists", oldKey) + } + } +} + +func setField(c *PluginConfig, fieldName string, value string, log wrapper.Log) { + switch fieldName { + case "cacheKeyFrom": + c.CacheKeyFrom = value + case "cacheValueFrom": + c.CacheValueFrom = value + case "cacheStreamValueFrom": + c.CacheStreamValueFrom = value + case "responseTemplate": + c.ResponseTemplate = value + case "streamResponseTemplate": + c.StreamResponseTemplate = value + } + log.Debugf("[setField] set %s to %s", fieldName, value) +} diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go new file mode 100644 index 0000000000..19a9b2b856 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -0,0 +1,275 @@ +package main + +import ( + "errors" + "fmt" + "strconv" + "strings" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/tidwall/resp" +) + +// CheckCacheForKey checks if the key is in the cache, or triggers similarity search if not found. +func CheckCacheForKey(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool, useSimilaritySearch bool) error { + activeCacheProvider := c.GetCacheProvider() + if activeCacheProvider == nil { + log.Debugf("[%s] [CheckCacheForKey] no cache provider configured, performing similarity search", PLUGIN_NAME) + return performSimilaritySearch(key, ctx, c, log, key, stream) + } + + queryKey := activeCacheProvider.GetCacheKeyPrefix() + key + log.Debugf("[%s] [CheckCacheForKey] querying cache with key: %s", PLUGIN_NAME, queryKey) + + err := activeCacheProvider.Get(queryKey, func(response resp.Value) { + handleCacheResponse(key, response, ctx, log, stream, c, useSimilaritySearch) + }) + + if err != nil { + log.Errorf("[%s] [CheckCacheForKey] failed to retrieve key: %s from cache, error: %v", PLUGIN_NAME, key, err) + return err + } + + return nil +} + +// handleCacheResponse processes cache response and handles cache hits and misses. +func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContext, log wrapper.Log, stream bool, c config.PluginConfig, useSimilaritySearch bool) { + if err := response.Error(); err == nil && !response.IsNull() { + log.Infof("[%s] cache hit for key: %s", PLUGIN_NAME, key) + processCacheHit(key, response.String(), stream, ctx, c, log) + return + } + + log.Infof("[%s] [handleCacheResponse] cache miss for key: %s", PLUGIN_NAME, key) + if err := response.Error(); err != nil { + log.Errorf("[%s] [handleCacheResponse] error retrieving key: %s from cache, error: %v", PLUGIN_NAME, key, err) + } + + if useSimilaritySearch && c.EnableSemanticCache { + if err := performSimilaritySearch(key, ctx, c, log, key, stream); err != nil { + log.Errorf("[%s] [handleCacheResponse] failed to perform similarity search for key: %s, error: %v", PLUGIN_NAME, key, err) + proxywasm.ResumeHttpRequest() + } + } else { + proxywasm.ResumeHttpRequest() + } +} + +// processCacheHit handles a successful cache hit. +func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) { + if strings.TrimSpace(response) == "" { + log.Warnf("[%s] [processCacheHit] cached response for key %s is empty", PLUGIN_NAME, key) + proxywasm.ResumeHttpRequest() + return + } + + log.Debugf("[%s] [processCacheHit] cached response for key %s: %s", PLUGIN_NAME, key, response) + + // Escape the response to ensure consistent formatting + escapedResponse := strings.Trim(strconv.Quote(response), "\"") + + ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) + + if stream { + proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, escapedResponse)), -1) + } else { + proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(c.ResponseTemplate, escapedResponse)), -1) + } +} + +// performSimilaritySearch determines the appropriate similarity search method to use. +func performSimilaritySearch(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, queryString string, stream bool) error { + activeVectorProvider := c.GetVectorProvider() + if activeVectorProvider == nil { + return logAndReturnError(log, "[performSimilaritySearch] no vector provider configured for similarity search") + } + + // Check if the active vector provider implements the StringQuerier interface. + if _, ok := activeVectorProvider.(vector.StringQuerier); ok { + log.Debugf("[%s] [performSimilaritySearch] active vector provider implements StringQuerier interface, performing string query", PLUGIN_NAME) + return performStringQuery(key, queryString, ctx, c, log, stream) + } + + // Check if the active vector provider implements the EmbeddingQuerier interface. + if _, ok := activeVectorProvider.(vector.EmbeddingQuerier); ok { + log.Debugf("[%s] [performSimilaritySearch] active vector provider implements EmbeddingQuerier interface, performing embedding query", PLUGIN_NAME) + return performEmbeddingQuery(key, ctx, c, log, stream) + } + + return logAndReturnError(log, "[performSimilaritySearch] no suitable querier or embedding provider available for similarity search") +} + +// performStringQuery executes the string-based similarity search. +func performStringQuery(key string, queryString string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool) error { + stringQuerier, ok := c.GetVectorProvider().(vector.StringQuerier) + if !ok { + return logAndReturnError(log, "[performStringQuery] active vector provider does not implement StringQuerier interface") + } + + return stringQuerier.QueryString(queryString, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) { + handleQueryResults(key, results, ctx, log, stream, c, err) + }) +} + +// performEmbeddingQuery executes the embedding-based similarity search. +func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool) error { + embeddingQuerier, ok := c.GetVectorProvider().(vector.EmbeddingQuerier) + if !ok { + return logAndReturnError(log, fmt.Sprintf("[performEmbeddingQuery] active vector provider does not implement EmbeddingQuerier interface")) + } + + activeEmbeddingProvider := c.GetEmbeddingProvider() + if activeEmbeddingProvider == nil { + return logAndReturnError(log, fmt.Sprintf("[performEmbeddingQuery] no embedding provider configured for similarity search")) + } + + return activeEmbeddingProvider.GetEmbedding(key, ctx, log, func(textEmbedding []float64, err error) { + log.Debugf("[%s] [performEmbeddingQuery] GetEmbedding success, length of embedding: %d, error: %v", PLUGIN_NAME, len(textEmbedding), err) + if err != nil { + handleInternalError(err, fmt.Sprintf("[%s] [performEmbeddingQuery] error getting embedding for key: %s", PLUGIN_NAME, key), log) + return + } + ctx.SetContext(CACHE_KEY_EMBEDDING_KEY, textEmbedding) + + err = embeddingQuerier.QueryEmbedding(textEmbedding, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) { + handleQueryResults(key, results, ctx, log, stream, c, err) + }) + if err != nil { + handleInternalError(err, fmt.Sprintf("[%s] [performEmbeddingQuery] error querying vector database for key: %s", PLUGIN_NAME, key), log) + } + }) +} + +// handleQueryResults processes the results of similarity search and determines next actions. +func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, stream bool, c config.PluginConfig, err error) { + if err != nil { + handleInternalError(err, fmt.Sprintf("[%s] [handleQueryResults] error querying vector database for key: %s", PLUGIN_NAME, key), log) + return + } + + if len(results) == 0 { + log.Warnf("[%s] [handleQueryResults] no similar keys found for key: %s", PLUGIN_NAME, key) + proxywasm.ResumeHttpRequest() + return + } + + mostSimilarData := results[0] + log.Debugf("[%s] [handleQueryResults] for key: %s, the most similar key found: %s with score: %f", PLUGIN_NAME, key, mostSimilarData.Text, mostSimilarData.Score) + simThreshold := c.GetVectorProviderConfig().Threshold + simThresholdRelation := c.GetVectorProviderConfig().ThresholdRelation + if compare(simThresholdRelation, mostSimilarData.Score, simThreshold) { + log.Infof("[%s] key accepted: %s with score: %f", PLUGIN_NAME, mostSimilarData.Text, mostSimilarData.Score) + if mostSimilarData.Answer != "" { + // direct return the answer if available + cacheResponse(ctx, c, key, mostSimilarData.Answer, log) + processCacheHit(key, mostSimilarData.Answer, stream, ctx, c, log) + } else { + if c.GetCacheProvider() != nil { + CheckCacheForKey(mostSimilarData.Text, ctx, c, log, stream, false) + } else { + // Otherwise, do not check the cache, directly return + log.Infof("[%s] cache hit for key: %s, but no corresponding answer found in the vector database", PLUGIN_NAME, mostSimilarData.Text) + proxywasm.ResumeHttpRequest() + } + } + } else { + log.Infof("[%s] score not meet the threshold %f: %s with score %f", PLUGIN_NAME, simThreshold, mostSimilarData.Text, mostSimilarData.Score) + proxywasm.ResumeHttpRequest() + } +} + +// logAndReturnError logs an error and returns it. +func logAndReturnError(log wrapper.Log, message string) error { + message = fmt.Sprintf("[%s] %s", PLUGIN_NAME, message) + log.Errorf(message) + return errors.New(message) +} + +// handleInternalError logs an error and resumes the HTTP request. +func handleInternalError(err error, message string, log wrapper.Log) { + if err != nil { + log.Errorf("[%s] [handleInternalError] %s: %v", PLUGIN_NAME, message, err) + } else { + log.Errorf("[%s] [handleInternalError] %s", PLUGIN_NAME, message) + } + // proxywasm.SendHttpResponse(500, [][2]string{{"content-type", "text/plain"}}, []byte("Internal Server Error"), -1) + proxywasm.ResumeHttpRequest() +} + +// Caches the response value +func cacheResponse(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log wrapper.Log) { + if strings.TrimSpace(value) == "" { + log.Warnf("[%s] [cacheResponse] cached value for key %s is empty", PLUGIN_NAME, key) + return + } + + activeCacheProvider := c.GetCacheProvider() + if activeCacheProvider != nil { + queryKey := activeCacheProvider.GetCacheKeyPrefix() + key + _ = activeCacheProvider.Set(queryKey, value, nil) + log.Debugf("[%s] [cacheResponse] cache set success, key: %s, length of value: %d", PLUGIN_NAME, queryKey, len(value)) + } +} + +// Handles embedding upload if available +func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log wrapper.Log) { + embedding := ctx.GetContext(CACHE_KEY_EMBEDDING_KEY) + if embedding == nil { + return + } + + emb, ok := embedding.([]float64) + if !ok { + log.Errorf("[%s] [uploadEmbeddingAndAnswer] embedding is not of expected type []float64", PLUGIN_NAME) + return + } + + activeVectorProvider := c.GetVectorProvider() + if activeVectorProvider == nil { + log.Debugf("[%s] [uploadEmbeddingAndAnswer] no vector provider configured for uploading embedding", PLUGIN_NAME) + return + } + + // Attempt to upload answer embedding first + if ansEmbUploader, ok := activeVectorProvider.(vector.AnswerAndEmbeddingUploader); ok { + log.Infof("[%s] uploading answer embedding for key: %s", PLUGIN_NAME, key) + err := ansEmbUploader.UploadAnswerAndEmbedding(key, emb, value, ctx, log, nil) + if err != nil { + log.Warnf("[%s] [uploadEmbeddingAndAnswer] failed to upload answer embedding for key: %s, error: %v", PLUGIN_NAME, key, err) + } else { + return // If successful, return early + } + } + + // If answer embedding upload fails, attempt normal embedding upload + if embUploader, ok := activeVectorProvider.(vector.EmbeddingUploader); ok { + log.Infof("[%s] uploading embedding for key: %s", PLUGIN_NAME, key) + err := embUploader.UploadEmbedding(key, emb, ctx, log, nil) + if err != nil { + log.Warnf("[%s] [uploadEmbeddingAndAnswer] failed to upload embedding for key: %s, error: %v", PLUGIN_NAME, key, err) + } + } +} + +// 主要用于相似度/距离/点积判断 +// 余弦相似度度量的是两个向量在方向上的相似程度。相似度越高,两个向量越接近。 +// 距离度量的是两个向量在空间上的远近程度。距离越小,两个向量越接近。 +// compare 函数根据操作符进行判断并返回结果 +func compare(operator string, value1 float64, value2 float64) bool { + switch operator { + case "gt": + return value1 > value2 + case "gte": + return value1 >= value2 + case "lt": + return value1 < value2 + case "lte": + return value1 <= value2 + default: + return false + } +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go new file mode 100644 index 0000000000..35c897cce5 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -0,0 +1,187 @@ +package embedding + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +const ( + DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com" + DASHSCOPE_PORT = 443 + DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v2" + DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding" +) + +type dashScopeProviderInitializer struct { +} + +func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) error { + if config.apiKey == "" { + return errors.New("[DashScope] apiKey is required") + } + return nil +} + +func (d *dashScopeProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) { + if c.servicePort == 0 { + c.servicePort = DASHSCOPE_PORT + } + if c.serviceHost == "" { + c.serviceHost = DASHSCOPE_DOMAIN + } + return &DSProvider{ + config: c, + client: wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: c.serviceName, + Host: c.serviceHost, + Port: int64(c.servicePort), + }), + }, nil +} + +func (d *DSProvider) GetProviderType() string { + return PROVIDER_TYPE_DASHSCOPE +} + +type Embedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type Input struct { + Texts []string `json:"texts"` +} + +type Params struct { + TextType string `json:"text_type"` +} + +type Response struct { + RequestID string `json:"request_id"` + Output Output `json:"output"` + Usage Usage `json:"usage"` +} + +type Output struct { + Embeddings []Embedding `json:"embeddings"` +} + +type Usage struct { + TotalTokens int `json:"total_tokens"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input Input `json:"input"` + Parameters Params `json:"parameters"` +} + +type Document struct { + Vector []float64 `json:"vector"` + Fields map[string]string `json:"fields"` +} + +type DSProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) { + + model := d.config.model + + if model == "" { + model = DASHSCOPE_DEFAULT_MODEL_NAME + } + data := EmbeddingRequest{ + Model: model, + Input: Input{ + Texts: texts, + }, + Parameters: Params{ + TextType: "query", + }, + } + + requestBody, err := json.Marshal(data) + if err != nil { + log.Errorf("failed to marshal request data: %v", err) + return "", nil, nil, err + } + + if d.config.apiKey == "" { + err := errors.New("dashScopeKey is empty") + log.Errorf("failed to construct headers: %v", err) + return "", nil, nil, err + } + + headers := [][2]string{ + {"Authorization", "Bearer " + d.config.apiKey}, + {"Content-Type", "application/json"}, + } + + return DASHSCOPE_ENDPOINT, headers, requestBody, err +} + +type Result struct { + ID string `json:"id"` + Vector []float64 `json:"vector,omitempty"` + Fields map[string]interface{} `json:"fields"` + Score float64 `json:"score"` +} + +func (d *DSProvider) parseTextEmbedding(responseBody []byte) (*Response, error) { + var resp Response + err := json.Unmarshal(responseBody, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +func (d *DSProvider) GetEmbedding( + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(emb []float64, err error)) error { + embUrl, embHeaders, embRequestBody, err := d.constructParameters([]string{queryString}, log) + if err != nil { + log.Errorf("failed to construct parameters: %v", err) + return err + } + + var resp *Response + err = d.client.Post(embUrl, embHeaders, embRequestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + + if statusCode != http.StatusOK { + err = errors.New("failed to get embedding due to status code: " + strconv.Itoa(statusCode)) + callback(nil, err) + return + } + + log.Debugf("get embedding response: %d, %s", statusCode, responseBody) + + resp, err = d.parseTextEmbedding(responseBody) + if err != nil { + err = fmt.Errorf("failed to parse response: %v", err) + callback(nil, err) + return + } + + if len(resp.Output.Embeddings) == 0 { + err = errors.New("no embedding found in response") + callback(nil, err) + return + } + + callback(resp.Output.Embeddings[0].Embedding, nil) + + }, d.config.timeout) + return err +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go new file mode 100644 index 0000000000..909edf129c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -0,0 +1,101 @@ +package embedding + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + PROVIDER_TYPE_DASHSCOPE = "dashscope" +) + +type providerInitializer interface { + ValidateConfig(ProviderConfig) error + CreateProvider(ProviderConfig) (Provider, error) +} + +var ( + providerInitializers = map[string]providerInitializer{ + PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{}, + } +) + +type ProviderConfig struct { + // @Title zh-CN 文本特征提取服务提供者类型 + // @Description zh-CN 文本特征提取服务提供者类型,例如 DashScope + typ string + // @Title zh-CN DashScope 文本特征提取服务名称 + // @Description zh-CN 文本特征提取服务名称 + serviceName string + // @Title zh-CN 文本特征提取服务域名 + // @Description zh-CN 文本特征提取服务域名 + serviceHost string + // @Title zh-CN 文本特征提取服务端口 + // @Description zh-CN 文本特征提取服务端口 + servicePort int64 + // @Title zh-CN 文本特征提取服务 API Key + // @Description zh-CN 文本特征提取服务 API Key + apiKey string + // @Title zh-CN 文本特征提取服务超时时间 + // @Description zh-CN 文本特征提取服务超时时间 + timeout uint32 + // @Title zh-CN 文本特征提取服务使用的模型 + // @Description zh-CN 用于文本特征提取的模型名称, 在 DashScope 中默认为 "text-embedding-v1" + model string +} + +func (c *ProviderConfig) FromJson(json gjson.Result) { + c.typ = json.Get("type").String() + c.serviceName = json.Get("serviceName").String() + c.serviceHost = json.Get("serviceHost").String() + c.servicePort = json.Get("servicePort").Int() + c.apiKey = json.Get("apiKey").String() + c.timeout = uint32(json.Get("timeout").Int()) + c.model = json.Get("model").String() + if c.timeout == 0 { + c.timeout = 10000 + } +} + +func (c *ProviderConfig) Validate() error { + if c.serviceName == "" { + return errors.New("embedding service name is required") + } + if c.apiKey == "" { + return errors.New("embedding service API key is required") + } + if c.typ == "" { + return errors.New("embedding service type is required") + } + initializer, has := providerInitializers[c.typ] + if !has { + return errors.New("unknown embedding service provider type: " + c.typ) + } + if err := initializer.ValidateConfig(*c); err != nil { + return err + } + return nil +} + +func (c *ProviderConfig) GetProviderType() string { + return c.typ +} + +func CreateProvider(pc ProviderConfig) (Provider, error) { + initializer, has := providerInitializers[pc.typ] + if !has { + return nil, errors.New("unknown provider type: " + pc.typ) + } + return initializer.CreateProvider(pc) +} + +type Provider interface { + GetProviderType() string + GetEmbedding( + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(emb []float64, err error)) error +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go b/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go new file mode 100644 index 0000000000..b26d9cea8d --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go @@ -0,0 +1,27 @@ +package embedding + +// import ( +// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +// ) + +// const ( +// weaviateURL = "172.17.0.1:8081" +// ) + +// type weaviateProviderInitializer struct { +// } + +// func (d *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error { +// return nil +// } + +// func (d *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { +// return &DSProvider{ +// config: config, +// client: wrapper.NewClusterClient(wrapper.DnsCluster{ +// ServiceName: config.ServiceName, +// Port: dashScopePort, +// Domain: dashScopeDomain, +// }), +// }, nil +// } diff --git a/plugins/wasm-go/extensions/ai-cache/go.mod b/plugins/wasm-go/extensions/ai-cache/go.mod index c9630cfb8a..e4aae265e0 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.mod +++ b/plugins/wasm-go/extensions/ai-cache/go.mod @@ -7,17 +7,18 @@ go 1.19 replace github.com/alibaba/higress/plugins/wasm-go => ../.. require ( - github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240528060522-53bccf89f441 + github.com/alibaba/higress/plugins/wasm-go v1.4.2 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.17.3 github.com/tidwall/resp v0.1.1 - github.com/tidwall/sjson v1.2.5 +// github.com/weaviate/weaviate-go-client/v4 v4.15.1 ) require ( - github.com/google/uuid v1.3.0 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect github.com/magefile/mage v1.14.0 // indirect + github.com/stretchr/testify v1.9.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect ) diff --git a/plugins/wasm-go/extensions/ai-cache/go.sum b/plugins/wasm-go/extensions/ai-cache/go.sum index 8246b4de5e..7ada0c8b70 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.sum +++ b/plugins/wasm-go/extensions/ai-cache/go.sum @@ -1,24 +1,21 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= -github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= -github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 7886d5698f..1aca29f0ec 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -1,33 +1,33 @@ -// File generated by hgctl. Modify as required. -// See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6 - +// 这个文件中主要将OnHttpRequestHeaders、OnHttpRequestBody、OnHttpResponseHeaders、OnHttpResponseBody这四个函数实现 +// 其中的缓存思路调用cache.go中的逻辑,然后cache.go中的逻辑会调用textEmbeddingProvider和vectorStoreProvider中的逻辑(实例) package main import ( - "errors" - "fmt" "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" - "github.com/tidwall/resp" ) const ( - CacheKeyContextKey = "cacheKey" - CacheContentContextKey = "cacheContent" - PartialMessageContextKey = "partialMessage" - ToolCallsContextKey = "toolCalls" - StreamContextKey = "stream" - DefaultCacheKeyPrefix = "higress-ai-cache:" - SkipCacheHeader = "x-higress-skip-ai-cache" + PLUGIN_NAME = "ai-cache" + CACHE_KEY_CONTEXT_KEY = "cacheKey" + CACHE_KEY_EMBEDDING_KEY = "cacheKeyEmbedding" + CACHE_CONTENT_CONTEXT_KEY = "cacheContent" + PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage" + TOOL_CALLS_CONTEXT_KEY = "toolCalls" + STREAM_CONTEXT_KEY = "stream" + SKIP_CACHE_HEADER = "x-higress-skip-ai-cache" + ERROR_PARTIAL_MESSAGE_KEY = "errorPartialMessage" ) func main() { + // CreateClient() wrapper.SetCtx( - "ai-cache", + PLUGIN_NAME, wrapper.ParseConfigBy(parseConfig), wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestBodyBy(onHttpRequestBody), @@ -36,146 +36,26 @@ func main() { ) } -// @Name ai-cache -// @Category protocol -// @Phase AUTHN -// @Priority 10 -// @Title zh-CN AI Cache -// @Description zh-CN 大模型结果缓存 -// @IconUrl -// @Version 0.1.0 -// -// @Contact.name johnlanni -// @Contact.url -// @Contact.email -// -// @Example -// redis: -// serviceName: my-redis.dns -// timeout: 2000 -// cacheKeyFrom: -// requestBody: "messages.@reverse.0.content" -// cacheValueFrom: -// responseBody: "choices.0.message.content" -// cacheStreamValueFrom: -// responseBody: "choices.0.delta.content" -// returnResponseTemplate: | -// {"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}} -// returnStreamResponseTemplate: | -// data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}} -// -// data:[DONE] -// -// @End - -type RedisInfo struct { - // @Title zh-CN redis 服务名称 - // @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local - ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"` - // @Title zh-CN redis 服务端口 - // @Description zh-CN 默认值为6379 - ServicePort int `required:"false" yaml:"servicePort" json:"servicePort"` - // @Title zh-CN 用户名 - // @Description zh-CN 登陆 redis 的用户名,非必填 - Username string `required:"false" yaml:"username" json:"username"` - // @Title zh-CN 密码 - // @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码 - Password string `required:"false" yaml:"password" json:"password"` - // @Title zh-CN 请求超时 - // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 - Timeout int `required:"false" yaml:"timeout" json:"timeout"` -} - -type KVExtractor struct { - // @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 - RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"` - // @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 - ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"` -} - -type PluginConfig struct { - // @Title zh-CN Redis 地址信息 - // @Description zh-CN 用于存储缓存结果的 Redis 地址 - RedisInfo RedisInfo `required:"true" yaml:"redis" json:"redis"` - // @Title zh-CN 缓存 key 的来源 - // @Description zh-CN 往 redis 里存时,使用的 key 的提取方式 - CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` - // @Title zh-CN 缓存 value 的来源 - // @Description zh-CN 往 redis 里存时,使用的 value 的提取方式 - CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` - // @Title zh-CN 流式响应下,缓存 value 的来源 - // @Description zh-CN 往 redis 里存时,使用的 value 的提取方式 - CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` - // @Title zh-CN 返回 HTTP 响应的模版 - // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 - ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"` - // @Title zh-CN 返回流式 HTTP 响应的模版 - // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 - ReturnStreamResponseTemplate string `required:"true" yaml:"returnStreamResponseTemplate" json:"returnStreamResponseTemplate"` - // @Title zh-CN 缓存的过期时间 - // @Description zh-CN 单位是秒,默认值为0,即永不过期 - CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"` - // @Title zh-CN Redis缓存Key的前缀 - // @Description zh-CN 默认值是"higress-ai-cache:" - CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"` - redisClient wrapper.RedisClient `yaml:"-" json:"-"` -} - -func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error { - c.RedisInfo.ServiceName = json.Get("redis.serviceName").String() - if c.RedisInfo.ServiceName == "" { - return errors.New("redis service name must not by empty") - } - c.RedisInfo.ServicePort = int(json.Get("redis.servicePort").Int()) - if c.RedisInfo.ServicePort == 0 { - if strings.HasSuffix(c.RedisInfo.ServiceName, ".static") { - // use default logic port which is 80 for static service - c.RedisInfo.ServicePort = 80 - } else { - c.RedisInfo.ServicePort = 6379 - } - } - c.RedisInfo.Username = json.Get("redis.username").String() - c.RedisInfo.Password = json.Get("redis.password").String() - c.RedisInfo.Timeout = int(json.Get("redis.timeout").Int()) - if c.RedisInfo.Timeout == 0 { - c.RedisInfo.Timeout = 1000 - } - c.CacheKeyFrom.RequestBody = json.Get("cacheKeyFrom.requestBody").String() - if c.CacheKeyFrom.RequestBody == "" { - c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" - } - c.CacheValueFrom.ResponseBody = json.Get("cacheValueFrom.responseBody").String() - if c.CacheValueFrom.ResponseBody == "" { - c.CacheValueFrom.ResponseBody = "choices.0.message.content" - } - c.CacheStreamValueFrom.ResponseBody = json.Get("cacheStreamValueFrom.responseBody").String() - if c.CacheStreamValueFrom.ResponseBody == "" { - c.CacheStreamValueFrom.ResponseBody = "choices.0.delta.content" - } - c.ReturnResponseTemplate = json.Get("returnResponseTemplate").String() - if c.ReturnResponseTemplate == "" { - c.ReturnResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` - } - c.ReturnStreamResponseTemplate = json.Get("returnStreamResponseTemplate").String() - if c.ReturnStreamResponseTemplate == "" { - c.ReturnStreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" +func parseConfig(json gjson.Result, c *config.PluginConfig, log wrapper.Log) error { + // config.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) + // config.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider")) + // config.RedisConfig.FromJson(json.Get("redis")) + c.FromJson(json, log) + if err := c.Validate(); err != nil { + return err } - c.CacheKeyPrefix = json.Get("cacheKeyPrefix").String() - if c.CacheKeyPrefix == "" { - c.CacheKeyPrefix = DefaultCacheKeyPrefix + // Note that initializing the client during the parseConfig phase may cause errors, such as Redis not being usable in Docker Compose. + if err := c.Complete(log); err != nil { + log.Errorf("complete config failed: %v", err) + return err } - c.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ - FQDN: c.RedisInfo.ServiceName, - Port: int64(c.RedisInfo.ServicePort), - }) - return c.redisClient.Init(c.RedisInfo.Username, c.RedisInfo.Password, int64(c.RedisInfo.Timeout)) + return nil } -func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { - skipCache, _ := proxywasm.GetHttpRequestHeader(SkipCacheHeader) +func onHttpRequestHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action { + skipCache, _ := proxywasm.GetHttpRequestHeader(SKIP_CACHE_HEADER) if skipCache == "on" { - ctx.SetContext(SkipCacheHeader, struct{}{}) + ctx.SetContext(SKIP_CACHE_HEADER, struct{}{}) ctx.DontReadRequestBody() return types.ActionContinue } @@ -185,199 +65,123 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrap return types.ActionContinue } if !strings.Contains(contentType, "application/json") { - log.Warnf("content is not json, can't process:%s", contentType) + log.Warnf("content is not json, can't process: %s", contentType) ctx.DontReadRequestBody() return types.ActionContinue } - proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") // The request has a body and requires delaying the header transmission until a cache miss occurs, // at which point the header should be sent. return types.HeaderStopIteration } -func TrimQuote(source string) string { - return strings.Trim(source, `"`) -} +func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []byte, log wrapper.Log) types.Action { -func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action { bodyJson := gjson.ParseBytes(body) // TODO: It may be necessary to support stream mode determination for different LLM providers. stream := false if bodyJson.Get("stream").Bool() { stream = true - ctx.SetContext(StreamContextKey, struct{}{}) - } else if ctx.GetContext(StreamContextKey) != nil { - stream = true + ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) + } + + var key string + if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_LAST_QUESTION { + log.Debugf("[onHttpRequestBody] cache key strategy is last question, cache key from: %s", c.CacheKeyFrom) + key = bodyJson.Get(c.CacheKeyFrom).String() + } else if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_ALL_QUESTIONS { + log.Debugf("[onHttpRequestBody] cache key strategy is all questions, cache key from: messages") + messages := bodyJson.Get("messages").Array() + var userMessages []string + for _, msg := range messages { + if msg.Get("role").String() == "user" { + userMessages = append(userMessages, msg.Get("content").String()) + } + } + key = strings.Join(userMessages, "\n") + } else if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_DISABLED { + log.Info("[onHttpRequestBody] cache key strategy is disabled") + ctx.DontReadRequestBody() + return types.ActionContinue + } else { + log.Warnf("[onHttpRequestBody] unknown cache key strategy: %s", c.CacheKeyStrategy) + ctx.DontReadRequestBody() + return types.ActionContinue } - key := TrimQuote(bodyJson.Get(config.CacheKeyFrom.RequestBody).Raw) + + ctx.SetContext(CACHE_KEY_CONTEXT_KEY, key) + log.Debugf("[onHttpRequestBody] key: %s", key) if key == "" { - log.Debug("parse key from request body failed") + log.Debug("[onHttpRequestBody] parse key from request body failed") + ctx.DontReadResponseBody() return types.ActionContinue } - ctx.SetContext(CacheKeyContextKey, key) - err := config.redisClient.Get(config.CacheKeyPrefix+key, func(response resp.Value) { - if err := response.Error(); err != nil { - log.Errorf("redis get key:%s failed, err:%v", key, err) - proxywasm.ResumeHttpRequest() - return - } - if response.IsNull() { - log.Debugf("cache miss, key:%s", key) - proxywasm.ResumeHttpRequest() - return - } - log.Debugf("cache hit, key:%s", key) - ctx.SetContext(CacheKeyContextKey, nil) - if !stream { - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, response.String())), -1) - } else { - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, response.String())), -1) - } - }) - if err != nil { - log.Error("redis access failed") + + if err := CheckCacheForKey(key, ctx, c, log, stream, true); err != nil { + log.Errorf("[onHttpRequestBody] check cache for key: %s failed, error: %v", key, err) return types.ActionContinue } - return types.ActionPause -} -func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string { - subMessages := strings.Split(sseMessage, "\n") - var message string - for _, msg := range subMessages { - if strings.HasPrefix(msg, "data:") { - message = msg - break - } - } - if len(message) < 6 { - log.Errorf("invalid message:%s", message) - return "" - } - // skip the prefix "data:" - bodyJson := message[5:] - if gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Exists() { - tempContentI := ctx.GetContext(CacheContentContextKey) - if tempContentI == nil { - content := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) - ctx.SetContext(CacheContentContextKey, content) - return content - } - append := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) - content := tempContentI.(string) + append - ctx.SetContext(CacheContentContextKey, content) - return content - } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { - // TODO: compatible with other providers - ctx.SetContext(ToolCallsContextKey, struct{}{}) - return "" - } - log.Debugf("unknown message:%s", bodyJson) - return "" + return types.ActionPause } -func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { - skipCache := ctx.GetContext(SkipCacheHeader) +func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action { + skipCache := ctx.GetContext(SKIP_CACHE_HEADER) if skipCache != nil { ctx.DontReadResponseBody() return types.ActionContinue } contentType, _ := proxywasm.GetHttpResponseHeader("content-type") if strings.Contains(contentType, "text/event-stream") { - ctx.SetContext(StreamContextKey, struct{}{}) + ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) } + + if ctx.GetContext(ERROR_PARTIAL_MESSAGE_KEY) != nil { + ctx.DontReadResponseBody() + return types.ActionContinue + } + return types.ActionContinue } -func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { - if ctx.GetContext(ToolCallsContextKey) != nil { - // we should not cache tool call result +func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { + log.Debugf("[onHttpResponseBody] is last chunk: %v", isLastChunk) + log.Debugf("[onHttpResponseBody] chunk: %s", string(chunk)) + + if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { return chunk } - keyI := ctx.GetContext(CacheKeyContextKey) - if keyI == nil { + + key := ctx.GetContext(CACHE_KEY_CONTEXT_KEY) + if key == nil { + log.Debug("[onHttpResponseBody] key is nil, skip cache") return chunk } + if !isLastChunk { - stream := ctx.GetContext(StreamContextKey) - if stream == nil { - tempContentI := ctx.GetContext(CacheContentContextKey) - if tempContentI == nil { - ctx.SetContext(CacheContentContextKey, chunk) - return chunk - } - tempContent := tempContentI.([]byte) - tempContent = append(tempContent, chunk...) - ctx.SetContext(CacheContentContextKey, tempContent) - } else { - var partialMessage []byte - partialMessageI := ctx.GetContext(PartialMessageContextKey) - if partialMessageI != nil { - partialMessage = append(partialMessageI.([]byte), chunk...) - } else { - partialMessage = chunk - } - messages := strings.Split(string(partialMessage), "\n\n") - for i, msg := range messages { - if i < len(messages)-1 { - // process complete message - processSSEMessage(ctx, config, msg, log) - } - } - if !strings.HasSuffix(string(partialMessage), "\n\n") { - ctx.SetContext(PartialMessageContextKey, []byte(messages[len(messages)-1])) - } else { - ctx.SetContext(PartialMessageContextKey, nil) - } + if err := handleNonLastChunk(ctx, c, chunk, log); err != nil { + log.Errorf("[onHttpResponseBody] handle non last chunk failed, error: %v", err) + // Set an empty struct in the context to indicate an error in processing the partial message + ctx.SetContext(ERROR_PARTIAL_MESSAGE_KEY, struct{}{}) } return chunk } - // last chunk - key := keyI.(string) - stream := ctx.GetContext(StreamContextKey) + + stream := ctx.GetContext(STREAM_CONTEXT_KEY) var value string + var err error if stream == nil { - var body []byte - tempContentI := ctx.GetContext(CacheContentContextKey) - if tempContentI != nil { - body = append(tempContentI.([]byte), chunk...) - } else { - body = chunk - } - bodyJson := gjson.ParseBytes(body) - - value = TrimQuote(bodyJson.Get(config.CacheValueFrom.ResponseBody).Raw) - if value == "" { - log.Warnf("parse value from response body failded, body:%s", body) - return chunk - } + value, err = processNonStreamLastChunk(ctx, c, chunk, log) } else { - if len(chunk) > 0 { - var lastMessage []byte - partialMessageI := ctx.GetContext(PartialMessageContextKey) - if partialMessageI != nil { - lastMessage = append(partialMessageI.([]byte), chunk...) - } else { - lastMessage = chunk - } - if !strings.HasSuffix(string(lastMessage), "\n\n") { - log.Warnf("invalid lastMessage:%s", lastMessage) - return chunk - } - // remove the last \n\n - lastMessage = lastMessage[:len(lastMessage)-2] - value = processSSEMessage(ctx, config, string(lastMessage), log) - } else { - tempContentI := ctx.GetContext(CacheContentContextKey) - if tempContentI == nil { - return chunk - } - value = tempContentI.(string) - } + value, err = processStreamLastChunk(ctx, c, chunk, log) } - config.redisClient.Set(config.CacheKeyPrefix+key, value, nil) - if config.CacheTTL != 0 { - config.redisClient.Expire(config.CacheKeyPrefix+key, config.CacheTTL, nil) + + if err != nil { + log.Errorf("[onHttpResponseBody] process last chunk failed, error: %v", err) + return chunk } + + cacheResponse(ctx, c, key.(string), value, log) + uploadEmbeddingAndAnswer(ctx, c, key.(string), value, log) return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go new file mode 100644 index 0000000000..983dfbb25a --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -0,0 +1,155 @@ +package main + +import ( + "fmt" + "strings" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +func handleNonLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { + stream := ctx.GetContext(STREAM_CONTEXT_KEY) + err := error(nil) + if stream == nil { + err = handleNonStreamChunk(ctx, c, chunk, log) + } else { + err = handleStreamChunk(ctx, c, chunk, log) + } + return err +} + +func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI == nil { + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, chunk) + return nil + } + tempContent := tempContentI.([]byte) + tempContent = append(tempContent, chunk...) + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContent) + return nil +} + +func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { + var partialMessage []byte + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) + log.Debugf("[handleStreamChunk] cache content: %v", ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)) + if partialMessageI != nil { + partialMessage = append(partialMessageI.([]byte), chunk...) + } else { + partialMessage = chunk + } + messages := strings.Split(string(partialMessage), "\n\n") + for i, msg := range messages { + if i < len(messages)-1 { + _, err := processSSEMessage(ctx, c, msg, log) + if err != nil { + return fmt.Errorf("[handleStreamChunk] processSSEMessage failed, error: %v", err) + } + } + } + if !strings.HasSuffix(string(partialMessage), "\n\n") { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1])) + } else { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) + } + return nil +} + +func processNonStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { + var body []byte + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI != nil { + body = append(tempContentI.([]byte), chunk...) + } else { + body = chunk + } + bodyJson := gjson.ParseBytes(body) + value := bodyJson.Get(c.CacheValueFrom).String() + if strings.TrimSpace(value) == "" { + return "", fmt.Errorf("[processNonStreamLastChunk] parse value from response body failed, body:%s", body) + } + return value, nil +} + +func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { + if len(chunk) > 0 { + var lastMessage []byte + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) + if partialMessageI != nil { + lastMessage = append(partialMessageI.([]byte), chunk...) + } else { + lastMessage = chunk + } + if !strings.HasSuffix(string(lastMessage), "\n\n") { + return "", fmt.Errorf("[processStreamLastChunk] invalid lastMessage:%s", lastMessage) + } + lastMessage = lastMessage[:len(lastMessage)-2] + value, err := processSSEMessage(ctx, c, string(lastMessage), log) + if err != nil { + return "", fmt.Errorf("[processStreamLastChunk] processSSEMessage failed, error: %v", err) + } + return value, nil + } + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI == nil { + return "", nil + } + return tempContentI.(string), nil +} + +func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) { + subMessages := strings.Split(sseMessage, "\n") + var message string + for _, msg := range subMessages { + if strings.HasPrefix(msg, "data:") { + message = msg + break + } + } + if len(message) < 6 { + return "", fmt.Errorf("[processSSEMessage] invalid message: %s", message) + } + + // skip the prefix "data:" + bodyJson := message[5:] + + if strings.TrimSpace(bodyJson) == "[DONE]" { + return "", nil + } + + // Extract values from JSON fields + responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom) + toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom) + + if toolCalls.Exists() { + // TODO: Temporarily store the tool_calls value in the context for processing + ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String()) + } + + // Check if the ResponseBody field exists + if !responseBody.Exists() { + if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil { + log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message) + return "", nil + } + return "", fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message) + } else { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + + // If there is no content in the cache, initialize and set the content + if tempContentI == nil { + content := responseBody.String() + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) + return content, nil + } + + // Update the content in the cache + appendMsg := responseBody.String() + content := tempContentI.(string) + appendMsg + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) + return content, nil + } +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go new file mode 100644 index 0000000000..7bdb0a76d0 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -0,0 +1,256 @@ +package vector + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +type dashVectorProviderInitializer struct { +} + +func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.apiKey) == 0 { + return errors.New("[DashVector] apiKey is required") + } + if len(config.collectionID) == 0 { + return errors.New("[DashVector] collectionID is required") + } + if len(config.serviceName) == 0 { + return errors.New("[DashVector] serviceName is required") + } + if len(config.serviceHost) == 0 { + return errors.New("[DashVector] serviceHost is required") + } + return nil +} + +func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &DvProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: config.serviceName, + Host: config.serviceHost, + Port: int64(config.servicePort), + }), + }, nil +} + +type DvProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (d *DvProvider) GetProviderType() string { + return PROVIDER_TYPE_DASH_VECTOR +} + +// type embeddingRequest struct { +// Model string `json:"model"` +// Input input `json:"input"` +// Parameters params `json:"parameters"` +// } + +// type params struct { +// TextType string `json:"text_type"` +// } + +// type input struct { +// Texts []string `json:"texts"` +// } + +// queryResponse 定义查询响应的结构 +type queryResponse struct { + Code int `json:"code"` + RequestID string `json:"request_id"` + Message string `json:"message"` + Output []result `json:"output"` +} + +// queryRequest 定义查询请求的结构 +type queryRequest struct { + Vector []float64 `json:"vector"` + TopK int `json:"topk"` + IncludeVector bool `json:"include_vector"` +} + +// result 定义查询结果的结构 +type result struct { + ID string `json:"id"` + Vector []float64 `json:"vector,omitempty"` // omitempty 使得如果 vector 是空,它将不会被序列化 + Fields map[string]interface{} `json:"fields"` + Score float64 `json:"score"` +} + +func (d *DvProvider) constructEmbeddingQueryParameters(vector []float64) (string, []byte, [][2]string, error) { + url := fmt.Sprintf("/v1/collections/%s/query", d.config.collectionID) + + requestData := queryRequest{ + Vector: vector, + TopK: d.config.topK, + IncludeVector: false, + } + + requestBody, err := json.Marshal(requestData) + if err != nil { + return "", nil, nil, err + } + + header := [][2]string{ + {"Content-Type", "application/json"}, + {"dashvector-auth-token", d.config.apiKey}, + } + + return url, requestBody, header, nil +} + +func (d *DvProvider) parseQueryResponse(responseBody []byte) (queryResponse, error) { + var queryResp queryResponse + err := json.Unmarshal(responseBody, &queryResp) + if err != nil { + return queryResponse{}, err + } + return queryResp, nil +} + +func (d *DvProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + url, body, headers, err := d.constructEmbeddingQueryParameters(emb) + log.Debugf("url:%s, body:%s, headers:%v", url, string(body), headers) + if err != nil { + err = fmt.Errorf("failed to construct embedding query parameters: %v", err) + return err + } + + err = d.client.Post(url, headers, body, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + err = nil + if statusCode != http.StatusOK { + err = fmt.Errorf("failed to query embedding: %d", statusCode) + callback(nil, ctx, log, err) + return + } + log.Debugf("query embedding response: %d, %s", statusCode, responseBody) + results, err := d.ParseQueryResponse(responseBody, ctx, log) + if err != nil { + err = fmt.Errorf("failed to parse query response: %v", err) + } + callback(results, ctx, log, err) + }, + d.config.timeout) + if err != nil { + err = fmt.Errorf("failed to query embedding: %v", err) + } + return err +} + +func getStringValue(fields map[string]interface{}, key string) string { + if val, ok := fields[key]; ok { + return val.(string) + } + return "" +} + +func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryResult, error) { + resp, err := d.parseQueryResponse(responseBody) + if err != nil { + return nil, err + } + + if len(resp.Output) == 0 { + return nil, errors.New("no query results found in response") + } + + results := make([]QueryResult, 0, len(resp.Output)) + + for _, output := range resp.Output { + result := QueryResult{ + Text: getStringValue(output.Fields, "query"), + Embedding: output.Vector, + Score: output.Score, + Answer: getStringValue(output.Fields, "answer"), + } + results = append(results, result) + } + + return results, nil +} + +type document struct { + Vector []float64 `json:"vector"` + Fields map[string]string `json:"fields"` +} + +type insertRequest struct { + Docs []document `json:"docs"` +} + +func (d *DvProvider) constructUploadParameters(emb []float64, queryString string, answer string) (string, []byte, [][2]string, error) { + url := "/v1/collections/" + d.config.collectionID + "/docs" + + doc := document{ + Vector: emb, + Fields: map[string]string{ + "query": queryString, + "answer": answer, + }, + } + + requestBody, err := json.Marshal(insertRequest{Docs: []document{doc}}) + if err != nil { + return "", nil, nil, err + } + + header := [][2]string{ + {"Content-Type", "application/json"}, + {"dashvector-auth-token", d.config.apiKey}, + } + + return url, requestBody, header, err +} + +func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, "") + if err != nil { + return err + } + err = d.client.Post( + url, + headers, + body, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Debugf("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + if statusCode != http.StatusOK { + err = fmt.Errorf("failed to upload embedding: %d", statusCode) + } + callback(ctx, log, err) + }, + d.config.timeout) + return err +} + +func (d *DvProvider) UploadAnswerAndEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, queryAnswer) + if err != nil { + return err + } + err = d.client.Post( + url, + headers, + body, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Debugf("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + if statusCode != http.StatusOK { + err = fmt.Errorf("failed to upload embedding: %d", statusCode) + } + callback(ctx, log, err) + }, + d.config.timeout) + return err +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go new file mode 100644 index 0000000000..a04123a166 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -0,0 +1,167 @@ +package vector + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + PROVIDER_TYPE_DASH_VECTOR = "dashvector" + PROVIDER_TYPE_CHROMA = "chroma" +) + +type providerInitializer interface { + ValidateConfig(ProviderConfig) error + CreateProvider(ProviderConfig) (Provider, error) +} + +var ( + providerInitializers = map[string]providerInitializer{ + PROVIDER_TYPE_DASH_VECTOR: &dashVectorProviderInitializer{}, + // PROVIDER_TYPE_CHROMA: &chromaProviderInitializer{}, + } +) + +// QueryResult 定义通用的查询结果的结构体 +type QueryResult struct { + Text string // 相似的文本 + Embedding []float64 // 相似文本的向量 + Score float64 // 文本的向量相似度或距离等度量 + Answer string // 相似文本对应的LLM生成的回答 +} + +type Provider interface { + GetProviderType() string +} + +type EmbeddingQuerier interface { + QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error +} + +type EmbeddingUploader interface { + UploadEmbedding( + queryString string, + queryEmb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error +} + +type AnswerAndEmbeddingUploader interface { + UploadAnswerAndEmbedding( + queryString string, + queryEmb []float64, + answer string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error +} + +type StringQuerier interface { + QueryString( + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error +} + +type SimilarityThresholdProvider interface { + GetSimilarityThreshold() float64 +} + +type ProviderConfig struct { + // @Title zh-CN 向量存储服务提供者类型 + // @Description zh-CN 向量存储服务提供者类型,例如 dashvector、chroma + typ string + // @Title zh-CN 向量存储服务名称 + // @Description zh-CN 向量存储服务名称 + serviceName string + // @Title zh-CN 向量存储服务域名 + // @Description zh-CN 向量存储服务域名 + serviceHost string + // @Title zh-CN 向量存储服务端口 + // @Description zh-CN 向量存储服务端口 + servicePort int64 + // @Title zh-CN 向量存储服务 API Key + // @Description zh-CN 向量存储服务 API Key + apiKey string + // @Title zh-CN 返回TopK结果 + // @Description zh-CN 返回TopK结果,默认为 1 + topK int + // @Title zh-CN 请求超时 + // @Description zh-CN 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 + timeout uint32 + // @Title zh-CN DashVector 向量存储服务 Collection ID + // @Description zh-CN DashVector 向量存储服务 Collection ID + collectionID string + // @Title zh-CN 相似度度量阈值 + // @Description zh-CN 默认相似度度量阈值,默认为 1000。 + Threshold float64 + // @Title zh-CN 相似度度量比较方式 + // @Description zh-CN 相似度度量比较方式,默认为小于。 + // 相似度度量方式有 Cosine, DotProduct, Euclidean 等,前两者值越大相似度越高,后者值越小相似度越高。 + // 所以需要允许自定义比较方式,对于 Cosine 和 DotProduct 选择 gt,对于 Euclidean 则选择 lt。 + // 默认为 lt,所有条件包括 lt (less than,小于)、lte (less than or equal to,小等于)、gt (greater than,大于)、gte (greater than or equal to,大等于) + ThresholdRelation string +} + +func (c *ProviderConfig) GetProviderType() string { + return c.typ +} + +func (c *ProviderConfig) FromJson(json gjson.Result) { + c.typ = json.Get("type").String() + // DashVector + c.serviceName = json.Get("serviceName").String() + c.serviceHost = json.Get("serviceHost").String() + c.servicePort = int64(json.Get("servicePort").Int()) + if c.servicePort == 0 { + c.servicePort = 443 + } + c.apiKey = json.Get("apiKey").String() + c.collectionID = json.Get("collectionID").String() + c.topK = int(json.Get("topK").Int()) + if c.topK == 0 { + c.topK = 1 + } + c.timeout = uint32(json.Get("timeout").Int()) + if c.timeout == 0 { + c.timeout = 10000 + } + c.Threshold = json.Get("threshold").Float() + if c.Threshold == 0 { + c.Threshold = 1000 + } + c.ThresholdRelation = json.Get("thresholdRelation").String() + if c.ThresholdRelation == "" { + c.ThresholdRelation = "lt" + } +} + +func (c *ProviderConfig) Validate() error { + if c.typ == "" { + return errors.New("vector database service is required") + } + initializer, has := providerInitializers[c.typ] + if !has { + return errors.New("unknown vector database service provider type: " + c.typ) + } + if err := initializer.ValidateConfig(*c); err != nil { + return err + } + return nil +} + +func CreateProvider(pc ProviderConfig) (Provider, error) { + initializer, has := providerInitializers[pc.typ] + if !has { + return nil, errors.New("unknown provider type: " + pc.typ) + } + return initializer.CreateProvider(pc) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/Makefile b/plugins/wasm-go/extensions/ai-proxy/Makefile new file mode 100644 index 0000000000..e5c7fa8de9 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/Makefile @@ -0,0 +1,4 @@ +.DEFAULT: +build: + tinygo build -o ai-proxy.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' ./main.go + mv ai-proxy.wasm ../../../../docker-compose-test/ \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-proxy/go.mod b/plugins/wasm-go/extensions/ai-proxy/go.mod index 7fed801fab..a5457b90f8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/go.mod +++ b/plugins/wasm-go/extensions/ai-proxy/go.mod @@ -10,7 +10,7 @@ require ( github.com/alibaba/higress/plugins/wasm-go v0.0.0 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.17.3 ) require ( diff --git a/plugins/wasm-go/extensions/ai-proxy/go.sum b/plugins/wasm-go/extensions/ai-proxy/go.sum index e5b8b79175..b2d63b5f4b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/go.sum +++ b/plugins/wasm-go/extensions/ai-proxy/go.sum @@ -13,8 +13,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 9e0fafe179..7b19d03fc2 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -82,7 +82,8 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf providerConfig := pluginConfig.GetProviderConfig() if apiName == "" && !providerConfig.IsOriginal() { log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path) - _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path) + // _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path) + log.Debugf("[onHttpRequestHeader] no send response") return types.ActionContinue } ctx.SetContext(ctxKeyApiName, apiName) diff --git a/plugins/wasm-go/extensions/request-block/Dockerfile b/plugins/wasm-go/extensions/request-block/Dockerfile new file mode 100644 index 0000000000..9b084e0596 --- /dev/null +++ b/plugins/wasm-go/extensions/request-block/Dockerfile @@ -0,0 +1,2 @@ +FROM scratch +COPY main.wasm plugin.wasm \ No newline at end of file diff --git a/plugins/wasm-go/extensions/request-block/Makefile b/plugins/wasm-go/extensions/request-block/Makefile new file mode 100644 index 0000000000..1210d6ec34 --- /dev/null +++ b/plugins/wasm-go/extensions/request-block/Makefile @@ -0,0 +1,4 @@ +.DEFAULT: +build: + tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer' ./main.go + mv main.wasm ../../../../docker-compose-test/ \ No newline at end of file diff --git a/plugins/wasm-go/extensions/request-block/main.go b/plugins/wasm-go/extensions/request-block/main.go index 224d4b26d6..2a43b4df72 100644 --- a/plugins/wasm-go/extensions/request-block/main.go +++ b/plugins/wasm-go/extensions/request-block/main.go @@ -177,7 +177,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config RequestBlockConfig, lo } func onHttpRequestBody(ctx wrapper.HttpContext, config RequestBlockConfig, body []byte, log wrapper.Log) types.Action { + log.Infof("My request-block body: %s\n", string(body)) bodyStr := string(body) + if !config.caseSensitive { bodyStr = strings.ToLower(bodyStr) } diff --git a/plugins/wasm-go/go.mod b/plugins/wasm-go/go.mod index 999721f3f6..6373ff646e 100644 --- a/plugins/wasm-go/go.mod +++ b/plugins/wasm-go/go.mod @@ -7,7 +7,7 @@ require ( github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.17.3 github.com/tidwall/resp v0.1.1 ) diff --git a/plugins/wasm-go/go.sum b/plugins/wasm-go/go.sum index e726b100a5..f396d4d7d9 100644 --- a/plugins/wasm-go/go.sum +++ b/plugins/wasm-go/go.sum @@ -4,6 +4,12 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43 h1:dCw7F/9ciw4NZN7w68wQRaygZ2zGOWMTIEoRvP1tlWs= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= @@ -14,6 +20,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= diff --git a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go index 10aa9020bd..c619c3e191 100644 --- a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go @@ -235,10 +235,11 @@ func (c RedisClusterClient[C]) Set(key string, value interface{}, callback Redis func (c RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error { args := make([]interface{}, 0) - args = append(args, "setex") + args = append(args, "set") args = append(args, key) - args = append(args, ttl) args = append(args, value) + args = append(args, "ex") + args = append(args, ttl) return RedisCall(c.cluster, respString(args), callback) } diff --git a/test/e2e/conformance/tests/go-wasm-ai-cache.go b/test/e2e/conformance/tests/go-wasm-ai-cache.go new file mode 100644 index 0000000000..30ac248916 --- /dev/null +++ b/test/e2e/conformance/tests/go-wasm-ai-cache.go @@ -0,0 +1,76 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" + + "github.com/alibaba/higress/test/e2e/conformance/utils/http" + "github.com/alibaba/higress/test/e2e/conformance/utils/suite" +) + +func init() { + Register(WasmPluginsAiCache) +} + +var WasmPluginsAiCache = suite.ConformanceTest{ + ShortName: "WasmPluginAiCache", + Description: "The Ingress in the higress-conformance-infra namespace test the ai-cache WASM plugin.", + Features: []suite.SupportedFeature{suite.WASMGoConformanceFeature}, + Manifests: []string{"tests/go-wasm-ai-cache.yaml"}, + Test: func(t *testing.T, suite *suite.ConformanceTestSuite) { + testcases := []http.Assertion{ + { + Meta: http.AssertionMeta{ + TestCaseName: "case 1: basic", + TargetBackend: "infra-backend-v1", + TargetNamespace: "higress-conformance-infra", + }, + Request: http.AssertionRequest{ + ActualRequest: http.Request{ + Host: "dashscope.aliyuncs.com", + Path: "/v1/chat/completions", + Method: "POST", + ContentType: http.ContentTypeApplicationJson, + Body: []byte(`{ + "model": "qwen-long", + "messages": [{"role":"user","content":"hi"}]}`), + }, + ExpectedRequest: &http.ExpectedRequest{ + Request: http.Request{ + Host: "dashscope.aliyuncs.com", + Path: "/compatible-mode/v1/chat/completions", + Method: "POST", + ContentType: http.ContentTypeApplicationJson, + Body: []byte(`{ + "model": "qwen-long", + "messages": [{"role":"user","content":"hi"}]}`), + }, + }, + }, + Response: http.AssertionResponse{ + ExpectedResponse: http.Response{ + StatusCode: 200, + }, + }, + }, + } + t.Run("WasmPlugins ai-cache", func(t *testing.T) { + for _, testcase := range testcases { + http.MakeRequestAndExpectEventuallyConsistentResponse(t, suite.RoundTripper, suite.TimeoutConfig, suite.GatewayAddress, testcase) + } + }) + }, +} diff --git a/test/e2e/conformance/tests/go-wasm-ai-cache.yaml b/test/e2e/conformance/tests/go-wasm-ai-cache.yaml new file mode 100644 index 0000000000..c7d6b0c46b --- /dev/null +++ b/test/e2e/conformance/tests/go-wasm-ai-cache.yaml @@ -0,0 +1,103 @@ +# Copyright (c) 2022 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + annotations: + name: wasmplugin-ai-cache-openai + namespace: higress-conformance-infra +spec: + ingressClassName: higress + rules: + - host: "dashscope.aliyuncs.com" + http: + paths: + - pathType: Prefix + path: "/" + backend: + service: + name: infra-backend-v1 + port: + number: 8080 +--- +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + annotations: + name: wasmplugin-ai-cache-qwen + namespace: higress-conformance-infra +spec: + ingressClassName: higress + rules: + - host: "qwen.ai.com" + http: + paths: + - pathType: Prefix + path: "/" + backend: + service: + name: infra-backend-v1 + port: + number: 8080 +--- +apiVersion: extensions.higress.io/v1alpha1 +kind: WasmPlugin +metadata: + name: ai-cache + namespace: higress-system +spec: + priority: 400 + matchRules: + - config: + embedding: + type: "dashscope" + serviceName: "qwen" + apiKey: "{{secret.qwenApiKey}}" + timeout: 12000 + vector: + type: "dashvector" + serviceName: "dashvector" + collectionID: "{{secret.collectionID}}" + serviceDomain: "{{secret.serviceDomain}}" + apiKey: "{{secret.apiKey}}" + timeout: 12000 + cache: + + ingress: + - higress-conformance-infra/wasmplugin-ai-cache-openai + - higress-conformance-infra/wasmplugin-ai-cache-qwen + # url: file:///opt/plugins/wasm-go/extensions/ai-cache/plugin.wasm + url: oci://registry.cn-shanghai.aliyuncs.com/suchunsv/higress_ai:1.18 +--- +apiVersion: extensions.higress.io/v1alpha1 +kind: WasmPlugin +metadata: + name: ai-proxy + namespace: higress-system +spec: + priority: 201 + matchRules: + - config: + provider: + type: "qwen" + qwenEnableCompatible: true + apiTokens: + - "{{secret.qwenApiKey}}" + timeout: 1200000 + modelMapping: + "*": "qwen-long" + ingress: + - higress-conformance-infra/wasmplugin-ai-cache-openai + - higress-conformance-infra/wasmplugin-ai-cache-qwen + url: oci://higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/ai-proxy:1.0.0