Skip to content

Commit

Permalink
feat: Add convenience function for LlamaGuard Filter (#555)
Browse files Browse the repository at this point in the history
* chore: add new filter utility for LlamaGuard

* chore: update default

* chore: fix test

* chore: cleanup

* chore: refactor

* chore: minor doc fix

* chore: adapt changes as per slack discussion

* chore: add orchestration client test

* chore: add documentation

* chore: add release note

* fix: Changes from lint

* chore: fox typo

* chore: add mutiple filters in sample code

* chore: fix typo

* Update .changeset/hip-tools-smile.md

Co-authored-by: Deeksha Sinha <[email protected]>

* review

* minor

* fix: type tests

* fix: Changes from lint

* reivew

* readme

---------

Co-authored-by: cloud-sdk-js <[email protected]>
Co-authored-by: Zhongpin Wang <[email protected]>
Co-authored-by: Deeksha Sinha <[email protected]>
  • Loading branch information
4 people authored Feb 24, 2025
1 parent b1d244c commit bc51f59
Show file tree
Hide file tree
Showing 10 changed files with 287 additions and 51 deletions.
5 changes: 5 additions & 0 deletions .changeset/hip-tools-smile.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@sap-ai-sdk/orchestration': minor
---

[New Functionality] Introduce `buildLlamaGuardFilter()` convenience function to build Llama guard filters.
56 changes: 39 additions & 17 deletions packages/orchestration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,20 +336,12 @@ Use the orchestration client with filtering to restrict content that is passed t
This feature allows filtering both the [input](https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/input-filtering) and [output](https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/output-filtering) of a model based on content safety criteria.
#### Azure Content Filter
Use `buildAzureContentSafetyFilter()` function to build an Azure content filter for both input and output.
Each category of the filter can be assigned a specific severity level, which corresponds to an Azure threshold value.
| Severity Level | Azure Threshold Value |
| ----------------------- | --------------------- |
| `ALLOW_SAFE` | 0 |
| `ALLOW_SAFE_LOW` | 2 |
| `ALLOW_SAFE_LOW_MEDIUM` | 4 |
| `ALLOW_ALL` | 6 |
The following example demonstrates how to use content filtering with the orchestration client.
See the sections below for details on the available content filters and how to build them.
```ts
import { OrchestrationClient, ContentFilters } from '@sap-ai-sdk/orchestration';
import { OrchestrationClient } from '@sap-ai-sdk/orchestration';

const llm = {
model_name: 'gpt-4o',
model_params: { max_tokens: 50, temperature: 0.1 }
Expand All @@ -358,16 +350,14 @@ const templating = {
template: [{ role: 'user', content: '{{?input}}' }]
};

const filter = buildAzureContentSafetyFilter({
Hate: 'ALLOW_SAFE_LOW',
Violence: 'ALLOW_SAFE_LOW_MEDIUM'
});
const filter = ... // Use a build function to create a content filter

const orchestrationClient = new OrchestrationClient({
llm,
templating,
filtering: {
input: {
filters: [filter]
filters: [filter] // Multiple filters can be applied
},
output: {
filters: [filter]
Expand All @@ -385,6 +375,38 @@ try {
}
```
Multiple filters can be applied at the same time for both input and output filtering.
#### Azure Content Filter
Use `buildAzureContentSafetyFilter()` function to build an Azure content filter.
Each category of the filter can be assigned a specific severity level, which corresponds to an Azure threshold value.
| Severity Level | Azure Threshold Value |
| ----------------------- | --------------------- |
| `ALLOW_SAFE` | 0 |
| `ALLOW_SAFE_LOW` | 2 |
| `ALLOW_SAFE_LOW_MEDIUM` | 4 |
| `ALLOW_ALL` | 6 |
```ts
const filter = buildAzureContentSafetyFilter({
Hate: 'ALLOW_SAFE_LOW',
Violence: 'ALLOW_SAFE_LOW_MEDIUM'
});
```
#### Llama Guard Filter
Use `buildLlamaGuardFilter()` function to build a Llama Guard content filter.
Available categories can be found with autocompletion.
Pass the categories as arguments to the function to enable them.
```ts
const filter = buildLlamaGuardFilter('hate', 'violent_crimes');
```
#### Error Handling
Both `chatCompletion()` and `getContent()` methods can throw errors.
Expand Down
2 changes: 2 additions & 0 deletions packages/orchestration/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export type {
DocumentGroundingServiceConfig,
DocumentGroundingServiceFilter,
LlmModelParams,
LlamaGuardCategory,
AzureContentFilter,
AzureFilterThreshold
} from './orchestration-types.js';
Expand All @@ -24,6 +25,7 @@ export { OrchestrationClient } from './orchestration-client.js';
export {
buildAzureContentFilter,
buildAzureContentSafetyFilter,
buildLlamaGuardFilter,
buildDocumentGroundingConfig
} from './util/index.js';

Expand Down
56 changes: 55 additions & 1 deletion packages/orchestration/src/orchestration-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ import { OrchestrationResponse } from './orchestration-response.js';
import {
constructCompletionPostRequestFromJsonModuleConfig,
constructCompletionPostRequest,
buildAzureContentSafetyFilter
buildAzureContentSafetyFilter,
buildLlamaGuardFilter
} from './util/index.js';
import type { CompletionPostResponse } from './client/api/schema/index.js';
import type {
Expand Down Expand Up @@ -206,6 +207,59 @@ describe('orchestration service client', () => {
expect(response.data).toEqual(mockResponse);
});

it('calls chatCompletion with filter configuration supplied using multiple convenience functions', async () => {
const llamaFilter = buildLlamaGuardFilter('self_harm');
const azureContentFilter = buildAzureContentSafetyFilter({
Sexual: 'ALLOW_SAFE'
});
const config: OrchestrationModuleConfig = {
llm: {
model_name: 'gpt-35-turbo-16k',
model_params: { max_tokens: 50, temperature: 0.1 }
},
templating: {
template: [
{
role: 'user',
content: 'Create {{?number}} paraphrases of {{?phrase}}'
}
]
},
filtering: {
input: {
filters: [llamaFilter, azureContentFilter]
},
output: {
filters: [llamaFilter, azureContentFilter]
}
}
};
const prompt = {
inputParams: { phrase: 'I like myself.', number: '20' }
};
const mockResponse = await parseMockResponse<CompletionPostResponse>(
'orchestration',
'orchestration-chat-completion-multiple-filter-config.json'
);

mockInference(
{
data: constructCompletionPostRequest(config, prompt)
},
{
data: mockResponse,
status: 200
},
{
url: 'inference/deployments/1234/completion'
}
);
const response = await new OrchestrationClient(config).chatCompletion(
prompt
);
expect(response.data).toEqual(mockResponse);
});

it('calls chatCompletion with filtering configuration', async () => {
const config: OrchestrationModuleConfig = {
llm: {
Expand Down
6 changes: 6 additions & 0 deletions packages/orchestration/src/orchestration-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import type {
FilteringStreamOptions,
GlobalStreamOptions,
GroundingModuleConfig,
LlamaGuard38B,
MaskingModuleConfig,
LlmModuleConfig as OriginalLlmModuleConfig,
TemplatingModuleConfig
Expand Down Expand Up @@ -201,3 +202,8 @@ export const supportedAzureFilterThresholds = {
*
*/
export type AzureFilterThreshold = keyof typeof supportedAzureFilterThresholds;

/**
* The filter categories supported for Llama guard filter.
*/
export type LlamaGuardCategory = keyof LlamaGuard38B;
33 changes: 32 additions & 1 deletion packages/orchestration/src/util/filtering.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {
buildAzureContentFilter,
buildAzureContentSafetyFilter
buildAzureContentSafetyFilter,
buildLlamaGuardFilter
} from './filtering.js';
import { constructCompletionPostRequest } from './module-config.js';
import type { OrchestrationModuleConfig } from '../orchestration-types.js';
Expand Down Expand Up @@ -211,4 +212,34 @@ describe('Content filter util', () => {
);
});
});

describe('Llama Guard filter', () => {
it('builds filter config with custom config', async () => {
const filterConfig = buildLlamaGuardFilter('elections', 'hate');
const expectedFilterConfig = {
type: 'llama_guard_3_8b',
config: {
elections: true,
hate: true
}
};
expect(filterConfig).toEqual(expectedFilterConfig);
});

it('builds filter config without duplicates', async () => {
const filterConfig = buildLlamaGuardFilter(
'non_violent_crimes',
'privacy',
'non_violent_crimes'
);
const expectedFilterConfig = {
type: 'llama_guard_3_8b',
config: {
non_violent_crimes: true,
privacy: true
}
};
expect(filterConfig).toEqual(expectedFilterConfig);
});
});
});
26 changes: 23 additions & 3 deletions packages/orchestration/src/util/filtering.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ import type {
AzureContentSafety,
AzureContentSafetyFilterConfig,
InputFilteringConfig,
LlamaGuard38BFilterConfig,
OutputFilteringConfig
} from '../client/api/schema/index.js';
import type {
AzureContentFilter,
AzureFilterThreshold
AzureFilterThreshold,
LlamaGuardCategory
} from '../orchestration-types.js';

/**
* Convenience function to create Azure content filters.
* Convenience function to build Azure content filter.
* @param filter - Filtering configuration for Azure filter. If skipped, the default Azure content filter configuration is used.
* @returns An object with the Azure filtering configuration.
* @deprecated Since 1.8.0. Use {@link buildAzureContentSafetyFilter()} instead.
Expand All @@ -33,10 +35,11 @@ export function buildAzureContentFilter(
}

/**
* Convenience function to create Azure content filters.
* Convenience function to build Azure content filter.
* @param config - Configuration for Azure content safety filter.
* If skipped, the default configuration of `ALLOW_SAFE_LOW` is used for all filter categories.
* @returns Filter config object.
* @example "buildAzureContentSafetyFilter({ Hate: 'ALLOW_SAFE', Violence: 'ALLOW_SAFE_LOW_MEDIUM' })"
*/
export function buildAzureContentSafetyFilter(
config?: AzureContentFilter
Expand All @@ -58,3 +61,20 @@ export function buildAzureContentSafetyFilter(
})
};
}

/**
* Convenience function to build Llama guard filter.
* @param categories - Categories to be enabled for filtering. Provide at least one category.
* @returns Filter config object.
* @example "buildLlamaGuardFilter('self_harm', 'hate')"
*/
export function buildLlamaGuardFilter(
...categories: [LlamaGuardCategory, ...LlamaGuardCategory[]]
): LlamaGuard38BFilterConfig {
return {
type: 'llama_guard_3_8b',
config: Object.fromEntries(
[...categories].map(category => [category, true])
)
};
}
Loading

0 comments on commit bc51f59

Please sign in to comment.