Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ai security guard #1261

Merged
merged 16 commits into from
Sep 24, 2024
43 changes: 20 additions & 23 deletions plugins/wasm-go/extensions/ai-security-guard/README-en.md
Original file line number Diff line number Diff line change
@@ -1,46 +1,43 @@
## 简介
## Introduction
Integrate with Aliyun content security service for detections of input and output of LLMs, ensuring that application content is legal and compliant.

## 配置说明
## Configuration
| Name | Type | Requirement | Default | Description |
| ------------ | ------------ | ------------ | ------------ | ------------ |
| `serviceSource` | string | requried | - | service source, such as `dns` |
| `serviceName` | string | requried | - | service name |
| `servicePort` | string | requried | - | service port |
rinfx marked this conversation as resolved.
Show resolved Hide resolved
| `domain` | string | requried | - | Host of Aliyun content security service endpoint |
| `ak` | string | requried | - | Aliyun accesskey |
| `sk` | string | requried | - | Aliyun secretkey |
| `checkRequest` | bool | optional | - | check if the input is leagal |
| `checkresponse` | bool | optional | - | check if the output is leagal |
| `serviceHost` | string | requried | - | Host of Aliyun content security service endpoint |
| `accessKey` | string | requried | - | Aliyun accesskey |
| `secretKey` | string | requried | - | Aliyun secretkey |
| `checkRequest` | bool | optional | false | check if the input is leagal |
| `checkResponse` | bool | optional | false | check if the output is leagal |

rinfx marked this conversation as resolved.
Show resolved Hide resolved

## 配置示例
### check if the input is leagal
## Examples of configuration
### Check if the input is leagal

rinfx marked this conversation as resolved.
Show resolved Hide resolved
```yaml
serviceSource: "dns"
serviceName: "safecheck"
serviceName: safecheck.dns
servicePort: 443
domain: "green-cip.cn-shanghai.aliyuncs.com"
ak: "XXXXXXXXX"
sk: "XXXXXXXXXXXXXXX"
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
```

### check if both the input and output are leagal
### Check if both the input and output are leagal

```yaml
serviceSource: "dns"
serviceName: "safecheck"
serviceName: safecheck.dns
servicePort: 443
domain: "green-cip.cn-shanghai.aliyuncs.com"
ak: "XXXXXXXXX"
sk: "XXXXXXXXXXXXXXX"
serviceHost: green-cip.cn-shanghai.aliyuncs.com
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
checkresponse: true
checkResponse: true
```

## observability
## Observability
### Metric
ai-security-guard plugin provides following metrics:
- `ai_sec_request_deny`: count of requests denied at request phase
Expand Down
71 changes: 53 additions & 18 deletions plugins/wasm-go/extensions/ai-security-guard/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,37 @@
## 配置说明
| Name | Type | Requirement | Default | Description |
| ------------ | ------------ | ------------ | ------------ | ------------ |
| `serviceSource` | string | requried | - | 服务来源,填dns |
| `serviceName` | string | requried | - | 服务名 |
| `servicePort` | string | requried | - | 服务端口 |
| `domain` | string | requried | - | 阿里云内容安全endpoint的域名 |
| `ak` | string | requried | - | 阿里云AK |
| `sk` | string | requried | - | 阿里云SK |
| `checkRequest` | bool | optional | - | 检查提问内容是否合规 |
| `checkresponse` | bool | optional | - | 检查大模型的回答内容是否合规,生效时会使流式响应变为非流式 |
| `serviceHost` | string | requried | - | 阿里云内容安全endpoint的域名 |
| `accessKey` | string | requried | - | 阿里云AK |
| `secretKey` | string | requried | - | 阿里云SK |
| `checkRequest` | bool | optional | false | 检查提问内容是否合规 |
| `checkResponse` | bool | optional | false | 检查大模型的回答内容是否合规,生效时会使流式响应变为非流式 |
rinfx marked this conversation as resolved.
Show resolved Hide resolved


## 配置示例
### 检测输入内容是否合规

```yaml
serviceSource: "dns"
serviceName: "safecheck"
serviceName: safecheck.dns
servicePort: 443
domain: "green-cip.cn-shanghai.aliyuncs.com"
ak: "XXXXXXXXX"
sk: "XXXXXXXXXXXXXXX"
serviceHost: "green-cip.cn-shanghai.aliyuncs.com"
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
```

### 检测输入与输出是否合规

```yaml
serviceSource: "dns"
serviceName: "safecheck"
serviceName: safecheck.dns
servicePort: 443
domain: "green-cip.cn-shanghai.aliyuncs.com"
ak: "XXXXXXXXX"
sk: "XXXXXXXXXXXXXXX"
serviceHost: green-cip.cn-shanghai.aliyuncs.com
accessKey: "XXXXXXXXX"
secretKey: "XXXXXXXXXXXXXXX"
checkRequest: true
checkresponse: true
checkResponse: true
```

## 可观测
Expand All @@ -49,4 +46,42 @@ ai-security-guard 插件提供了以下监控指标:
### Trace
如果开启了链路追踪,ai-security-guard 插件会在请求 span 中添加以下 attributes:
- `ai_sec_risklabel`: 表示请求命中的风险类型
- `ai_sec_deny_phase`: 表示请求被检测到风险的阶段(取值为request或者response)
- `ai_sec_deny_phase`: 表示请求被检测到风险的阶段(取值为request或者response)

## 请求示例
```bash
curl http://localhost/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-4o-mini",
"messages": [
{
"role": "user",
"content": "这是一段非法内容"
}
]
}'
```

请求内容会被发送到阿里云内容安全服务进行检测,如果请求内容检测结果为非法,网关将返回形如以下的回答:

```json
{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4o-mini",
"system_fingerprint": "fp_44709d6fcb",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "作为一名人工智能助手,我不能提供涉及色情、暴力、政治等敏感话题的内容。如果您有其他相关问题,欢迎您提问。",
},
"logprobs": null,
"finish_reason": "stop"
}
]
}
```
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-security-guard/go.mod
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module ai-security-guard
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard

go 1.18

Expand Down
41 changes: 21 additions & 20 deletions plugins/wasm-go/extensions/ai-security-guard/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ func main() {
}

const (
NormalResponseFormat = `{"id": "chatcmpl-123","object": "chat.completion","created": 1677652288,"model": "gpt-4o-mini","system_fingerprint": "fp_44709d6fcb","choices": [{"index": 0,"message": {"role": "assistant","content": "%s",},"logprobs": null,"finish_reason": "stop"}]}`
StreamResponseChunk = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
StreamResponseEnd = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
StreamResponseFormat = StreamResponseChunk + "\n\n" + StreamResponseEnd
TracingPrefix = "trace_span_tag."
NormalResponseFormat = `{"id": "chatcmpl-123","object": "chat.completion","created": 1677652288,"model": "gpt-4o-mini","system_fingerprint": "fp_44709d6fcb","choices": [{"index": 0,"message": {"role": "assistant","content": "%s",},"logprobs": null,"finish_reason": "stop"}]}`
StreamResponseChunk = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
StreamResponseEnd = `data:{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
StreamResponseFormat = StreamResponseChunk + "\n\n" + StreamResponseEnd
TracingPrefix = "trace_span_tag."
DefaultResponseIfNoAdvice = "很抱歉,我不能对您的问题做出回答。"
)

type AISecurityConfig struct {
Expand Down Expand Up @@ -127,21 +128,21 @@ func generateHexID(length int) (string, error) {
func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) error {
serviceName := json.Get("serviceName").String()
servicePort := json.Get("servicePort").Int()
domain := json.Get("domain").String()
if serviceName == "" || servicePort == 0 || domain == "" {
serviceHost := json.Get("serviceHost").String()
if serviceName == "" || servicePort == 0 || serviceHost == "" {
return errors.New("invalid service config")
}
config.ak = json.Get("ak").String()
config.sk = json.Get("sk").String()
config.ak = json.Get("accessKey").String()
config.sk = json.Get("secretKey").String()
if config.ak == "" || config.sk == "" {
return errors.New("invalid AK/SK config")
}
config.checkRequest = json.Get("checkRequest").Bool()
config.checkResponse = json.Get("checkResponse").Bool()
config.client = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: serviceName,
Port: servicePort,
Domain: domain,
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: servicePort,
Host: serviceHost,
})
config.metrics = make(map[string]proxywasm.MetricCounter)
return nil
Expand Down Expand Up @@ -186,7 +187,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", "CIPFrom/AIGateway"}}, nil,
johnlanni marked this conversation as resolved.
Show resolved Hide resolved
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
respData := gjson.GetBytes(responseBody, "Data")
if respData.Exists() {
Expand All @@ -200,19 +201,19 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
config.incrementCounter("ai_sec_request_deny", 1)
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
jsonData := []byte(fmt.Sprintf(StreamResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
jsonData := []byte(fmt.Sprintf(NormalResponseFormat, respAdvice.Array()[0].Get("Answer").String()))
johnlanni marked this conversation as resolved.
Show resolved Hide resolved
config.incrementCounter("ai_sec_request_deny", 1)
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
} else if respResult.Array()[0].Get("Label").String() != "nonLabel" {
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
if stream {
jsonData := []byte(fmt.Sprintf(StreamResponseFormat, "很抱歉,我不能对您的问题做出回答。"))
jsonData := []byte(fmt.Sprintf(StreamResponseFormat, DefaultResponseIfNoAdvice))
config.incrementCounter("ai_sec_request_deny", 1)
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
jsonData := []byte(fmt.Sprintf(NormalResponseFormat, "很抱歉,我不能对您的问题做出回答。"))
jsonData := []byte(fmt.Sprintf(NormalResponseFormat, DefaultResponseIfNoAdvice))
rinfx marked this conversation as resolved.
Show resolved Hide resolved
config.incrementCounter("ai_sec_request_deny", 1)
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
Expand Down Expand Up @@ -288,7 +289,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), nil, nil,
config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", "CIPFrom/AIGateway"}}, nil,
rinfx marked this conversation as resolved.
Show resolved Hide resolved
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
defer proxywasm.ResumeHttpResponse()
respData := gjson.GetBytes(responseBody, "Data")
Expand All @@ -314,9 +315,9 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
hdsMap := ctx.GetContext("headers").(map[string][]string)
var jsonData []byte
if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
jsonData = []byte(fmt.Sprintf(StreamResponseFormat, "很抱歉,我不能对您的问题做出回答。"))
jsonData = []byte(fmt.Sprintf(StreamResponseFormat, DefaultResponseIfNoAdvice))
} else {
jsonData = []byte(fmt.Sprintf(NormalResponseFormat, "很抱歉,我不能对您的问题做出回答。"))
jsonData = []byte(fmt.Sprintf(NormalResponseFormat, DefaultResponseIfNoAdvice))
rinfx marked this conversation as resolved.
Show resolved Hide resolved
}
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{"200"}
Expand Down
Loading