Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ const aiguard = tracer.aiguard
aiguard.evaluate([
{ role: 'user', content: 'What is 2 + 2' },
]).then(result => {
result.action && result.reason
result.action && result.reason && result.tags
})

aiguard.evaluate([
Expand All @@ -729,11 +729,11 @@ aiguard.evaluate([
],
}
]).then(result => {
result.action && result.reason
result.action && result.reason && result.tags
})

aiguard.evaluate([
{ role: 'tool', tool_call_id: 'call_1', content: '5' },
]).then(result => {
result.action && result.reason
result.action && result.reason && result.tags
})
10 changes: 9 additions & 1 deletion index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,10 @@ declare namespace tracer {
* Human-readable explanation for why this action was chosen.
*/
reason: string;
/**
* List of tags associated with the evaluation (e.g. indirect-prompt-injection)
*/
tags: string[];
}

/**
Expand All @@ -1331,6 +1335,10 @@ declare namespace tracer {
* Human-readable explanation from AI Guard describing why the conversation was blocked.
*/
reason: string;
/**
* List of tags associated with the evaluation (e.g. indirect-prompt-injection)
*/
tags: string[];
}

/**
Expand Down Expand Up @@ -1844,7 +1852,7 @@ declare namespace tracer {
* [@google-cloud/pubsub](https://github.com/googleapis/nodejs-pubsub) module.
*/
interface google_cloud_pubsub extends Integration {}

/**
* This plugin automatically instruments the
* [@google-cloud/vertexai](https://github.com/googleapis/nodejs-vertexai) module.
Expand Down
28 changes: 18 additions & 10 deletions packages/dd-trace/src/aiguard/sdk.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
'use strict'

const rfdc = require('rfdc')({ proto: false, circles: false })
const NoopAIGuard = require('./noop')
const executeRequest = require('./client')
const {
Expand All @@ -22,10 +23,11 @@ const appsecMetrics = telemetryMetrics.manager.namespace('appsec')
const ALLOW = 'ALLOW'

class AIGuardAbortError extends Error {
constructor (reason) {
constructor (reason, tags) {
super(reason)
this.name = 'AIGuardAbortError'
this.reason = reason
this.tags = tags
}
}

Expand Down Expand Up @@ -77,20 +79,26 @@ class AIGuard extends NoopAIGuard {
this.#initialized = true
}

#truncate (messages) {
/**
* Returns a safe copy of the messages to be serialized into the meta struct.
*
* - Clones each message so callers cannot mutate the data set in the meta struct.
* - Truncates the list of messages and `content` fields emitting metrics accordingly.
*/
#buildMessagesForMetaStruct (messages) {
const size = Math.min(messages.length, this.#maxMessagesLength)
if (messages.length > size) {
appsecMetrics.count(AI_GUARD_TELEMETRY_TRUNCATED, { type: 'messages' }).inc(1)
}
const result = messages.slice(-size)

const result = []
let contentTruncated = false
for (let i = 0; i < size; i++) {
const message = result[i]
for (let i = messages.length - size; i < messages.length; i++) {
const message = rfdc(messages[i])
if (message.content?.length > this.#maxContentSize) {
contentTruncated = true
result[i] = { ...message, content: message.content.slice(0, this.#maxContentSize) }
message.content = message.content.slice(0, this.#maxContentSize)
}
result.push(message)
}
if (contentTruncated) {
appsecMetrics.count(AI_GUARD_TELEMETRY_TRUNCATED, { type: 'content' }).inc(1)
Expand Down Expand Up @@ -139,7 +147,7 @@ class AIGuard extends NoopAIGuard {
}
}
const metaStruct = {
messages: this.#truncate(messages)
messages: this.#buildMessagesForMetaStruct(messages)
}
span.meta_struct = {
[AI_GUARD_META_STRUCT_KEY]: metaStruct
Expand Down Expand Up @@ -192,9 +200,9 @@ class AIGuard extends NoopAIGuard {
}
if (shouldBlock) {
span.setTag(AI_GUARD_BLOCKED_TAG_KEY, 'true')
throw new AIGuardAbortError(reason)
throw new AIGuardAbortError(reason, tags)
}
return { action, reason }
return { action, reason, tags }
})
}
}
Expand Down
29 changes: 28 additions & 1 deletion packages/dd-trace/test/aiguard/index.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,15 @@ describe('AIGuard SDK', () => {
if (shouldBlock) {
await rejects(
() => aiguard.evaluate(messages, { block: true }),
err => err.name === 'AIGuardAbortError' && err.reason === reason
err => err.name === 'AIGuardAbortError' && err.reason === reason && err.tags === tags
)
} else {
const evaluation = await aiguard.evaluate(messages, { block: true })
assert.strictEqual(evaluation.action, action)
assert.strictEqual(evaluation.reason, reason)
if (tags) {
assert.strictEqual(evaluation.tags, tags)
}
}

assertTelemetry('ai_guard.requests', { error: false, action, block: shouldBlock })
Expand Down Expand Up @@ -302,6 +305,30 @@ describe('AIGuard SDK', () => {
)
})

it('test message immutability', async () => {
const messages = [{
role: 'assistant',
tool_calls: [{ id: 'call_1', function: { name: 'shell', arguments: '{"cmd": "ls -lah"}' } }]
}]
mockFetch({
body: { data: { attributes: { action: 'ALLOW', reason: 'OK', is_blocking_enabled: false } } }
})

await tracer.trace('test', async () => {
await aiguard.evaluate(messages)
// update messages before flushing
messages[0].tool_calls.push({ id: 'call_2', function: { name: 'shell', arguments: '{"cmd": "rm -rf"}' } })
messages.push({ role: 'tool', tool_call_id: 'call_1', content: 'dir1, dir2, dir3' })
})

await agent.assertSomeTraces(traces => {
const span = traces[0][1] // second span in the trace
const metaStruct = msgpack.decode(span.meta_struct.ai_guard)
assert.equal(metaStruct.messages.length, 1)
assert.equal(metaStruct.messages[0].tool_calls.length, 1)
})
})

it('test missing required fields uses noop as default', async () => {
const client = new AIGuard(tracer, { aiguard: { endpoint: 'http://aiguard' } })
const result = await client.evaluate(toolCall)
Expand Down