diff --git a/docs/my-website/docs/proxy/guardrails/custom_guardrail.md b/docs/my-website/docs/proxy/guardrails/custom_guardrail.md
index 365fdf81aa58..8b91833dd178 100644
--- a/docs/my-website/docs/proxy/guardrails/custom_guardrail.md
+++ b/docs/my-website/docs/proxy/guardrails/custom_guardrail.md
@@ -6,7 +6,7 @@ import TabItem from '@theme/TabItem';
Use this if you want to write code to run a custom guardrail
-## Quick Start
+## Quick Start
### 1. Write a `CustomGuardrail` Class
@@ -36,7 +36,7 @@ class myCustomGuardrail(CustomGuardrail):
async def apply_guardrail(
self,
text: str, # IMPORTANT: This is the text to check against your guardrail rules. It's extracted from the request or response across all LLM call types.
- language: Optional[str] = None, # ignore
+ language: Optional[str] = None, # ignore
entities: Optional[List[PiiEntityType]] = None, # ignore
request_data: Optional[dict] = None, # ignore
) -> str:
@@ -46,27 +46,27 @@ class myCustomGuardrail(CustomGuardrail):
Return the text (optionally modified) to allow it through.
"""
result = await self._check_with_api(text, request_data)
-
+
if result.get("action") == "BLOCK":
raise Exception(f"Content blocked: {result.get('reason', 'Policy violation')}")
-
+
return text
async def _check_with_api(self, text: str, request_data: Optional[dict]) -> dict:
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback)
-
+
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
-
+
response = await async_client.post(
f"{self.api_base}/check",
headers=headers,
json={"text": text},
timeout=5,
)
-
+
response.raise_for_status()
return response.json()
```
@@ -103,8 +103,8 @@ model_list:
guardrails:
- guardrail_name: "my-custom-guardrail"
litellm_params:
- guardrail: custom_guardrail.myCustomGuardrail # π Key change
- mode: "during_call" # runs apply_guardrail method
+ guardrail: custom_guardrail.myCustomGuardrail # π Key change
+ mode: "during_call" # runs apply_guardrail method
api_key: os.environ/MY_GUARDRAIL_API_KEY
api_base: https://api.myguardrail.com
```
@@ -127,20 +127,20 @@ guardrails:
- guardrail_name: "custom-pre-guard"
litellm_params:
guardrail: custom_guardrail.myCustomGuardrail
- mode: "pre_call" # runs async_pre_call_hook
+ mode: "pre_call" # runs async_pre_call_hook
- guardrail_name: "custom-during-guard"
litellm_params:
- guardrail: custom_guardrail.myCustomGuardrail
- mode: "during_call" # runs async_moderation_hook
+ guardrail: custom_guardrail.myCustomGuardrail
+ mode: "during_call" # runs async_moderation_hook
- guardrail_name: "custom-post-guard"
litellm_params:
guardrail: custom_guardrail.myCustomGuardrail
- mode: "post_call" # runs async_post_call_success_hook
+ mode: "post_call" # runs async_post_call_success_hook
```
-### 3. Start LiteLLM Gateway
+### 3. Start LiteLLM Gateway
@@ -149,7 +149,6 @@ Mount your `custom_guardrail.py` on the LiteLLM Docker container
This mounts your `custom_guardrail.py` file from your local directory to the `/app` directory in the Docker container, making it accessible to the LiteLLM Gateway.
-
```shell
docker run -d \
-p 4000:4000 \
@@ -167,7 +166,6 @@ docker run -d \
-
```shell
litellm --config config.yaml --detailed_debug
```
@@ -176,7 +174,7 @@ litellm --config config.yaml --detailed_debug
-### 4. Test it
+### 4. Test it
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
@@ -265,6 +263,36 @@ curl -i -X POST http://localhost:4000/v1/chat/completions \
}'
```
+Expected response after pre-guard
+
+```json
+{
+ "id": "chatcmpl-9zREDkBIG20RJB4pMlyutmi1hXQWc",
+ "choices": [
+ {
+ "finish_reason": "stop",
+ "index": 0,
+ "message": {
+ "content": "It looks like you've chosen a string of asterisks. This could be a way to censor or hide certain text. However, without more context, I can't provide a specific word or phrase. If there's something specific you'd like me to say or if you need help with a topic, feel free to let me know!",
+ "role": "assistant",
+ "tool_calls": null,
+ "function_call": null
+ }
+ }
+ ],
+ "created": 1724429701,
+ "model": "gpt-4o-2024-05-13",
+ "object": "chat.completion",
+ "system_fingerprint": "fp_3aa7262c27",
+ "usage": {
+ "completion_tokens": 65,
+ "prompt_tokens": 14,
+ "total_tokens": 79
+ },
+ "service_tier": null
+}
+```
+
@@ -288,6 +316,8 @@ curl -i http://localhost:4000/v1/chat/completions \
#### Test `"custom-during-guard"`
+**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
+
@@ -345,6 +375,8 @@ curl -i http://localhost:4000/v1/chat/completions \
#### Test `"custom-post-guard"`
+**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
+
@@ -413,7 +445,6 @@ curl -i -X POST http://localhost:4000/v1/chat/completions \
:::
-
Use this to pass additional parameters to the guardrail API call. e.g. things like success threshold
1. Use `get_guardrail_dynamic_request_body_params`
@@ -483,6 +514,7 @@ response = client.chat.completions.create(
}
)
```
+
@@ -507,22 +539,111 @@ curl 'http://0.0.0.0:4000/chat/completions' \
]
}'
```
+
The `get_guardrail_dynamic_request_body_params` method will return:
+
```json
{
- "success_threshold": 0.9
+ "success_threshold": 0.9
}
```
+---
+
+## β¨ Pass custom HTTP headers to guardrail API
+
+You can pass custom HTTP headers that will be sent to the guardrail API endpoint at runtime by including them in your client request. This allows you to dynamically add headers like authentication tokens, tracking IDs, or override default headers without modifying your configuration.
+
+### How it works
+
+1. Use `get_guardrail_custom_headers` in your guardrail implementation
+
+`get_guardrail_custom_headers` is a method of the `litellm.integrations.custom_guardrail.CustomGuardrail` class that extracts custom headers from the incoming HTTP request.
+
+```python
+from litellm.integrations.custom_guardrail import CustomGuardrail
+
+class myCustomGuardrail(CustomGuardrail):
+ async def async_pre_call_hook(self, user_api_key_dict, cache, data, call_type):
+ # Get custom headers from request
+ custom_headers = self.get_guardrail_custom_headers(request_data=data)
+ # custom_headers will contain: {"X-Custom-Header": "value", "X-Tracking-ID": "123"}
+
+ # Use custom headers when making API calls
+ headers = {
+ "Authorization": f"Bearer {self.api_key}",
+ **custom_headers # Merge custom headers
+ }
+ # ... make API call with headers
+ return data
+```
+
+2. Pass headers in your API requests:
+
+Pass a header named `X-LiteLLM-Guardrail-{guardrail_name}` (where `{guardrail_name}` matches your guardrail's name) with a JSON object containing the HTTP headers to send to the guardrail API.
+
+
+
+
+```bash
+curl -X POST http://localhost:4000/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -H "X-LiteLLM-Guardrail-my-guardrail: {\"X-Custom-Header\": \"custom-value\", \"X-Tracking-ID\": \"track-123\"}" \
+ -H "Authorization: Bearer sk-1234" \
+ -d '{
+ "model": "gpt-4",
+ "messages": [{"role": "user", "content": "Hello"}]
+ }'
+```
+
+
+
+
+```python
+import openai
+import json
+
+client = openai.OpenAI(
+ api_key="anything",
+ base_url="http://localhost:4000"
+)
+
+response = client.chat.completions.create(
+ model="gpt-4",
+ messages=[{"role": "user", "content": "Hello"}],
+ extra_headers={
+ "X-LiteLLM-Guardrail-my-guardrail": json.dumps({
+ "X-Custom-Header": "custom-value",
+ "X-Tracking-ID": "track-123"
+ })
+ }
+)
+```
+
+
+
+
+### Notes
+
+- Header names are case-insensitive (HTTP standard)
+- The header value must be valid JSON
+- Invalid JSON will be ignored and the guardrail will use default headers only
+- All custom headers are merged with default headers and sent to the guardrail API
+- This feature is available for all guardrails that inherit from `CustomGuardrail`
+- Use `extra_body` in request metadata for payload parameters (not HTTP headers)
+
+---
+
## Advanced: Individual Event Hooks
Pro: More flexibility
Con: You need to implement this for each LLM call type (chat completions, text completions, embeddings, image generation, moderation, audio transcription, pass through endpoint, rerank, etc. )
For more fine-grained control over when and how your guardrail runs, you can implement individual event hooks. This gives you flexibility to:
+
- Modify inputs before the LLM call
- Run checks in parallel with the LLM call (lower latency)
- Validate or modify outputs after the LLM call
@@ -650,14 +771,13 @@ class myCustomGuardrail(CustomGuardrail):
## **CustomGuardrail methods**
-| Component | Description | Optional | Checked Data | Can Modify Input | Can Modify Output | Can Fail Call |
-|-----------|-------------|----------|--------------|------------------|-------------------|----------------|
-| `apply_guardrail` | Simple method to check and optionally modify text | β
| INPUT or OUTPUT | β
| β
| β
|
-| `async_pre_call_hook` | A hook that runs before the LLM API call | β
| INPUT | β
| β | β
|
-| `async_moderation_hook` | A hook that runs during the LLM API call| β
| INPUT | β | β | β
|
-| `async_post_call_success_hook` | A hook that runs after a successful LLM API call| β
| INPUT, OUTPUT | β | β
| β
|
-| `async_post_call_streaming_iterator_hook` | A hook that processes streaming responses | β
| OUTPUT | β | β
| β
|
-
+| Component | Description | Optional | Checked Data | Can Modify Input | Can Modify Output | Can Fail Call |
+| ----------------------------------------- | ------------------------------------------------- | -------- | --------------- | ---------------- | ----------------- | ------------- |
+| `apply_guardrail` | Simple method to check and optionally modify text | β
| INPUT or OUTPUT | β
| β
| β
|
+| `async_pre_call_hook` | A hook that runs before the LLM API call | β
| INPUT | β
| β | β
|
+| `async_moderation_hook` | A hook that runs during the LLM API call | β
| INPUT | β | β | β
|
+| `async_post_call_success_hook` | A hook that runs after a successful LLM API call | β
| INPUT, OUTPUT | β | β
| β
|
+| `async_post_call_streaming_iterator_hook` | A hook that processes streaming responses | β
| OUTPUT | β | β
| β
|
## Frequently Asked Questions
@@ -669,10 +789,10 @@ class myCustomGuardrail(CustomGuardrail):
**A.** The main one you should care about is 'text' - this is what you'll want to send to your api for verification - See implementation [here](https://github.com/BerriAI/litellm/blob/0292b84dc47473ddeff29bd5a86f529bc523034b/litellm/llms/anthropic/chat/guardrail_translation/handler.py#L102)
-**Q. Is this function agnostic to the LLM provider? Meaning does it pass the same values for OpenAI and Anthropic for example?
+**Q. Is this function agnostic to the LLM provider? Meaning does it pass the same values for OpenAI and Anthropic for example?**
**A.** Yes
**Q. How do I know if my guardrail is running?**
-**A.** If you implement `apply_guardrail`, you can query the guardrail directly via [the `/apply_guardrail` API](../../apply_guardrail).
\ No newline at end of file
+**A.** If you implement `apply_guardrail`, you can query the guardrail directly via [the `/apply_guardrail` API](../../apply_guardrail).
diff --git a/docs/my-website/docs/proxy/guardrails/grayswan.md b/docs/my-website/docs/proxy/guardrails/grayswan.md
index b510c870a1e3..acceaeb2ec4a 100644
--- a/docs/my-website/docs/proxy/guardrails/grayswan.md
+++ b/docs/my-website/docs/proxy/guardrails/grayswan.md
@@ -27,22 +27,22 @@ Add a guardrail entry that references the Gray Swan integration. Below is a bala
```yaml
model_list:
- - model_name: openai/gpt-4.1-mini
+ - model_name: openai/gpt-5
litellm_params:
- model: openai/gpt-4.1-mini
+ model: openai/gpt-5
api_key: os.environ/OPENAI_API_KEY
guardrails:
- guardrail_name: "cygnal-monitor"
litellm_params:
guardrail: grayswan
- mode: [pre_call, post_call] # monitor both input and output
+ mode: [pre_call, post_call] # monitor both input and output
api_key: os.environ/GRAYSWAN_API_KEY
- api_base: os.environ/GRAYSWAN_API_BASE # optional
+ api_base: os.environ/GRAYSWAN_API_BASE # optional
optional_params:
- on_flagged_action: monitor # or "block"
- violation_threshold: 0.5 # score >= threshold is flagged
- reasoning_mode: hybrid # off | hybrid | thinking
+ on_flagged_action: monitor # or "block"
+ violation_threshold: 0.5 # score >= threshold is flagged
+ reasoning_mode: "off" # off | hybrid | thinking (ensure quotes are used for "off", otherwise it will get parsed as `false`)
categories:
safety: "Detect jailbreaks and policy violations"
policy_id: "your-cygnal-policy-id"
@@ -67,11 +67,11 @@ litellm --config config.yaml --port 4000
Gray Swan can run during `pre_call`, `during_call`, and `post_call` stages. Combine modes based on your latency and coverage requirements.
-| Mode | When it Runs | Protects | Typical Use Case |
-|--------------|-------------------|-----------------------|------------------|
-| `pre_call` | Before LLM call | User input only | Block prompt injection before it reaches the model |
-| `during_call`| Parallel to call | User input only | Low-latency monitoring without blocking |
-| `post_call` | After response | Full conversation | Scan output for policy violations, leaked secrets, or IPI |
+| Mode | When it Runs | Protects | Typical Use Case |
+| ------------- | ---------------- | ----------------- | --------------------------------------------------------- |
+| `pre_call` | Before LLM call | User input only | Block prompt injection before it reaches the model |
+| `during_call` | Parallel to call | User input only | Low-latency monitoring without blocking |
+| `post_call` | After response | Full conversation | Scan output for policy violations, leaked secrets, or IPI |
@@ -138,12 +138,24 @@ Provides the strongest enforcement by inspecting both prompts and responses.
## Configuration Reference
-| Parameter | Type | Description |
-|---------------------------------------|-----------------|-------------|
-| `api_key` | string | Gray Swan Cygnal API key. Reads from `GRAYSWAN_API_KEY` if omitted. |
-| `mode` | string or list | Guardrail stages (`pre_call`, `during_call`, `post_call`). |
-| `optional_params.on_flagged_action` | string | `monitor` (log only) or `block` (raise `HTTPException`). |
-| `.optional_params.violation_threshold`| number (0-1) | Scores at or above this value are considered violations. |
-| `optional_params.reasoning_mode` | string | `off`, `hybrid`, or `thinking`. Enables Cygnalβs reasoning capabilities. |
-| `optional_params.categories` | object | Map of custom category names to descriptions. |
-| `optional_params.policy_id` | string | Gray Swan policy identifier. |
+| Parameter | Type | Description |
+| -------------------------------------- | -------------- | ------------------------------------------------------------------------ |
+| `api_key` | string | Gray Swan Cygnal API key. Reads from `GRAYSWAN_API_KEY` if omitted. |
+| `mode` | string or list | Guardrail stages (`pre_call`, `during_call`, `post_call`). |
+| `optional_params.on_flagged_action` | string | `monitor` (log only) or `block` (raise `HTTPException`). |
+| `.optional_params.violation_threshold` | number (0-1) | Scores at or above this value are considered violations. |
+| `optional_params.reasoning_mode` | string | `off`, `hybrid`, or `thinking`. Enables Cygnal's reasoning capabilities. |
+| `optional_params.categories` | object | Map of custom category names to descriptions. |
+| `optional_params.policy_id` | string | Gray Swan policy identifier. |
+
+---
+
+## Dynamic Configuration
+
+### Custom HTTP Headers
+
+You can pass custom HTTP headers to the Gray Swan API endpoint at runtime. All headers passed via `X-LiteLLM-Guardrail-{guardrail_name}` will be forwarded directly to the Gray Swan API.
+
+For detailed documentation and examples, see [Pass custom HTTP headers to guardrail API](../custom_guardrail.md#-pass-custom-http-headers-to-guardrail-api).
+
+**Note:** For payload parameters like `policy_id`, `reasoning_mode`, and `categories`, use the `extra_body` parameter in the request metadata instead (see [Pass additional parameters to guardrail](../custom_guardrail.md#-pass-additional-parameters-to-guardrail)).
diff --git a/litellm/batches/main.py b/litellm/batches/main.py
index 5279dd70bc42..677f12cd7679 100644
--- a/litellm/batches/main.py
+++ b/litellm/batches/main.py
@@ -18,7 +18,6 @@
import httpx
from openai.types.batch import BatchRequestCounts
-from openai.types.batch import Metadata as BatchMetadata
import litellm
from litellm._logging import verbose_logger
@@ -862,7 +861,6 @@ def cancel_batch(
LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
"""
try:
-
try:
if model is not None:
_, custom_llm_provider, _, _ = get_llm_provider(
diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py
index b50d05ed2ec6..e1e47420ce08 100644
--- a/litellm/integrations/custom_guardrail.py
+++ b/litellm/integrations/custom_guardrail.py
@@ -1,3 +1,4 @@
+import json
from datetime import datetime
from typing import Any, Dict, List, Optional, Type, Union, get_args
@@ -279,7 +280,7 @@ def should_run_guardrail(
data, self.event_hook
)
if result is not None:
- return result
+ return result
return True
def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool:
@@ -333,6 +334,83 @@ def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
return {}
+ def get_guardrail_custom_headers(self, request_data: dict) -> dict:
+ """
+ Returns custom headers from HTTP request header for the Guardrail API call.
+
+ Extracts headers matching pattern `X-LiteLLM-Guardrail-{guardrail_name}` from the
+ incoming HTTP request. The header value should be a JSON object string containing
+ multiple header values.
+
+ Example:
+ Header: `X-LiteLLM-Guardrail-grayswan: {"policy_id": "xyz", "reasoning_mode": "hybrid"}`
+ Will return: `{"policy_id": "xyz", "reasoning_mode": "hybrid"}`
+
+ Priority order:
+ 1. Custom headers from HTTP request header (this method)
+ 2. Dynamic body params from `extra_body` in request metadata
+ 3. Config values from `config.yaml` or initialization
+
+ Args:
+ request_data: The original `request_data` passed to LiteLLM Proxy
+
+ Returns:
+ dict: Dictionary of custom headers, empty dict if not found or invalid
+ """
+ if not self.guardrail_name:
+ return {}
+
+ # Get headers from proxy_server_request
+ proxy_server_request = request_data.get("proxy_server_request", {})
+ headers = proxy_server_request.get("headers", {})
+
+ if not headers:
+ return {}
+
+ # Build expected header name (case-insensitive matching)
+ expected_header_name = f"x-litellm-guardrail-{self.guardrail_name.lower()}"
+
+ # Find matching header (HTTP headers are case-insensitive)
+ header_value = None
+ for header_name, value in headers.items():
+ if header_name.lower() == expected_header_name:
+ header_value = value
+ break
+
+ if header_value is None:
+ return {}
+
+ # Parse header value
+ try:
+ # Handle both string and dict formats
+ if isinstance(header_value, str):
+ parsed = json.loads(header_value)
+ elif isinstance(header_value, dict):
+ parsed = header_value
+ else:
+ verbose_logger.warning(
+ f"Guardrail {self.guardrail_name}: Invalid header value type for {expected_header_name}, expected string or dict"
+ )
+ return {}
+
+ if not isinstance(parsed, dict):
+ verbose_logger.warning(
+ f"Guardrail {self.guardrail_name}: Header {expected_header_name} must contain a JSON object"
+ )
+ return {}
+
+ return parsed
+ except json.JSONDecodeError as e:
+ verbose_logger.warning(
+ f"Guardrail {self.guardrail_name}: Failed to parse JSON from header {expected_header_name}: {e}"
+ )
+ return {}
+ except Exception as e:
+ verbose_logger.warning(
+ f"Guardrail {self.guardrail_name}: Error extracting custom headers: {e}"
+ )
+ return {}
+
def _validate_premium_user(self) -> bool:
"""
Returns True if the user is a premium user
diff --git a/litellm/proxy/guardrails/guardrail_hooks/grayswan/grayswan.py b/litellm/proxy/guardrails/guardrail_hooks/grayswan/grayswan.py
index fb93dec47b7a..6c7aac8a31a6 100644
--- a/litellm/proxy/guardrails/guardrail_hooks/grayswan/grayswan.py
+++ b/litellm/proxy/guardrails/guardrail_hooks/grayswan/grayswan.py
@@ -138,6 +138,7 @@ async def async_pre_call_hook(
verbose_proxy_logger.debug("Gray Swan Guardrail: No messages in data")
return data
+ custom_headers = self.get_guardrail_custom_headers(data)
dynamic_body = self.get_guardrail_dynamic_request_body_params(data) or {}
payload = self._prepare_payload(messages, dynamic_body)
@@ -147,7 +148,7 @@ async def async_pre_call_hook(
)
return data
- await self.run_grayswan_guardrail(payload)
+ await self.run_grayswan_guardrail(payload, custom_headers)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
@@ -184,6 +185,7 @@ async def async_moderation_hook(
verbose_proxy_logger.debug("Gray Swan Guardrail: No messages in data")
return data
+ custom_headers = self.get_guardrail_custom_headers(data)
dynamic_body = self.get_guardrail_dynamic_request_body_params(data) or {}
payload = self._prepare_payload(messages, dynamic_body)
@@ -193,7 +195,7 @@ async def async_moderation_hook(
)
return data
- await self.run_grayswan_guardrail(payload)
+ await self.run_grayswan_guardrail(payload, custom_headers)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
@@ -231,6 +233,7 @@ async def async_post_call_success_hook(
)
return response
+ custom_headers = self.get_guardrail_custom_headers(data)
dynamic_body = self.get_guardrail_dynamic_request_body_params(data) or {}
payload = self._prepare_payload(response_messages, dynamic_body)
@@ -240,7 +243,7 @@ async def async_post_call_success_hook(
)
return response
- await self.run_grayswan_guardrail(payload)
+ await self.run_grayswan_guardrail(payload, custom_headers)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
)
@@ -250,8 +253,10 @@ async def async_post_call_success_hook(
# Core GraySwan interaction
# ------------------------------------------------------------------
- async def run_grayswan_guardrail(self, payload: dict):
- headers = self._prepare_headers()
+ async def run_grayswan_guardrail(
+ self, payload: dict, custom_headers: Optional[Dict[str, Any]] = None
+ ):
+ headers = self._prepare_headers(custom_headers)
try:
response = await self.async_handler.post(
@@ -279,28 +284,48 @@ async def run_grayswan_guardrail(self, payload: dict):
# Helpers
# ------------------------------------------------------------------
- def _prepare_headers(self) -> Dict[str, str]:
- return {
+ def _prepare_headers(
+ self, custom_headers: Optional[Dict[str, Any]] = None
+ ) -> Dict[str, str]:
+ headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"grayswan-api-key": self.api_key,
}
+ # Merge all custom headers if provided
+ if custom_headers:
+ for key, value in custom_headers.items():
+ if value is not None:
+ headers[key] = str(value)
+
+ return headers
+
def _prepare_payload(
- self, messages: list[dict], dynamic_body: dict
+ self,
+ messages: list[dict],
+ dynamic_body: dict,
+ custom_headers: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
payload: Dict[str, Any] = {}
payload["messages"] = messages
- categories = dynamic_body.get("categories") or self.categories
+ # Priority: dynamic_body > config
+ categories = dynamic_body.get("categories") if dynamic_body else None
+ if not categories:
+ categories = self.categories
if categories:
payload["categories"] = categories
- policy_id = dynamic_body.get("policy_id") or self.policy_id
+ policy_id = dynamic_body.get("policy_id") if dynamic_body else None
+ if not policy_id:
+ policy_id = self.policy_id
if policy_id:
payload["policy_id"] = policy_id
- reasoning_mode = dynamic_body.get("reasoning_mode") or self.reasoning_mode
+ reasoning_mode = dynamic_body.get("reasoning_mode") if dynamic_body else None
+ if not reasoning_mode:
+ reasoning_mode = self.reasoning_mode
if reasoning_mode:
payload["reasoning_mode"] = reasoning_mode
diff --git a/tests/test_litellm/caching/test_s3_cache.py b/tests/test_litellm/caching/test_s3_cache.py
index 9c902768bfcc..c418e80e2882 100644
--- a/tests/test_litellm/caching/test_s3_cache.py
+++ b/tests/test_litellm/caching/test_s3_cache.py
@@ -266,13 +266,18 @@ async def test_s3_cache_async_set_cache_pipeline(mock_s3_dependencies):
# Should have called put_object 3 times
assert cache.s3_client.put_object.call_count == 3
- # Verify each call
+ # Verify each call (order may vary due to concurrent execution)
calls = cache.s3_client.put_object.call_args_list
- for i, (key, value) in enumerate(cache_list):
- call_args = calls[i][1]
+ expected_items = {key: json.dumps(value) for key, value in cache_list}
+ actual_items = {}
+
+ for call in calls:
+ call_args = call[1]
assert call_args["Bucket"] == "test-bucket"
- assert call_args["Key"] == key
- assert call_args["Body"] == json.dumps(value)
+ actual_items[call_args["Key"]] = call_args["Body"]
+
+ # Verify all expected keys and values are present (order doesn't matter)
+ assert actual_items == expected_items
@pytest.mark.asyncio
@@ -294,11 +299,18 @@ async def test_s3_cache_concurrent_async_operations(mock_s3_dependencies):
assert cache.s3_client.put_object.call_count == 5
# Verify each call had correct parameters
+ # Collect all keys from calls (order may vary due to concurrency)
calls = cache.s3_client.put_object.call_args_list
- for i, call in enumerate(calls):
+ actual_keys = set()
+ expected_keys = {f"concurrent_key_{i}" for i in range(5)}
+
+ for call in calls:
call_args = call[1]
assert call_args["Bucket"] == "test-bucket"
- assert f"concurrent_key_{i}" == call_args["Key"]
+ actual_keys.add(call_args["Key"])
+
+ # Verify all expected keys are present (order doesn't matter for concurrent operations)
+ assert actual_keys == expected_keys
@pytest.mark.asyncio
diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_grayswan.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_grayswan.py
index d6bd0a251b5e..a3fbdb0bdbbd 100644
--- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_grayswan.py
+++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_grayswan.py
@@ -22,7 +22,9 @@ def grayswan_guardrail() -> GraySwanGuardrail:
)
-def test_prepare_payload_uses_dynamic_overrides(grayswan_guardrail: GraySwanGuardrail) -> None:
+def test_prepare_payload_uses_dynamic_overrides(
+ grayswan_guardrail: GraySwanGuardrail,
+) -> None:
messages = [{"role": "user", "content": "hello"}]
dynamic_body = {
"categories": {"custom": "override"},
@@ -30,7 +32,7 @@ def test_prepare_payload_uses_dynamic_overrides(grayswan_guardrail: GraySwanGuar
"reasoning_mode": "thinking",
}
- payload = grayswan_guardrail._prepare_payload(messages, dynamic_body)
+ payload = grayswan_guardrail._prepare_payload(messages, dynamic_body, None)
assert payload["messages"] == messages
assert payload["categories"] == {"custom": "override"}
@@ -38,18 +40,24 @@ def test_prepare_payload_uses_dynamic_overrides(grayswan_guardrail: GraySwanGuar
assert payload["reasoning_mode"] == "thinking"
-def test_prepare_payload_falls_back_to_guardrail_defaults(grayswan_guardrail: GraySwanGuardrail) -> None:
+def test_prepare_payload_falls_back_to_guardrail_defaults(
+ grayswan_guardrail: GraySwanGuardrail,
+) -> None:
messages = [{"role": "user", "content": "hello"}]
- payload = grayswan_guardrail._prepare_payload(messages, {})
+ payload = grayswan_guardrail._prepare_payload(messages, {}, None)
assert payload["categories"] == {"safety": "general policy"}
assert payload["policy_id"] == "default-policy"
assert payload["reasoning_mode"] == "hybrid"
-def test_process_response_does_not_block_under_threshold(grayswan_guardrail: GraySwanGuardrail) -> None:
- grayswan_guardrail._process_grayswan_response({"violation": 0.3, "violated_rules": []})
+def test_process_response_does_not_block_under_threshold(
+ grayswan_guardrail: GraySwanGuardrail,
+) -> None:
+ grayswan_guardrail._process_grayswan_response(
+ {"violation": 0.3, "violated_rules": []}
+ )
def test_process_response_blocks_when_threshold_exceeded() -> None:
@@ -85,12 +93,16 @@ def __init__(self, payload: dict):
self.calls: list[dict] = []
async def post(self, *, url: str, headers: dict, json: dict, timeout: float):
- self.calls.append({"url": url, "headers": headers, "json": json, "timeout": timeout})
+ self.calls.append(
+ {"url": url, "headers": headers, "json": json, "timeout": timeout}
+ )
return _DummyResponse(self.payload)
@pytest.mark.asyncio
-async def test_run_guardrail_posts_payload(monkeypatch, grayswan_guardrail: GraySwanGuardrail) -> None:
+async def test_run_guardrail_posts_payload(
+ monkeypatch, grayswan_guardrail: GraySwanGuardrail
+) -> None:
dummy_client = _DummyClient({"violation": 0.1})
grayswan_guardrail.async_handler = dummy_client
@@ -110,7 +122,9 @@ def fake_process(response_json: dict) -> None:
@pytest.mark.asyncio
-async def test_run_guardrail_raises_api_error(grayswan_guardrail: GraySwanGuardrail) -> None:
+async def test_run_guardrail_raises_api_error(
+ grayswan_guardrail: GraySwanGuardrail,
+) -> None:
class _FailingClient:
async def post(self, **_kwargs):
raise RuntimeError("boom")
@@ -121,3 +135,167 @@ async def post(self, **_kwargs):
with pytest.raises(GraySwanGuardrailAPIError):
await grayswan_guardrail.run_grayswan_guardrail(payload)
+
+
+def test_get_guardrail_custom_headers_with_valid_json(
+ grayswan_guardrail: GraySwanGuardrail,
+) -> None:
+ """Test extraction of custom headers from request data with valid JSON."""
+ import json
+
+ request_data = {
+ "proxy_server_request": {
+ "headers": {
+ "x-litellm-guardrail-grayswan-test": json.dumps(
+ {"policy_id": "header-policy-123", "reasoning_mode": "thinking"}
+ )
+ }
+ }
+ }
+
+ custom_headers = grayswan_guardrail.get_guardrail_custom_headers(request_data)
+
+ assert custom_headers == {
+ "policy_id": "header-policy-123",
+ "reasoning_mode": "thinking",
+ }
+
+
+def test_get_guardrail_custom_headers_with_dict_value(
+ grayswan_guardrail: GraySwanGuardrail,
+) -> None:
+ """Test extraction of custom headers when header value is already a dict."""
+ request_data = {
+ "proxy_server_request": {
+ "headers": {
+ "x-litellm-guardrail-grayswan-test": {"policy_id": "header-policy-456"}
+ }
+ }
+ }
+
+ custom_headers = grayswan_guardrail.get_guardrail_custom_headers(request_data)
+
+ assert custom_headers == {"policy_id": "header-policy-456"}
+
+
+def test_get_guardrail_custom_headers_with_invalid_json(
+ grayswan_guardrail: GraySwanGuardrail,
+) -> None:
+ """Test that invalid JSON in header gracefully returns empty dict."""
+ request_data = {
+ "proxy_server_request": {
+ "headers": {"x-litellm-guardrail-grayswan-test": "invalid json {"}
+ }
+ }
+
+ custom_headers = grayswan_guardrail.get_guardrail_custom_headers(request_data)
+
+ assert custom_headers == {}
+
+
+def test_get_guardrail_custom_headers_case_insensitive(
+ grayswan_guardrail: GraySwanGuardrail,
+) -> None:
+ """Test that header matching is case-insensitive."""
+ import json
+
+ request_data = {
+ "proxy_server_request": {
+ "headers": {
+ "X-LiteLLM-Guardrail-GraySwan-Test": json.dumps(
+ {"policy_id": "case-test"}
+ )
+ }
+ }
+ }
+
+ custom_headers = grayswan_guardrail.get_guardrail_custom_headers(request_data)
+
+ assert custom_headers == {"policy_id": "case-test"}
+
+
+def test_get_guardrail_custom_headers_missing_header(
+ grayswan_guardrail: GraySwanGuardrail,
+) -> None:
+ """Test that missing header returns empty dict."""
+ request_data = {"proxy_server_request": {"headers": {}}}
+
+ custom_headers = grayswan_guardrail.get_guardrail_custom_headers(request_data)
+
+ assert custom_headers == {}
+
+
+def test_prepare_payload_ignores_custom_headers_for_config_params(
+ grayswan_guardrail: GraySwanGuardrail,
+) -> None:
+ """Test that custom headers are NOT used for config params like policy_id and reasoning_mode."""
+ messages = [{"role": "user", "content": "hello"}]
+ custom_headers = {"policy_id": "header-policy", "reasoning_mode": "thinking"}
+ dynamic_body = {"policy_id": "dynamic-policy", "reasoning_mode": "hybrid"}
+
+ payload = grayswan_guardrail._prepare_payload(
+ messages, dynamic_body, custom_headers
+ )
+
+ # Custom headers should be ignored - dynamic_body takes priority
+ assert (
+ payload["policy_id"] == "dynamic-policy"
+ ) # Dynamic body wins, not custom header
+ assert payload["reasoning_mode"] == "hybrid" # Dynamic body wins, not custom header
+
+
+def test_prepare_payload_priority_dynamic_body_over_config(
+ grayswan_guardrail: GraySwanGuardrail,
+) -> None:
+ """Test that dynamic_body takes priority over config."""
+ messages = [{"role": "user", "content": "hello"}]
+ custom_headers = {}
+ dynamic_body = {"policy_id": "dynamic-policy"}
+
+ payload = grayswan_guardrail._prepare_payload(
+ messages, dynamic_body, custom_headers
+ )
+
+ assert payload["policy_id"] == "dynamic-policy" # Dynamic body wins
+ assert payload["reasoning_mode"] == "hybrid" # Falls back to config
+
+
+def test_prepare_payload_priority_config_fallback(
+ grayswan_guardrail: GraySwanGuardrail,
+) -> None:
+ """Test that config values are used when dynamic_body is empty."""
+ messages = [{"role": "user", "content": "hello"}]
+ custom_headers = {}
+ dynamic_body = {}
+
+ payload = grayswan_guardrail._prepare_payload(
+ messages, dynamic_body, custom_headers
+ )
+
+ assert payload["policy_id"] == "default-policy" # Config value
+ assert payload["reasoning_mode"] == "hybrid" # Config value
+
+
+@pytest.mark.asyncio
+async def test_run_guardrail_with_custom_headers(
+ monkeypatch, grayswan_guardrail: GraySwanGuardrail
+) -> None:
+ """Test that custom HTTP headers are passed to the API request."""
+ dummy_client = _DummyClient({"violation": 0.1})
+ grayswan_guardrail.async_handler = dummy_client
+
+ def fake_process(response_json: dict) -> None:
+ pass
+
+ monkeypatch.setattr(grayswan_guardrail, "_process_grayswan_response", fake_process)
+
+ payload = {"messages": [{"role": "user", "content": "test"}]}
+ custom_headers = {
+ "custom-header": "test-value",
+ }
+
+ await grayswan_guardrail.run_grayswan_guardrail(payload, custom_headers)
+
+ # Verify custom HTTP header was included
+ assert "custom-header" in dummy_client.calls[0]["headers"]
+ assert dummy_client.calls[0]["headers"]["custom-header"] == "test-value"
diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_presidio.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_presidio.py
index 9543b61ef690..f593da0846b8 100644
--- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_presidio.py
+++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_presidio.py
@@ -544,6 +544,7 @@ async def test_presidio_sets_guardrail_information_in_request_data():
presidio = _OPTIONAL_PresidioPIIMasking(
guardrail_name="test_presidio",
output_parse_pii=True,
+ mock_testing=True,
)
request_data = {
@@ -600,6 +601,7 @@ async def test_request_data_flows_to_apply_guardrail():
presidio = _OPTIONAL_PresidioPIIMasking(
guardrail_name="test_presidio",
output_parse_pii=True,
+ mock_testing=True,
)
request_data = {