Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ai security guard #1261

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions plugins/wasm-go/extensions/ai-security-guard/README-en.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
## 简介
Integrate with Aliyun content security service for detections of input and output of LLMs, ensuring that application content is legal and compliant.

## 配置说明
| Name | Type | Requirement | Default | Description |
| ------------ | ------------ | ------------ | ------------ | ------------ |
| `serviceSource` | string | requried | - | service source, such as `dns` |
| `serviceName` | string | requried | - | service name |
rinfx marked this conversation as resolved.
Show resolved Hide resolved
| `servicePort` | string | requried | - | service port |
| `domain` | string | requried | - | Host of Aliyun content security service endpoint |
| `ak` | string | requried | - | Aliyun accesskey |
| `sk` | string | requried | - | Aliyun secretkey |
| `checkRequest` | bool | optional | - | check if the input is leagal |
| `checkresponse` | bool | optional | - | check if the output is leagal |
rinfx marked this conversation as resolved.
Show resolved Hide resolved


## 配置示例
rinfx marked this conversation as resolved.
Show resolved Hide resolved
### check if the input is leagal

```yaml
serviceSource: "dns"
serviceName: "safecheck"
servicePort: 443
domain: "green-cip.cn-shanghai.aliyuncs.com"
ak: "XXXXXXXXX"
sk: "XXXXXXXXXXXXXXX"
checkRequest: true
```

### check if both the input and output are leagal

```yaml
serviceSource: "dns"
serviceName: "safecheck"
servicePort: 443
domain: "green-cip.cn-shanghai.aliyuncs.com"
ak: "XXXXXXXXX"
sk: "XXXXXXXXXXXXXXX"
checkRequest: true
checkresponse: true
```

## observability
rinfx marked this conversation as resolved.
Show resolved Hide resolved
### Metric
ai-security-guard plugin provides following metrics:
- `ai_sec_request_deny`: count of requests denied at request phase
- `ai_sec_response_deny`: count of requests denied at response phase

### Trace
ai-security-guard plugin provides following span attributes:
- `ai_sec_risklabel`: risk type of this request
- `ai_sec_deny_phase`: denied phase of this request, value can be request/response
52 changes: 41 additions & 11 deletions plugins/wasm-go/extensions/ai-security-guard/README.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,52 @@
# 简介
## 简介
通过对接阿里云内容安全检测大模型的输入输出,保障AI应用内容合法合规。

# 配置说明
## 配置说明
| Name | Type | Requirement | Default | Description |
| :-: | :-: | :-: | :-: | :-: |
| serviceSource | string | requried | - | 服务来源,填dns |
| serviceName | string | requried | - | 服务名 |
| servicePort | string | requried | - | 服务端口 |
| domain | string | requried | - | 阿里云内容安全endpoint |
| ak | string | requried | - | 阿里云AK |
| sk | string | requried | - | 阿里云SK |
| ------------ | ------------ | ------------ | ------------ | ------------ |
| `serviceSource` | string | requried | - | 服务来源,填dns |
| `serviceName` | string | requried | - | 服务名 |
| `servicePort` | string | requried | - | 服务端口 |
| `domain` | string | requried | - | 阿里云内容安全endpoint的域名 |
| `ak` | string | requried | - | 阿里云AK |
| `sk` | string | requried | - | 阿里云SK |
| `checkRequest` | bool | optional | - | 检查提问内容是否合规 |
| `checkresponse` | bool | optional | - | 检查大模型的回答内容是否合规,生效时会使流式响应变为非流式 |


# 配置示例
## 配置示例
### 检测输入内容是否合规

```yaml
serviceSource: "dns"
serviceName: "safecheck"
servicePort: 443
domain: "green-cip.cn-shanghai.aliyuncs.com"
ak: "XXXXXXXXX"
sk: "XXXXXXXXXXXXXXX"
checkRequest: true
```

### 检测输入与输出是否合规

```yaml
serviceSource: "dns"
serviceName: "safecheck"
servicePort: 443
domain: "green-cip.cn-shanghai.aliyuncs.com"
ak: "XXXXXXXXX"
sk: "XXXXXXXXXXXXXXX"
```
checkRequest: true
checkresponse: true
```

## 可观测
### Metric
ai-security-guard 插件提供了以下监控指标:
- `ai_sec_request_deny`: 请求内容安全检测失败请求数
- `ai_sec_response_deny`: 模型回答安全检测失败请求数

### Trace
如果开启了链路追踪,ai-security-guard 插件会在请求 span 中添加以下 attributes:
- `ai_sec_risklabel`: 表示请求命中的风险类型
- `ai_sec_deny_phase`: 表示请求被检测到风险的阶段(取值为request或者response)
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-security-guard/go.mod
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module myplugin
module ai-security-guard

rinfx marked this conversation as resolved.
Show resolved Hide resolved
go 1.18

Expand Down
7 changes: 1 addition & 6 deletions plugins/wasm-go/extensions/ai-security-guard/go.sum
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906 h1:RhEmB+ApLKsClZD7joTC4ifmsVgOVz4pFLdPR3xhNaE=
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906/go.mod h1:10jQXKsYFUF7djs+Oy7t82f4dbie9pISfP9FJwpPLuk=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
Expand Down
164 changes: 120 additions & 44 deletions plugins/wasm-go/extensions/ai-security-guard/main.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package main

import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
Expand All @@ -32,16 +32,57 @@ func main() {
)
}

const (
NormalResponseFormat = `{"id": "chatcmpl-123","object": "chat.completion","created": 1677652288,"model": "gpt-4o-mini","system_fingerprint": "fp_44709d6fcb","choices": [{"index": 0,"message": {"role": "assistant","content": "%s",},"logprobs": null,"finish_reason": "stop"}]}`
StreamResponseChunk = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
StreamResponseEnd = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
StreamResponseFormat = StreamResponseChunk + "\n\n" + StreamResponseEnd
TracingPrefix = "trace_span_tag."
)

type AISecurityConfig struct {
client wrapper.HttpClient
ak string
sk string
client wrapper.HttpClient
ak string
sk string
checkRequest bool
checkResponse bool
metrics map[string]proxywasm.MetricCounter
}

func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64) {
counter, ok := config.metrics[metricName]
if !ok {
counter = proxywasm.DefineCounterMetric(metricName)
config.metrics[metricName] = counter
}
counter.Increment(inc)
}

type StandardResponse struct {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 struct 没有用到了吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个确实没再用到了,我删除一下

Code int `json:"Code"`
Phase string `json:"BlockPhase"`
Message string `json:"Message"`
ID string `json:"id"`
Choices []Choice `json:"choices"`
Created int64 `json:"created,omitempty"`
Model string `json:"model,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
Object string `json:"object,omitempty"`
Usage chatCompletionUsage `json:"usage,omitempty"`
}

type Choice struct {
Index int `json:"index"`
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

type chatCompletionUsage struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}

func urlEncoding(rawStr string) string {
Expand Down Expand Up @@ -71,7 +112,7 @@ func getSign(params map[string]string, secret string) string {
})
canonicalStr := strings.Join(paramArray, "&")
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
fmt.Println(signStr)
// proxywasm.LogInfo(signStr)
return hmacSha1(signStr, secret)
}

Expand All @@ -87,25 +128,38 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e
serviceName := json.Get("serviceName").String()
servicePort := json.Get("servicePort").Int()
domain := json.Get("domain").String()
config.ak = json.Get("ak").String()
config.sk = json.Get("sk").String()
if serviceName == "" || servicePort == 0 || domain == "" {
return errors.New("invalid service config")
}
config.ak = json.Get("ak").String()
config.sk = json.Get("sk").String()
if config.ak == "" || config.sk == "" {
return errors.New("invalid AK/SK config")
}
config.checkRequest = json.Get("checkRequest").Bool()
config.checkResponse = json.Get("checkResponse").Bool()
config.client = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: serviceName,
Port: servicePort,
Domain: domain,
})
config.metrics = make(map[string]proxywasm.MetricCounter)
return nil
}

func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
if !config.checkRequest {
ctx.DontReadRequestBody()
}
if !config.checkResponse {
ctx.DontReadResponseBody()
}
return types.ActionContinue
}

func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
messages := gjson.GetBytes(body, "messages").Array()
stream := gjson.GetBytes(body, "stream").Bool()
if len(messages) > 0 {
role := messages[len(messages)-1].Get("role").String()
content := messages[len(messages)-1].Get("content").String()
Expand Down Expand Up @@ -139,24 +193,29 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
if respAdvice.Exists() {
sr := StandardResponse{
Code: 403,
Phase: "Request",
Message: respAdvice.Array()[0].Get("Answer").String(),
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
if stream {
jsonData := []byte(fmt.Sprintf(StreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
config.incrementCounter("ai_sec_request_deny", 1)
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
jsonData := []byte(fmt.Sprintf(StreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
config.incrementCounter("ai_sec_request_deny", 1)
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
jsonData, _ := json.MarshalIndent(sr, "", " ")
label := respResult.Array()[0].Get("Label").String()
proxywasm.SetProperty([]string{"risklabel"}, []byte(label))
proxywasm.SendHttpResponseWithDetail(403, "ai-security-guard.label."+label, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
sr := StandardResponse{
Code: 403,
Phase: "Request",
Message: "risk detected",
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
if stream {
jsonData := []byte(fmt.Sprintf(StreamResponseFormat, "很抱歉,我不能对您的问题做出回答。"))
config.incrementCounter("ai_sec_request_deny", 1)
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
jsonData := []byte(fmt.Sprintf(NormalResponseFormat, "很抱歉,我不能对您的问题做出回答。"))
config.incrementCounter("ai_sec_request_deny", 1)
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
jsonData, _ := json.MarshalIndent(sr, "", " ")
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SendHttpResponseWithDetail(403, "ai-security-guard.risk_detected", [][2]string{{"content-type", "application/json"}}, jsonData, -1)
} else {
proxywasm.ResumeHttpRequest()
}
Expand Down Expand Up @@ -206,9 +265,9 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
}

func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
messages := gjson.GetBytes(body, "choices").Array()
if len(messages) > 0 {
content := messages[0].Get("message").Get("content").String()
content := extractResponseMessage(body)
log.Debugf("Raw response content is: %s", content)
if len(content) > 0 {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
Expand Down Expand Up @@ -237,31 +296,35 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
if respAdvice.Exists() {
sr := StandardResponse{
Code: 403,
Phase: "Response",
Message: respAdvice.Array()[0].Get("Answer").String(),
}
jsonData, _ := json.MarshalIndent(sr, "", " ")
hdsMap := ctx.GetContext("headers").(map[string][]string)
var jsonData []byte
if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
jsonData = []byte(fmt.Sprintf(StreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
} else {
jsonData = []byte(fmt.Sprintf(NormalResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
}
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{"403"}
hdsMap[":status"] = []string{"200"}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("response"))
config.incrementCounter("ai_sec_response_deny", 1)
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
sr := StandardResponse{
Code: 403,
Phase: "Response",
Message: "risk detected",
}
jsonData, _ := json.MarshalIndent(sr, "", " ")
hdsMap := ctx.GetContext("headers").(map[string][]string)
var jsonData []byte
if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
jsonData = []byte(fmt.Sprintf(StreamResponseFormat, "很抱歉,我不能对您的问题做出回答。"))
} else {
jsonData = []byte(fmt.Sprintf(NormalResponseFormat, "很抱歉,我不能对您的问题做出回答。"))
rinfx marked this conversation as resolved.
Show resolved Hide resolved
}
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{"403"}
hdsMap[":status"] = []string{"200"}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
proxywasm.SetProperty([]string{"risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("response"))
config.incrementCounter("ai_sec_response_deny", 1)
}
}
},
Expand All @@ -271,3 +334,16 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
return types.ActionContinue
}
}

func extractResponseMessage(data []byte) string {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
strChunks := []string{}
for _, chunk := range chunks {
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
jsonObj := gjson.GetBytes(chunk, "choices.0.delta.content")
if jsonObj.Exists() {
strChunks = append(strChunks, jsonObj.String())
}
}
return strings.Join(strChunks, "")
}
Loading