diff --git a/plugins/wasm-go/extensions/ai-security-guard/README.md b/plugins/wasm-go/extensions/ai-security-guard/README.md index 4961d527b3..5a8f753f3b 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/README.md +++ b/plugins/wasm-go/extensions/ai-security-guard/README.md @@ -1,22 +1,143 @@ +--- +title: AI内容安全 +keywords: [higress, AI, security] +description: 阿里云内容安全检测 +--- + ## 功能说明 +通过对接阿里云内容安全检测大模型的输入输出,保障AI应用内容合法合规。 + +## 运行属性 + +插件执行阶段:`默认阶段` +插件执行优先级:`300` ## 配置说明 | Name | Type | Requirement | Default | Description | -| :-: | :-: | :-: | :-: | :-: | -| serviceSource | string | requried | - | 服务来源,填dns | -| serviceName | string | requried | - | 服务名 | -| servicePort | string | requried | - | 服务端口 | -| domain | string | requried | - | 阿里云内容安全endpoint | -| ak | string | requried | - | 阿里云AK | -| sk | string | requried | - | 阿里云SK | +| ------------ | ------------ | ------------ | ------------ | ------------ | +| `serviceName` | string | requried | - | 服务名 | +| `servicePort` | string | requried | - | 服务端口 | +| `serviceHost` | string | requried | - | 阿里云内容安全endpoint的域名 | +| `accessKey` | string | requried | - | 阿里云AK | +| `secretKey` | string | requried | - | 阿里云SK | +| `checkRequest` | bool | optional | false | 检查提问内容是否合规 | +| `checkResponse` | bool | optional | false | 检查大模型的回答内容是否合规,生效时会使流式响应变为非流式 | +| `requestCheckService` | string | optional | llm_query_moderation | 指定阿里云内容安全用于检测输入内容的服务 | +| `responseCheckService` | string | optional | llm_response_moderation | 指定阿里云内容安全用于检测输出内容的服务 | +| `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | 指定要检测内容在请求body中的jsonpath | +| `responseContentJsonPath` | string | optional | `choices.0.message.content` | 指定要检测内容在响应body中的jsonpath | +| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | 指定要检测内容在流式响应body中的jsonpath | +| `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 | +| `denyMessage` | string | optional | openai格式的流失/非流式响应,回答内容为阿里云内容安全的建议回答 | 指定内容非法时的响应内容 | ## 配置示例 +### 前提条件 +由于插件中需要调用阿里云内容安全服务,所以需要先创建一个DNS类型的服务,例如: + +![](https://img.alicdn.com/imgextra/i4/O1CN013AbDcn1slCY19inU2_!!6000000005806-0-tps-1754-1320.jpg) + +### 检测输入内容是否合规 + +```yaml +serviceName: safecheck.dns +servicePort: 443 +serviceHost: "green-cip.cn-shanghai.aliyuncs.com" +accessKey: "XXXXXXXXX" +secretKey: "XXXXXXXXXXXXXXX" +checkRequest: true +``` + +### 检测输入与输出是否合规 + +```yaml +serviceName: safecheck.dns +servicePort: 443 +serviceHost: green-cip.cn-shanghai.aliyuncs.com +accessKey: "XXXXXXXXX" +secretKey: "XXXXXXXXXXXXXXX" +checkRequest: true +checkResponse: true +``` + +### 指定自定义内容安全检测服务 +用户可能需要根据不同的场景配置不同的检测规则,该问题可通过为不同域名/路由/服务配置不同的内容安全检测服务实现。如下图所示,我们创建了一个名为 llm_query_moderation_01 的检测服务,其中的检测规则在 llm_query_moderation 之上做了一些改动: + +![](https://img.alicdn.com/imgextra/i4/O1CN01bAtcvn1N9sB16iiZR_!!6000000001528-0-tps-2728-822.jpg) + +接下来在目标域名/路由/服务级别进行以下配置,指定使用我们自定义的 llm_query_moderation_01 中的规则进行检测: + +```yaml +serviceName: safecheck.dns +servicePort: 443 +serviceHost: "green-cip.cn-shanghai.aliyuncs.com" +accessKey: "XXXXXXXXX" +secretKey: "XXXXXXXXXXXXXXX" +checkRequest: true +requestCheckService: llm_query_moderation_01 +``` + +### 配置非openai协议(例如百炼App) + ```yaml -serviceSource: "dns" -serviceName: "safecheck" +serviceName: safecheck.dns servicePort: 443 -domain: "green-cip.cn-shanghai.aliyuncs.com" -ak: "XXXXXXXXX" -sk: "XXXXXXXXXXXXXXX" +serviceHost: "green-cip.cn-shanghai.aliyuncs.com" +accessKey: "XXXXXXXXX" +secretKey: "XXXXXXXXXXXXXXX" +checkRequest: true +checkResponse: true +requestContentJsonPath: "input.prompt" +responseContentJsonPath: "output.text" +denyCode: 200 +denyMessage: "很抱歉,我无法回答您的问题" +``` + +## 可观测 +### Metric +ai-security-guard 插件提供了以下监控指标: +- `ai_sec_request_deny`: 请求内容安全检测失败请求数 +- `ai_sec_response_deny`: 模型回答安全检测失败请求数 + +### Trace +如果开启了链路追踪,ai-security-guard 插件会在请求 span 中添加以下 attributes: +- `ai_sec_risklabel`: 表示请求命中的风险类型 +- `ai_sec_deny_phase`: 表示请求被检测到风险的阶段(取值为request或者response) + +## 请求示例 +```bash +curl http://localhost/v1/chat/completions \ +-H "Content-Type: application/json" \ +-d '{ + "model": "gpt-4o-mini", + "messages": [ + { + "role": "user", + "content": "这是一段非法内容" + } + ] +}' +``` + +请求内容会被发送到阿里云内容安全服务进行检测,如果请求内容检测结果为非法,网关将返回形如以下的回答: + +```json +{ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o-mini", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。", + }, + "logprobs": null, + "finish_reason": "stop" + } + ] +} ``` diff --git a/plugins/wasm-go/extensions/ai-security-guard/README_EN.md b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md new file mode 100644 index 0000000000..450b554179 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-security-guard/README_EN.md @@ -0,0 +1,69 @@ +--- +title: AI Content Security +keywords: [higress, AI, security] +description: Alibaba Cloud content security +--- + + +## Introduction +Integrate with Aliyun content security service for detections of input and output of LLMs, ensuring that application content is legal and compliant. + +## Runtime Properties + +Plugin Phase: `CUSTOM` +Plugin Priority: `300` + +## Configuration +| Name | Type | Requirement | Default | Description | +| ------------ | ------------ | ------------ | ------------ | ------------ | +| `serviceName` | string | requried | - | service name | +| `servicePort` | string | requried | - | service port | +| `serviceHost` | string | requried | - | Host of Aliyun content security service endpoint | +| `accessKey` | string | requried | - | Aliyun accesskey | +| `secretKey` | string | requried | - | Aliyun secretkey | +| `checkRequest` | bool | optional | false | check if the input is legal | +| `checkResponse` | bool | optional | false | check if the output is legal | +| `requestCheckService` | string | optional | llm_query_moderation | Aliyun yundun service name for input check | +| `responseCheckService` | string | optional | llm_response_moderation | Aliyun yundun service name for output check | +| `requestContentJsonPath` | string | optional | `messages.@reverse.0.content` | Specify the jsonpath of the content to be detected in the request body | +| `responseContentJsonPath` | string | optional | `choices.0.message.content` | Specify the jsonpath of the content to be detected in the response body | +| `responseStreamContentJsonPath` | string | optional | `choices.0.delta.content` | Specify the jsonpath of the content to be detected in the streaming response body | +| `denyCode` | int | optional | 200 | Response status code when the specified content is illegal | +| `denyMessage` | string | optional | Drainage/non-streaming response in openai format, the answer content is the suggested answer from Alibaba Cloud content security + | Response content when the specified content is illegal | + + +## Examples of configuration +### Check if the input is legal + +```yaml +serviceName: safecheck.dns +servicePort: 443 +serviceHost: "green-cip.cn-shanghai.aliyuncs.com" +accessKey: "XXXXXXXXX" +secretKey: "XXXXXXXXXXXXXXX" +checkRequest: true +``` + +### Check if both the input and output are legal + +```yaml +serviceName: safecheck.dns +servicePort: 443 +serviceHost: green-cip.cn-shanghai.aliyuncs.com +accessKey: "XXXXXXXXX" +secretKey: "XXXXXXXXXXXXXXX" +checkRequest: true +checkResponse: true +``` + +## Observability +### Metric +ai-security-guard plugin provides following metrics: +- `ai_sec_request_deny`: count of requests denied at request phase +- `ai_sec_response_deny`: count of requests denied at response phase + +### Trace +ai-security-guard plugin provides following span attributes: +- `ai_sec_risklabel`: risk type of this request +- `ai_sec_deny_phase`: denied phase of this request, value can be request/response \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-security-guard/go.mod b/plugins/wasm-go/extensions/ai-security-guard/go.mod index cd70355982..f2bc5a1436 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/go.mod +++ b/plugins/wasm-go/extensions/ai-security-guard/go.mod @@ -1,4 +1,4 @@ -module myplugin +module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard go 1.18 diff --git a/plugins/wasm-go/extensions/ai-security-guard/go.sum b/plugins/wasm-go/extensions/ai-security-guard/go.sum index 1924b268fc..f473e12b2d 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/go.sum +++ b/plugins/wasm-go/extensions/ai-security-guard/go.sum @@ -1,14 +1,9 @@ -github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY= -github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc= -github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 h1:RhEmB+ApLKsClZD7joTC4ifmsVgOVz4pFLdPR3xhNaE= -github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index a97fda5770..e1ccfeb572 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -1,12 +1,12 @@ package main import ( + "bytes" "crypto/hmac" "crypto/rand" "crypto/sha1" "encoding/base64" "encoding/hex" - "encoding/json" "errors" "fmt" "net/http" @@ -32,16 +32,47 @@ func main() { ) } +const ( + OpenAIResponseFormat = `{"id": "chatcmpl-123","object": "chat.completion","model": "gpt-4o-mini","choices": [{"index": 0,"message": {"role": "assistant","content": "%s"},"logprobs": null,"finish_reason": "stop"}]}` + OpenAIStreamResponseChunk = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` + OpenAIStreamResponseEnd = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4o-mini", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}` + OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + + TracingPrefix = "trace_span_tag." + + DefaultRequestCheckService = "llm_query_moderation" + DefaultResponseCheckService = "llm_response_moderation" + DefaultRequestJsonPath = "messages.@reverse.0.content" + DefaultResponseJsonPath = "choices.0.message.content" + DefaultStreamingResponseJsonPath = "choices.0.delta.content" + DefaultDenyCode = 200 + + AliyunUserAgent = "CIPFrom/AIGateway" +) + type AISecurityConfig struct { - client wrapper.HttpClient - ak string - sk string + client wrapper.HttpClient + ak string + sk string + checkRequest bool + requestCheckService string + requestContentJsonPath string + checkResponse bool + responseCheckService string + responseContentJsonPath string + responseStreamContentJsonPath string + denyCode int64 + denyMessage string + metrics map[string]proxywasm.MetricCounter } -type StandardResponse struct { - Code int `json:"Code"` - Phase string `json:"BlockPhase"` - Message string `json:"Message"` +func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) { + counter, ok := config.metrics[metricName] + if !ok { + counter = proxywasm.DefineCounterMetric(metricName) + config.metrics[metricName] = counter + } + counter.Increment(inc) } func urlEncoding(rawStr string) string { @@ -71,7 +102,7 @@ func getSign(params map[string]string, secret string) string { }) canonicalStr := strings.Join(paramArray, "&") signStr := "POST&%2F&" + urlEncoding(canonicalStr) - fmt.Println(signStr) + // proxywasm.LogInfo(signStr) return hmacSha1(signStr, secret) } @@ -86,32 +117,70 @@ func generateHexID(length int) (string, error) { func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) error { serviceName := json.Get("serviceName").String() servicePort := json.Get("servicePort").Int() - domain := json.Get("domain").String() - config.ak = json.Get("ak").String() - config.sk = json.Get("sk").String() - if serviceName == "" || servicePort == 0 || domain == "" { + serviceHost := json.Get("serviceHost").String() + if serviceName == "" || servicePort == 0 || serviceHost == "" { return errors.New("invalid service config") } - config.client = wrapper.NewClusterClient(wrapper.DnsCluster{ - ServiceName: serviceName, - Port: servicePort, - Domain: domain, + config.ak = json.Get("accessKey").String() + config.sk = json.Get("secretKey").String() + if config.ak == "" || config.sk == "" { + return errors.New("invalid AK/SK config") + } + config.checkRequest = json.Get("checkRequest").Bool() + config.checkResponse = json.Get("checkResponse").Bool() + config.denyMessage = json.Get("denyMessage").String() + if obj := json.Get("denyCode"); obj.Exists() { + config.denyCode = obj.Int() + } else { + config.denyCode = DefaultDenyCode + } + if obj := json.Get("requestCheckService"); obj.Exists() { + config.requestCheckService = obj.String() + } else { + config.requestCheckService = DefaultRequestCheckService + } + if obj := json.Get("responseCheckService"); obj.Exists() { + config.responseCheckService = obj.String() + } else { + config.responseCheckService = DefaultResponseCheckService + } + if obj := json.Get("requestContentJsonPath"); obj.Exists() { + config.requestContentJsonPath = obj.String() + } else { + config.requestContentJsonPath = DefaultRequestJsonPath + } + if obj := json.Get("responseContentJsonPath"); obj.Exists() { + config.responseContentJsonPath = obj.String() + } else { + config.responseContentJsonPath = DefaultResponseJsonPath + } + if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() { + config.responseStreamContentJsonPath = obj.String() + } else { + config.responseStreamContentJsonPath = DefaultStreamingResponseJsonPath + } + config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: serviceName, + Port: servicePort, + Host: serviceHost, }) + config.metrics = make(map[string]proxywasm.MetricCounter) return nil } func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action { + if !config.checkRequest { + ctx.DontReadRequestBody() + } + if !config.checkResponse { + ctx.DontReadResponseBody() + } return types.ActionContinue } func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { - messages := gjson.GetBytes(body, "messages").Array() - if len(messages) > 0 { - role := messages[len(messages)-1].Get("role").String() - content := messages[len(messages)-1].Get("content").String() - if role != "user" { - return types.ActionContinue - } + content := gjson.GetBytes(body, config.requestContentJsonPath).String() + if content != "" { timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") randomID, _ := generateHexID(16) params := map[string]string{ @@ -123,7 +192,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] "Action": "TextModerationPlus", "AccessKeyId": config.ak, "Timestamp": timestamp, - "Service": "llm_query_moderation", + "Service": config.requestCheckService, "ServiceParameters": `{"content": "` + content + `"}`, } signature := getSign(params, config.sk+"&") @@ -132,31 +201,27 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] reqParams.Add(k, v) } reqParams.Add("Signature", signature) - config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil, + config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) { respData := gjson.GetBytes(responseBody, "Data") if respData.Exists() { respAdvice := respData.Get("Advice") respResult := respData.Get("Result") if respAdvice.Exists() { - sr := StandardResponse{ - Code: 403, - Phase: "Request", - Message: respAdvice.Array()[0].Get("Answer").String(), - } - jsonData, _ := json.MarshalIndent(sr, "", " ") - label := respResult.Array()[0].Get("Label").String() - proxywasm.SetProperty([]string{"risklabel"}, []byte(label)) - proxywasm.SendHttpResponseWithDetail(403, "ai-security-guard.label."+label, [][2]string{{"content-type", "application/json"}}, jsonData, -1) - } else if respResult.Array()[0].Get("Label").String() != "nonLabel" { - sr := StandardResponse{ - Code: 403, - Phase: "Request", - Message: "risk detected", + proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) + proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request")) + config.incrementCounter("ai_sec_request_deny", 1) + if config.denyMessage != "" { + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(config.denyMessage), -1) + } else { + if gjson.GetBytes(body, "stream").Bool() { + jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String())) + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) + } else { + jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String())) + proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) + } } - jsonData, _ := json.MarshalIndent(sr, "", " ") - proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) - proxywasm.SendHttpResponseWithDetail(403, "ai-security-guard.risk_detected", [][2]string{{"content-type", "application/json"}}, jsonData, -1) } else { proxywasm.ResumeHttpRequest() } @@ -206,9 +271,16 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log } func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { - messages := gjson.GetBytes(body, "choices").Array() - if len(messages) > 0 { - content := messages[0].Get("message").Get("content").String() + hdsMap := ctx.GetContext("headers").(map[string][]string) + isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") + var content string + if isStreamingResponse { + content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath) + } else { + content = gjson.GetBytes(body, config.responseContentJsonPath).String() + } + log.Debugf("Raw response content is: %s", content) + if len(content) > 0 { timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") randomID, _ := generateHexID(16) params := map[string]string{ @@ -220,7 +292,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ "Action": "TextModerationPlus", "AccessKeyId": config.ak, "Timestamp": timestamp, - "Service": "llm_response_moderation", + "Service": config.responseCheckService, "ServiceParameters": `{"content": "` + content + `"}`, } signature := getSign(params, config.sk+"&") @@ -229,7 +301,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ reqParams.Add(k, v) } reqParams.Add("Signature", signature) - config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil, + config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) { defer proxywasm.ResumeHttpResponse() respData := gjson.GetBytes(responseBody, "Data") @@ -237,31 +309,23 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ respAdvice := respData.Get("Advice") respResult := respData.Get("Result") if respAdvice.Exists() { - sr := StandardResponse{ - Code: 403, - Phase: "Response", - Message: respAdvice.Array()[0].Get("Answer").String(), + var jsonData []byte + if config.denyMessage != "" { + jsonData = []byte(config.denyMessage) + } else { + if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") { + jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, respAdvice.Array()[0].Get("Answer").String())) + } else { + jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, respAdvice.Array()[0].Get("Answer").String())) + } } - jsonData, _ := json.MarshalIndent(sr, "", " ") - hdsMap := ctx.GetContext("headers").(map[string][]string) delete(hdsMap, "content-length") - hdsMap[":status"] = []string{"403"} + hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)} proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap)) proxywasm.ReplaceHttpResponseBody(jsonData) - proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) - } else if respResult.Array()[0].Get("Label").String() != "nonLabel" { - sr := StandardResponse{ - Code: 403, - Phase: "Response", - Message: "risk detected", - } - jsonData, _ := json.MarshalIndent(sr, "", " ") - hdsMap := ctx.GetContext("headers").(map[string][]string) - delete(hdsMap, "content-length") - hdsMap[":status"] = []string{"403"} - proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap)) - proxywasm.ReplaceHttpResponseBody(jsonData) - proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) + proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String())) + proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("response")) + config.incrementCounter("ai_sec_response_deny", 1) } } }, @@ -271,3 +335,16 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ return types.ActionContinue } } + +func extractMessageFromStreamingBody(data []byte, jsonPath string) string { + chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) + strChunks := []string{} + for _, chunk := range chunks { + // Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}] + jsonObj := gjson.GetBytes(chunk, jsonPath) + if jsonObj.Exists() { + strChunks = append(strChunks, jsonObj.String()) + } + } + return strings.Join(strChunks, "") +}