From 79489cf501ab6ba453dc775393abbb6780b0fd26 Mon Sep 17 00:00:00 2001 From: Oren Hazai Date: Tue, 16 Sep 2025 18:25:53 +0300 Subject: [PATCH 01/19] add new nemoguardrails/cache folder with lfu cache implementation (and interface) add tests for lfu cache new content safety dynamic cache + integration add stats logging remove redundant test thread safety support for content-safety caching fixed failing tests update documentation to reflect thread-safety support for cache fixes following test failures on race conditions fixes following test failures remove a test update cache interface per model config without defaults --- examples/configs/content_safety/README.md | 119 +- examples/configs/content_safety/config.yml | 15 + nemoguardrails/cache/README.md | 271 +++ nemoguardrails/cache/__init__.py | 21 + nemoguardrails/cache/interface.py | 236 +++ nemoguardrails/cache/lfu.py | 677 ++++++ .../library/content_safety/actions.py | 57 +- .../library/content_safety/manager.py | 98 + nemoguardrails/rails/llm/config.py | 64 + nemoguardrails/rails/llm/llmrails.py | 66 + tests/test_cache_lfu.py | 1861 +++++++++++++++++ 11 files changed, 3477 insertions(+), 8 deletions(-) create mode 100644 nemoguardrails/cache/README.md create mode 100644 nemoguardrails/cache/__init__.py create mode 100644 nemoguardrails/cache/interface.py create mode 100644 nemoguardrails/cache/lfu.py create mode 100644 nemoguardrails/library/content_safety/manager.py create mode 100644 tests/test_cache_lfu.py diff --git a/examples/configs/content_safety/README.md b/examples/configs/content_safety/README.md index 35a2d2a45..0645e9043 100644 --- a/examples/configs/content_safety/README.md +++ b/examples/configs/content_safety/README.md @@ -1,10 +1,117 @@ -# NemoGuard ContentSafety Usage Example +# Content Safety Configuration -This example showcases the use of NVIDIA's [NemoGuard ContentSafety model](./../../../docs/user-guides/advanced/nemoguard-contentsafety-deployment.md) for topical and dialogue moderation. +This example demonstrates how to configure content safety rails with NeMo Guardrails, including optional cache persistence. -The structure of the config folder is the following: +## Features -- `config.yml` - The config file holding all the configuration options for the model. -- `prompts.yml` - The config file holding the topical rules used for topical and dialogue moderation by the current guardrail configuration. +- **Input Safety Checks**: Validates user inputs before processing +- **Output Safety Checks**: Ensures bot responses are appropriate +- **Caching**: Reduces redundant API calls with LFU cache +- **Persistence**: Optional cache persistence for resilience across restarts +- **Thread Safety**: Fully thread-safe for use in multi-threaded web servers -Please see the docs for more details about the [recommended ContentSafety deployment](./../../../docs/user-guides/advanced/nemoguard-contentsafety-deployment.md) methods, either using locally downloaded NIMs or NVIDIA AI Enterprise (NVAIE). +## Configuration Overview + +The configuration includes: + +1. **Main Model**: The primary LLM for conversations (Llama 3.3 70B) +2. **Content Safety Model**: Dedicated model for safety checks (NemoGuard 8B) +3. **Rails**: Input and output safety check flows +4. **Cache Configuration**: Memory cache with optional persistence + +## How It Works + +1. **User Input**: When a user sends a message, it's checked by the content safety model +2. **Cache Check**: The system first checks if this content was already evaluated (cache hit) +3. **Safety Evaluation**: If not cached, the content safety model evaluates the input +4. **Result Caching**: The safety check result is cached for future use +5. **Response Generation**: If safe, the main model generates a response +6. **Output Check**: The response is also checked for safety before returning to the user + +## Cache Persistence + +The cache configuration includes: + +- **Automatic Saves**: Every 5 minutes (configurable) +- **Shutdown Saves**: Caches are automatically persisted when the application closes +- **Crash Recovery**: Cache reloads from disk on restart +- **Per-Model Storage**: Each model gets its own cache file + +To disable persistence, you can either: + +1. Set `enabled: false` in the persistence section +2. Remove the `persistence` section entirely +3. Set `interval` to `null` or remove it + +Note: Persistence requires both `enabled: true` and a valid `interval` value to be active. + +## Thread Safety + +The content safety implementation is fully thread-safe: + +- **Concurrent Requests**: Safely handles multiple simultaneous safety checks +- **No Data Corruption**: Thread-safe cache operations prevent data corruption +- **Efficient Locking**: Uses RLock for minimal performance impact +- **Atomic Operations**: Prevents duplicate LLM calls for the same content + +This makes it suitable for: + +- Multi-threaded web servers (FastAPI, Flask, Django) +- Concurrent request processing +- High-traffic applications + +### Proper Shutdown + +For best results, use one of these patterns: + +```python +# Context manager (recommended) +with LLMRails(config) as rails: + # Your code here + pass +# Caches automatically persisted on exit + +# Or manual cleanup +rails = LLMRails(config) +# Your code here +rails.close() # Persist caches +``` + +## Running the Example + +```bash +# From the NeMo-Guardrails root directory +nemoguardrails server --config examples/configs/content_safety/ +``` + +## Customization + +### Adjust Cache Settings + +```yaml +cache: + enabled: true # Enable/disable caching + capacity_per_model: 5000 # Maximum entries per model + persistence: + interval: 300.0 # Seconds between saves + path: ./my_cache.json # Custom path +``` + +### Memory-Only Cache + +For memory-only caching without persistence: + +```yaml +cache: + enabled: true + capacity_per_model: 5000 + store: memory + # No persistence section +``` + +## Benefits + +1. **Performance**: Avoid redundant content safety API calls +2. **Cost Savings**: Reduce API usage for repeated content +3. **Reliability**: Cache survives process restarts +4. **Flexibility**: Easy to enable/disable features as needed diff --git a/examples/configs/content_safety/config.yml b/examples/configs/content_safety/config.yml index f6808bf14..5ca7c8608 100644 --- a/examples/configs/content_safety/config.yml +++ b/examples/configs/content_safety/config.yml @@ -14,3 +14,18 @@ rails: output: flows: - content safety check output $model=content_safety + + # Content safety cache configuration with persistence and stats + config: + content_safety: + cache: + enabled: true + capacity_per_model: 5000 + store: memory # In-memory cache with optional disk persistence + persistence: + enabled: true # Enable persistence (requires interval to be set) + interval: 300.0 # Persist every 5 minutes + path: ./content_safety_cache.json # Where to save cache + stats: + enabled: true # Enable statistics tracking + log_interval: 60.0 # Log cache statistics every minute diff --git a/nemoguardrails/cache/README.md b/nemoguardrails/cache/README.md new file mode 100644 index 000000000..369e82f5e --- /dev/null +++ b/nemoguardrails/cache/README.md @@ -0,0 +1,271 @@ +# Content Safety LLM Call Caching + +## Overview + +The content safety checks in `actions.py` now use an LFU (Least Frequently Used) cache to improve performance by avoiding redundant LLM calls for identical safety checks. The cache supports optional persistence to disk for resilience across restarts. + +## Implementation Details + +### Cache Configuration + +- Per-model caches: Each model gets its own LFU cache instance +- Default capacity: 50,000 entries per model +- Eviction policy: LFU with LRU tiebreaker +- Statistics tracking: Enabled by default +- Tracks timestamps: `created_at` and `accessed_at` for each entry +- Cache creation: Automatic when a model is first used +- Persistence: Optional periodic save to disk with configurable interval + +### Cached Functions + +1. `content_safety_check_input()` - Caches safety checks for user inputs + +Note: `content_safety_check_output()` does not use caching to ensure fresh evaluation of bot responses. + +### Cache Key Components + +The cache key is a SHA256 hash of: + +- The rendered prompt only (can be a string or list of strings) + +Since temperature is fixed (1e-20) and stop/max_tokens are derived from the model configuration, they don't need to be part of the cache key. + +### How It Works + +1. **Before LLM Call**: + - Generate cache key from request parameters + - Check if result exists in cache + - If found, return cached result (cache hit) + +2. **After LLM Call**: + - If not in cache, make the actual LLM call + - Store the result in cache for future use + +### Cache Management + +The caching system automatically creates and manages separate caches for each model. Key features: + +- **Automatic Creation**: Caches are created on first use for each model +- **Isolated Storage**: Each model maintains its own cache, preventing cross-model interference +- **Default Settings**: Each cache has 50,000 entry capacity with stats tracking enabled + +```python +# Internal cache access (for debugging/monitoring): +from nemoguardrails.library.content_safety.actions import _MODEL_CACHES + +# View which models have caches +models_with_caches = list(_MODEL_CACHES.keys()) + +# Get stats for a specific model's cache +if "llama_guard" in _MODEL_CACHES: + stats = _MODEL_CACHES["llama_guard"].get_stats() +``` + +### Persistence Configuration + +The cache supports optional persistence to disk for resilience across restarts: + +```yaml +rails: + config: + content_safety: + cache: + enabled: true + capacity_per_model: 5000 + persistence: + interval: 300.0 # Persist every 5 minutes + path: ./cache_{model_name}.json # {model_name} is replaced +``` + +**Configuration Options:** + +- `persistence.interval`: Seconds between automatic saves (None = no persistence) +- `persistence.path`: Where to save cache data (can include `{model_name}` placeholder) + +**How Persistence Works:** + +1. **Automatic Saves**: Cache checks trigger persistence if interval has passed +2. **On Shutdown**: Caches are automatically persisted when LLMRails is closed or garbage collected +3. **On Restart**: Cache loads from disk if persistence file exists +4. **Preserves State**: Frequencies and access patterns are maintained +5. **Per-Model Files**: Each model gets its own persistence file + +**Manual Persistence:** + +```python +# Force immediate persistence of all caches +content_safety_manager.persist_all_caches() +``` + +This is useful for graceful shutdown scenarios. + +**Notes on Persistence:** + +- Persistence only works with "memory" store type +- Cache files are JSON format for easy inspection and debugging +- Set `persistence.interval` to None to disable persistence +- The cache automatically persists on each check if the interval has passed + +### Statistics and Monitoring + +The cache supports detailed statistics tracking and periodic logging for monitoring cache performance: + +```yaml +rails: + config: + content_safety: + cache: + enabled: true + capacity_per_model: 10000 + stats: + enabled: true # Enable stats tracking + log_interval: 60.0 # Log stats every minute +``` + +**Statistics Features:** + +1. **Tracking Only**: Set `stats.enabled: true` with no `log_interval` to track stats without logging +2. **Automatic Logging**: Set both `stats.enabled: true` and `log_interval` for periodic logging +3. **Manual Logging**: Force immediate stats logging with `cache.log_stats_now()` + +**Statistics Tracked:** + +- **Hits**: Number of cache hits (successful lookups) +- **Misses**: Number of cache misses (failed lookups) +- **Hit Rate**: Percentage of requests served from cache +- **Evictions**: Number of items removed due to capacity +- **Puts**: Number of new items added to cache +- **Updates**: Number of existing items updated +- **Current Size**: Number of items currently in cache + +**Log Format:** + +``` +LFU Cache Statistics - Size: 2456/10000 | Hits: 15234 | Misses: 2456 | Hit Rate: 86.11% | Evictions: 0 | Puts: 2456 | Updates: 0 +``` + +**Usage Examples:** + +```python +# Programmatically access stats +if "safety_model" in _MODEL_CACHES: + cache = _MODEL_CACHES["safety_model"] + stats = cache.get_stats() + print(f"Cache hit rate: {stats['hit_rate']:.2%}") + + # Force immediate stats logging + if cache.supports_stats_logging(): + cache.log_stats_now() +``` + +**Configuration Options:** + +- `stats.enabled`: Enable/disable statistics tracking (default: false) +- `stats.log_interval`: Seconds between automatic stats logs (None = no logging) + +**Notes:** + +- Stats logging requires stats tracking to be enabled +- Logs appear at INFO level in the `nemoguardrails.cache.lfu` logger +- Stats are reset when cache is cleared or when `reset_stats()` is called +- Each model maintains independent statistics + +### Example Configuration Usage + +```python +from nemoguardrails import RailsConfig, LLMRails + +# Method 1: Using context manager (recommended - ensures cleanup) +config = RailsConfig.from_path("./config.yml") +with LLMRails(config) as rails: + # Content safety checks will be cached and persisted automatically + response = await rails.generate_async( + messages=[{"role": "user", "content": "Hello, how are you?"}] + ) +# Caches are automatically persisted on exit + +# Method 2: Manual cleanup +rails = LLMRails(config) +response = await rails.generate_async( + messages=[{"role": "user", "content": "Hello, how are you?"}] +) +rails.close() # Manually persist caches + +# Note: If neither method is used, caches will still be persisted +# when the object is garbage collected (__del__) +``` + +### Thread Safety + +The content safety caching system is **thread-safe** for single-node deployments: + +1. **LFUCache Implementation**: + - Uses `threading.RLock` for all operations + - All public methods (`get`, `put`, `size`, `clear`, etc.) are protected by locks + - Supports atomic `get_or_compute()` operations that prevent duplicate computations + +2. **ContentSafetyManager**: + - Thread-safe cache creation using double-checked locking pattern + - Ensures only one cache instance per model across all threads + - Thread-safe persistence operations + +3. **Key Features**: + - **No Data Corruption**: Concurrent operations maintain data integrity + - **No Race Conditions**: Proper locking prevents race conditions + - **Atomic Operations**: `get_or_compute()` ensures expensive computations happen only once + - **Minimal Lock Contention**: Efficient locking patterns minimize performance impact + +4. **Usage in Web Servers**: + - Safe for use in multi-threaded web servers (FastAPI, Flask, etc.) + - Handles concurrent requests without issues + - Each thread sees consistent cache state + +**Note**: This implementation is designed for single-node deployments. For distributed systems, consider using external caching solutions like Redis. + +### Benefits + +1. **Performance**: Eliminates redundant LLM calls for identical inputs +2. **Cost Savings**: Reduces API calls to LLM services +3. **Consistency**: Ensures identical inputs always produce identical outputs +4. **Smart Eviction**: LFU policy keeps frequently checked content in cache +5. **Model Isolation**: Each model has its own cache, preventing interference between different safety models +6. **Statistics Tracking**: Monitor cache performance with hit rates, evictions, and more per model +7. **Timestamp Tracking**: Track when entries were created and last accessed +8. **Resilience**: Cache survives process restarts without losing data when persistence is enabled +9. **Efficiency**: LFU eviction algorithm ensures the most useful entries remain in cache +10. **Thread Safety**: Safe for concurrent access in multi-threaded environments + +### Example Usage Pattern + +```python +# First call - takes ~500ms (LLM API call) +result = await content_safety_check_input( + llms=llms, + llm_task_manager=task_manager, + model_name="safety_model", + context={"user_message": "Hello world"} +) + +# Subsequent identical calls - takes ~1ms (cache hit) +result = await content_safety_check_input( + llms=llms, + llm_task_manager=task_manager, + model_name="safety_model", + context={"user_message": "Hello world"} +) +``` + +### Logging + +The implementation includes debug logging: + +- Cache creation: `"Created cache for model '{model_name}' with capacity {capacity}"` +- Cache hits: `"Content safety cache hit for model '{model_name}', key: {key[:8]}..."` +- Cache stores: `"Content safety result cached for model '{model_name}', key: {key[:8]}..."` + +Enable debug logging to monitor cache behavior: + +```python +import logging +logging.getLogger("nemoguardrails.library.content_safety.actions").setLevel(logging.DEBUG) +``` diff --git a/nemoguardrails/cache/__init__.py b/nemoguardrails/cache/__init__.py new file mode 100644 index 000000000..e7f22f070 --- /dev/null +++ b/nemoguardrails/cache/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General-purpose caching utilities for NeMo Guardrails.""" + +from nemoguardrails.cache.interface import CacheInterface +from nemoguardrails.cache.lfu import LFUCache + +__all__ = ["CacheInterface", "LFUCache"] diff --git a/nemoguardrails/cache/interface.py b/nemoguardrails/cache/interface.py new file mode 100644 index 000000000..d724d6999 --- /dev/null +++ b/nemoguardrails/cache/interface.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cache interface for NeMo Guardrails caching system. + +This module defines the abstract base class for cache implementations +that can be used interchangeably throughout the guardrails system. + +Cache implementations may optionally support persistence by overriding +the persist_now() method and supports_persistence() method. Persistence +allows cache state to be saved to and loaded from external storage. +""" + +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional + + +class CacheInterface(ABC): + """ + Abstract base class defining the interface for cache implementations. + + All cache implementations must inherit from this class and implement + the required methods to ensure compatibility with the caching system. + """ + + @abstractmethod + def get(self, key: Any, default: Any = None) -> Any: + """ + Retrieve an item from the cache. + + Args: + key: The key to look up in the cache. + default: Value to return if key is not found (default: None). + + Returns: + The value associated with the key, or default if not found. + """ + pass + + @abstractmethod + def put(self, key: Any, value: Any) -> None: + """ + Store an item in the cache. + + If the cache is at capacity, this method should evict an item + according to the cache's eviction policy (e.g., LFU, LRU, etc.). + + Args: + key: The key to store. + value: The value to associate with the key. + """ + pass + + @abstractmethod + def size(self) -> int: + """ + Get the current number of items in the cache. + + Returns: + The number of items currently stored in the cache. + """ + pass + + @abstractmethod + def is_empty(self) -> bool: + """ + Check if the cache is empty. + + Returns: + True if the cache contains no items, False otherwise. + """ + pass + + @abstractmethod + def clear(self) -> None: + """ + Remove all items from the cache. + + After calling this method, the cache should be empty. + """ + pass + + def contains(self, key: Any) -> bool: + """ + Check if a key exists in the cache. + + This is an optional method that can be overridden for efficiency. + The default implementation uses get() to check existence. + + Args: + key: The key to check. + + Returns: + True if the key exists in the cache, False otherwise. + """ + # Default implementation - can be overridden for efficiency + sentinel = object() + return self.get(key, sentinel) is not sentinel + + @property + @abstractmethod + def capacity(self) -> int: + """ + Get the maximum capacity of the cache. + + Returns: + The maximum number of items the cache can hold. + """ + pass + + def persist_now(self) -> None: + """ + Force immediate persistence of cache to storage. + + This is an optional method that cache implementations can override + if they support persistence. The default implementation does nothing. + + Implementations that support persistence should save the current + cache state to their configured storage backend. + """ + # Default no-op implementation + pass + + def supports_persistence(self) -> bool: + """ + Check if this cache implementation supports persistence. + + Returns: + True if the cache supports persistence, False otherwise. + + The default implementation returns False. Cache implementations + that support persistence should override this to return True. + """ + return False + + def get_stats(self) -> dict: + """ + Get cache statistics. + + Returns: + Dictionary with cache statistics. The format and contents + may vary by implementation. Common fields include: + - hits: Number of cache hits + - misses: Number of cache misses + - evictions: Number of items evicted + - hit_rate: Percentage of requests that were hits + - current_size: Current number of items in cache + - capacity: Maximum capacity of the cache + + The default implementation returns a message indicating that + statistics tracking is not supported. + """ + return { + "message": "Statistics tracking is not supported by this cache implementation" + } + + def reset_stats(self) -> None: + """ + Reset cache statistics. + + This is an optional method that cache implementations can override + if they support statistics tracking. The default implementation does nothing. + """ + # Default no-op implementation + pass + + def log_stats_now(self) -> None: + """ + Force immediate logging of cache statistics. + + This is an optional method that cache implementations can override + if they support statistics logging. The default implementation does nothing. + + Implementations that support statistics logging should output the + current cache statistics to their configured logging backend. + """ + # Default no-op implementation + pass + + def supports_stats_logging(self) -> bool: + """ + Check if this cache implementation supports statistics logging. + + Returns: + True if the cache supports statistics logging, False otherwise. + + The default implementation returns False. Cache implementations + that support statistics logging should override this to return True + when logging is enabled. + """ + return False + + async def get_or_compute( + self, key: Any, compute_fn: Callable[[], Any], default: Any = None + ) -> Any: + """ + Atomically get a value from the cache or compute it if not present. + + This method ensures that the compute function is called at most once + even in the presence of concurrent requests for the same key. + + Args: + key: The key to look up + compute_fn: Async function to compute the value if key is not found + default: Value to return if compute_fn raises an exception + + Returns: + The cached value or the computed value + + This is an optional method with a default implementation. Cache + implementations should override this for better thread-safety guarantees. + """ + # Default implementation - not thread-safe for computation + value = self.get(key) + if value is not None: + return value + + try: + computed_value = await compute_fn() + self.put(key, computed_value) + return computed_value + except Exception: + return default diff --git a/nemoguardrails/cache/lfu.py b/nemoguardrails/cache/lfu.py new file mode 100644 index 000000000..4f8e450c0 --- /dev/null +++ b/nemoguardrails/cache/lfu.py @@ -0,0 +1,677 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Least Frequently Used (LFU) cache implementation.""" + +import asyncio +import json +import logging +import os +import threading +import time +from typing import Any, Callable, Optional + +from nemoguardrails.cache.interface import CacheInterface + +log = logging.getLogger(__name__) + + +class LFUNode: + """Node for the LFU cache doubly linked list.""" + + def __init__(self, key: Any, value: Any) -> None: + self.key = key + self.value = value + self.freq = 1 + self.prev: Optional["LFUNode"] = None + self.next: Optional["LFUNode"] = None + self.created_at = time.time() + self.accessed_at = self.created_at + + +class DoublyLinkedList: + """Doubly linked list to maintain nodes with the same frequency.""" + + def __init__(self) -> None: + # Create dummy head and tail nodes + self.head = LFUNode(None, None) + self.tail = LFUNode(None, None) + self.head.next = self.tail + self.tail.prev = self.head + self.size = 0 + + def append(self, node: LFUNode) -> None: + """Add node to the end of the list (before tail).""" + node.prev = self.tail.prev + node.next = self.tail + self.tail.prev.next = node + self.tail.prev = node + self.size += 1 + + def pop(self, node: Optional[LFUNode] = None) -> Optional[LFUNode]: + """Remove and return a node. If no node specified, removes the first node.""" + if self.size == 0: + return None + + if node is None: + node = self.head.next + + # Remove node from the list + node.prev.next = node.next + node.next.prev = node.prev + self.size -= 1 + + return node + + +class LFUCache(CacheInterface): + """ + Least Frequently Used (LFU) Cache implementation. + + When the cache reaches capacity, it evicts the least frequently used item. + If there are ties in frequency, it evicts the least recently used among them. + """ + + def __init__( + self, + capacity: int, + track_stats: bool = False, + persistence_interval: Optional[float] = None, + persistence_path: Optional[str] = None, + stats_logging_interval: Optional[float] = None, + ) -> None: + """ + Initialize the LFU cache. + + Args: + capacity: Maximum number of items the cache can hold + track_stats: Enable tracking of cache statistics + persistence_interval: Seconds between periodic dumps to disk (None disables persistence) + persistence_path: Path to persistence file (defaults to 'lfu_cache.json' if persistence enabled) + stats_logging_interval: Seconds between periodic stats logging (None disables logging) + """ + if capacity < 0: + raise ValueError("Capacity must be non-negative") + + self._capacity = capacity + self.track_stats = track_stats + self._lock = threading.RLock() # Thread-safe access + self._computing: dict[Any, asyncio.Future] = {} # Track keys being computed + + self.key_map: dict[Any, LFUNode] = {} # key -> node mapping + self.freq_map: dict[int, DoublyLinkedList] = {} # frequency -> list of nodes + self.min_freq = 0 # Track minimum frequency for eviction + + # Persistence configuration + self.persistence_interval = persistence_interval + self.persistence_path = persistence_path or "lfu_cache.json" + # Initialize to None to ensure first check doesn't trigger immediately + self.last_persist_time = None + + # Stats logging configuration + self.stats_logging_interval = stats_logging_interval + # Initialize to None to ensure first check doesn't trigger immediately + self.last_stats_log_time = None + + # Statistics tracking + if self.track_stats: + self.stats = { + "hits": 0, + "misses": 0, + "evictions": 0, + "puts": 0, + "updates": 0, + } + + # Load from disk if persistence is enabled and file exists + if self.persistence_interval is not None: + self._load_from_disk() + + def _update_node_freq(self, node: LFUNode) -> None: + """Update the frequency of a node and move it to the appropriate frequency list.""" + old_freq = node.freq + old_list = self.freq_map[old_freq] + + # Remove node from current frequency list + old_list.pop(node) + + # Update min_freq if necessary + if self.min_freq == old_freq and old_list.size == 0: + self.min_freq += 1 + # Clean up empty frequency lists + del self.freq_map[old_freq] + + # Increment frequency and add to new list + node.freq += 1 + new_freq = node.freq + node.accessed_at = time.time() # Update access time + + if new_freq not in self.freq_map: + self.freq_map[new_freq] = DoublyLinkedList() + + self.freq_map[new_freq].append(node) + + def get(self, key: Any, default: Any = None) -> Any: + """ + Get an item from the cache. + + Args: + key: The key to look up + default: Value to return if key is not found + + Returns: + The value associated with the key, or default if not found + """ + with self._lock: + # Check if we should persist + self._check_and_persist() + + # Check if we should log stats + self._check_and_log_stats() + + if key not in self.key_map: + if self.track_stats: + self.stats["misses"] += 1 + return default + + node = self.key_map[key] + + if self.track_stats: + self.stats["hits"] += 1 + + self._update_node_freq(node) + return node.value + + def put(self, key: Any, value: Any) -> None: + """ + Put an item into the cache. + + Args: + key: The key to store + value: The value to associate with the key + """ + with self._lock: + # Check if we should persist + self._check_and_persist() + + # Check if we should log stats + self._check_and_log_stats() + + if self._capacity == 0: + return + + if key in self.key_map: + # Update existing key + node = self.key_map[key] + node.value = value + node.created_at = time.time() # Reset creation time on update + self._update_node_freq(node) + if self.track_stats: + self.stats["updates"] += 1 + else: + # Add new key + if len(self.key_map) >= self._capacity: + # Need to evict least frequently used item + self._evict_lfu() + + # Create new node and add to cache + new_node = LFUNode(key, value) + self.key_map[key] = new_node + + # Add to frequency 1 list + if 1 not in self.freq_map: + self.freq_map[1] = DoublyLinkedList() + + self.freq_map[1].append(new_node) + self.min_freq = 1 + + if self.track_stats: + self.stats["puts"] += 1 + + def _evict_lfu(self) -> None: + """Evict the least frequently used item from the cache.""" + if self.min_freq in self.freq_map: + lfu_list = self.freq_map[self.min_freq] + node_to_evict = lfu_list.pop() # Remove least recently used among LFU + + if node_to_evict: + del self.key_map[node_to_evict.key] + + if self.track_stats: + self.stats["evictions"] += 1 + + # Clean up empty frequency list + if lfu_list.size == 0: + del self.freq_map[self.min_freq] + + def size(self) -> int: + """Return the current size of the cache.""" + with self._lock: + return len(self.key_map) + + def is_empty(self) -> bool: + """Check if the cache is empty.""" + with self._lock: + return len(self.key_map) == 0 + + def clear(self) -> None: + """Clear all items from the cache.""" + with self._lock: + if self.track_stats: + # Track number of items evicted + self.stats["evictions"] += len(self.key_map) + + self.key_map.clear() + self.freq_map.clear() + self.min_freq = 0 + + def get_stats(self) -> dict: + """ + Get cache statistics. + + Returns: + Dictionary with cache statistics (if tracking is enabled) + """ + with self._lock: + if not self.track_stats: + return {"message": "Statistics tracking is disabled"} + + stats = self.stats.copy() + stats["current_size"] = len(self.key_map) # Direct access within lock + stats["capacity"] = self._capacity + + # Calculate hit rate + total_requests = stats["hits"] + stats["misses"] + stats["hit_rate"] = ( + stats["hits"] / total_requests if total_requests > 0 else 0.0 + ) + + return stats + + def reset_stats(self) -> None: + """Reset cache statistics.""" + with self._lock: + if self.track_stats: + self.stats = { + "hits": 0, + "misses": 0, + "evictions": 0, + "puts": 0, + "updates": 0, + } + + def _check_and_persist(self) -> None: + """Check if enough time has passed and persist to disk if needed.""" + if self.persistence_interval is None: + return + + current_time = time.time() + + # Initialize timestamp on first check + if self.last_persist_time is None: + self.last_persist_time = current_time + return + + if current_time - self.last_persist_time >= self.persistence_interval: + self._persist_to_disk() + self.last_persist_time = current_time + + def _persist_to_disk(self) -> None: + """ + Serialize cache to disk. + + Stores cache data as JSON with node information including keys, values, + frequencies, and timestamps for reconstruction. + """ + if not self.key_map: + # If cache is empty, remove the persistence file + if os.path.exists(self.persistence_path): + os.remove(self.persistence_path) + return + + cache_data = { + "capacity": self._capacity, + "min_freq": self.min_freq, + "nodes": [], + } + + # Serialize all nodes + for key, node in self.key_map.items(): + cache_data["nodes"].append( + { + "key": key, + "value": node.value, + "freq": node.freq, + "created_at": node.created_at, + "accessed_at": node.accessed_at, + } + ) + + # Write to disk + try: + with open(self.persistence_path, "w") as f: + json.dump(cache_data, f, indent=2) + except Exception as e: + # Silently fail on persistence errors to not disrupt cache operations + pass + + def _load_from_disk(self) -> None: + """ + Load cache from disk if persistence file exists. + + Reconstructs the cache state including frequency lists and node relationships. + """ + if not os.path.exists(self.persistence_path): + return + + try: + with open(self.persistence_path, "r") as f: + cache_data = json.load(f) + + # Reconstruct cache + self.min_freq = cache_data.get("min_freq", 0) + + for node_data in cache_data.get("nodes", []): + # Create node + node = LFUNode(node_data["key"], node_data["value"]) + node.freq = node_data["freq"] + node.created_at = node_data["created_at"] + node.accessed_at = node_data["accessed_at"] + + # Add to key map + self.key_map[node.key] = node + + # Add to appropriate frequency list + if node.freq not in self.freq_map: + self.freq_map[node.freq] = DoublyLinkedList() + self.freq_map[node.freq].append(node) + + except Exception as e: + # If loading fails, start with empty cache + self.key_map.clear() + self.freq_map.clear() + self.min_freq = 0 + + def persist_now(self) -> None: + """Force immediate persistence to disk (useful for shutdown).""" + with self._lock: + if self.persistence_interval is not None: + self._persist_to_disk() + self.last_persist_time = time.time() + + def supports_persistence(self) -> bool: + """Check if this cache instance supports persistence.""" + return self.persistence_interval is not None + + def _check_and_log_stats(self) -> None: + """Check if enough time has passed and log stats if needed.""" + if not self.track_stats or self.stats_logging_interval is None: + return + + current_time = time.time() + + # Initialize timestamp on first check + if self.last_stats_log_time is None: + self.last_stats_log_time = current_time + return + + if current_time - self.last_stats_log_time >= self.stats_logging_interval: + self._log_stats() + self.last_stats_log_time = current_time + + def _log_stats(self) -> None: + """Log current cache statistics.""" + stats = self.get_stats() + + # Format the log message + log_msg = ( + f"LFU Cache Statistics - " + f"Size: {stats['current_size']}/{stats['capacity']} | " + f"Hits: {stats['hits']} | " + f"Misses: {stats['misses']} | " + f"Hit Rate: {stats['hit_rate']:.2%} | " + f"Evictions: {stats['evictions']} | " + f"Puts: {stats['puts']} | " + f"Updates: {stats['updates']}" + ) + + log.info(log_msg) + + def log_stats_now(self) -> None: + """Force immediate logging of cache statistics.""" + if self.track_stats: + self._log_stats() + self.last_stats_log_time = time.time() + + def supports_stats_logging(self) -> bool: + """Check if this cache instance supports stats logging.""" + return self.track_stats and self.stats_logging_interval is not None + + async def get_or_compute( + self, key: Any, compute_fn: Callable[[], Any], default: Any = None + ) -> Any: + """ + Atomically get a value from the cache or compute it if not present. + + This method ensures that the compute function is called at most once + even in the presence of concurrent requests for the same key. + + Args: + key: The key to look up + compute_fn: Async function to compute the value if key is not found + default: Value to return if compute_fn raises an exception + + Returns: + The cached value or the computed value + """ + # First check if the value is already in cache + future = None + with self._lock: + if key in self.key_map: + node = self.key_map[key] + if self.track_stats: + self.stats["hits"] += 1 + self._update_node_freq(node) + return node.value + + # Check if this key is already being computed + if key in self._computing: + future = self._computing[key] + + # If the key is being computed, wait for it outside the lock + if future is not None: + try: + return await future + except Exception: + return default + + # Create a future for this computation + future = asyncio.Future() + with self._lock: + # Double-check the cache and computing dict + if key in self.key_map: + node = self.key_map[key] + if self.track_stats: + self.stats["hits"] += 1 + self._update_node_freq(node) + return node.value + + if key in self._computing: + # Another thread started computing while we were waiting + future = self._computing[key] + else: + # We'll be the ones computing + self._computing[key] = future + + # If another thread is computing, wait for it + if not future.done() and self._computing.get(key) is not future: + try: + return await self._computing[key] + except Exception: + return default + + # We're responsible for computing the value + try: + computed_value = await compute_fn() + + # Store the computed value in cache + with self._lock: + # Remove from computing dict + self._computing.pop(key, None) + + # Check one more time if someone else added it + if key in self.key_map: + node = self.key_map[key] + if self.track_stats: + self.stats["hits"] += 1 + self._update_node_freq(node) + future.set_result(node.value) + return node.value + + # Now add to cache using internal logic + if self._capacity == 0: + future.set_result(computed_value) + return computed_value + + # Add new key + if len(self.key_map) >= self._capacity: + self._evict_lfu() + + # Create new node and add to cache + new_node = LFUNode(key, computed_value) + self.key_map[key] = new_node + + # Add to frequency 1 list + if 1 not in self.freq_map: + self.freq_map[1] = DoublyLinkedList() + + self.freq_map[1].append(new_node) + self.min_freq = 1 + + if self.track_stats: + self.stats["puts"] += 1 + + # Set the result in the future + future.set_result(computed_value) + return computed_value + + except Exception as e: + with self._lock: + self._computing.pop(key, None) + future.set_exception(e) + return default + + def contains(self, key: Any) -> bool: + """ + Check if a key exists in the cache without updating its frequency. + + This is more efficient than the default implementation which uses get() + and has the side effect of updating frequency counts. + + Args: + key: The key to check + + Returns: + True if the key exists in the cache, False otherwise + """ + with self._lock: + return key in self.key_map + + @property + def capacity(self) -> int: + """Get the maximum capacity of the cache.""" + return self._capacity + + +# Example usage and testing +if __name__ == "__main__": + print("=== Basic LFU Cache Example ===") + # Create a basic LFU cache + cache = LFUCache(3) + + cache.put("a", 1) + cache.put("b", 2) + cache.put("c", 3) + + print(f"Get 'a': {cache.get('a')}") # Returns 1, frequency of 'a' becomes 2 + print(f"Get 'b': {cache.get('b')}") # Returns 2, frequency of 'b' becomes 2 + + cache.put("d", 4) # Evicts 'c' (least frequently used) + + print(f"Get 'c': {cache.get('c', 'Not found')}") # Returns 'Not found' + print(f"Get 'd': {cache.get('d')}") # Returns 4 + print(f"Cache size: {cache.size()}") # Returns 3 + + print("\n=== Cache with Statistics Tracking ===") + + # Create cache with statistics tracking + stats_cache = LFUCache(capacity=5, track_stats=True) + + # Add some items + for i in range(6): + stats_cache.put(f"key{i}", f"value{i}") + + # Access some items to change frequencies + for _ in range(3): + stats_cache.get("key4") # Increase frequency + stats_cache.get("key5") # Increase frequency + + # Some cache misses + stats_cache.get("nonexistent1") + stats_cache.get("nonexistent2") + + # Check statistics + print(f"\nCache statistics: {stats_cache.get_stats()}") + + # Update existing key + stats_cache.put("key4", "updated_value4") + + # Check updated statistics + print(f"\nUpdated statistics: {stats_cache.get_stats()}") + + # Reset statistics + stats_cache.reset_stats() + print(f"\nAfter reset: {stats_cache.get_stats()}") + + print("\n=== Cache with Persistence ===") + + # Create cache with persistence (5 second interval) + persist_cache = LFUCache( + capacity=3, persistence_interval=5.0, persistence_path="test_cache.json" + ) + + # Add some items + persist_cache.put("item1", "value1") + persist_cache.put("item2", "value2") + persist_cache.put("item3", "value3") + + # Force immediate persistence + persist_cache.persist_now() + print("Cache persisted to disk") + + # Create new cache instance that will load from disk + new_cache = LFUCache( + capacity=3, persistence_interval=5.0, persistence_path="test_cache.json" + ) + + # Verify data was loaded + print(f"Loaded item1: {new_cache.get('item1')}") # Should return 'value1' + print(f"Loaded item2: {new_cache.get('item2')}") # Should return 'value2' + print(f"Cache size after loading: {new_cache.size()}") # Should return 3 + + # Clean up + if os.path.exists("test_cache.json"): + os.remove("test_cache.json") + print("Cleaned up test persistence file") diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py index 2407210fa..3b8e97734 100644 --- a/nemoguardrails/library/content_safety/actions.py +++ b/nemoguardrails/library/content_safety/actions.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging -from typing import Dict, Optional +import re +from typing import Dict, List, Optional, Union from langchain_core.language_models.llms import BaseLLM @@ -27,6 +29,32 @@ log = logging.getLogger(__name__) +PROMPT_PATTERN_WHITESPACES = re.compile(r"\s+") + + +def _create_cache_key(prompt: Union[str, List[str]]) -> str: + """Create a cache key from the prompt.""" + # can the prompt really be a list? + if isinstance(prompt, list): + prompt_str = json.dumps(prompt) + else: + prompt_str = prompt + + # normalize the prompt to a string + # should we do more normalizations? + return PROMPT_PATTERN_WHITESPACES.sub(" ", prompt_str).strip() + + +# Thread Safety Note: +# The content safety caching mechanism is thread-safe for single-node deployments. +# The underlying LFUCache uses threading.RLock to ensure atomic operations. +# ContentSafetyManager uses double-checked locking for efficient cache creation. +# +# However, this implementation is NOT suitable for distributed environments. +# For multi-node deployments, consider using distributed caching solutions +# like Redis or a shared database. + + @action() async def content_safety_check_input( llms: Dict[str, BaseLLM], @@ -75,6 +103,24 @@ async def content_safety_check_input( max_tokens = max_tokens or _MAX_TOKENS + # Check cache if content safety manager is available for this model + cached_result = None + cache_key = None + cache = None + + # Try to get the model-specific content safety manager + content_safety_manager = kwargs.get(f"content_safety_manager_{model_name}") + + if content_safety_manager: + cache = content_safety_manager.get_cache() + if cache: + cache_key = _create_cache_key(check_input_prompt) + cached_result = cache.get(cache_key) + if cached_result is not None: + log.debug(f"Content safety cache hit for model '{model_name}'") + return cached_result + + # Make the actual LLM call result = await llm_call( llm, check_input_prompt, @@ -86,7 +132,14 @@ async def content_safety_check_input( is_safe, *violated_policies = result - return {"allowed": is_safe, "policy_violations": violated_policies} + final_result = {"allowed": is_safe, "policy_violations": violated_policies} + + # Store in cache if available + if cache_key and cache: + cache.put(cache_key, final_result) + log.debug(f"Content safety result cached for model '{model_name}'") + + return final_result def content_safety_check_output_mapping(result: dict) -> bool: diff --git a/nemoguardrails/library/content_safety/manager.py b/nemoguardrails/library/content_safety/manager.py new file mode 100644 index 000000000..16a0ae12f --- /dev/null +++ b/nemoguardrails/library/content_safety/manager.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Optional + +from nemoguardrails.cache.interface import CacheInterface +from nemoguardrails.cache.lfu import LFUCache +from nemoguardrails.rails.llm.config import ModelCacheConfig + +log = logging.getLogger(__name__) + + +class ContentSafetyManager: + """Manages content safety functionality for a specific model.""" + + def __init__( + self, model_name: str, cache_config: Optional[ModelCacheConfig] = None + ): + self.model_name = model_name + self.cache_config = cache_config + self._cache: Optional[CacheInterface] = None + self._initialize_cache() + + def _initialize_cache(self): + """Initialize cache based on configuration.""" + if not self.cache_config or not self.cache_config.enabled: + log.debug( + f"Content safety caching is disabled for model '{self.model_name}'" + ) + return + + # Create cache based on store type + if self.cache_config.store == "memory": + # Determine persistence settings + persistence_path = None + persistence_interval = None + + if ( + self.cache_config.persistence.enabled + and self.cache_config.persistence.interval is not None + ): + persistence_interval = self.cache_config.persistence.interval + + if self.cache_config.persistence.path: + # Use configured path, replacing {model_name} if present + persistence_path = self.cache_config.persistence.path.replace( + "{model_name}", self.model_name + ) + else: + # Default path if persistence is enabled but no path specified + persistence_path = f"cache_{self.model_name}.json" + + # Determine stats logging settings + stats_logging_interval = None + if ( + self.cache_config.stats.enabled + and self.cache_config.stats.log_interval is not None + ): + stats_logging_interval = self.cache_config.stats.log_interval + + self._cache = LFUCache( + capacity=self.cache_config.capacity_per_model, + track_stats=self.cache_config.stats.enabled, + persistence_interval=persistence_interval, + persistence_path=persistence_path, + stats_logging_interval=stats_logging_interval, + ) + + log.info( + f"Created cache for model '{self.model_name}' with capacity {self.cache_config.capacity_per_model}" + ) + # elif self.cache_config.store == "filesystem": + # self._cache = FilesystemCache(...) + # elif self.cache_config.store == "redis": + # self._cache = RedisCache(...) + + def get_cache(self) -> Optional[CacheInterface]: + """Get the cache for this model.""" + return self._cache + + def persist_cache(self): + """Force immediate persistence of cache if it supports it.""" + if self._cache and self._cache.supports_persistence(): + self._cache.persist_now() + log.info(f"Persisted cache for model: {self.model_name}") diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 6b6a9b64a..1bc92aa93 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -15,6 +15,8 @@ """Module for the configuration of rails.""" +from __future__ import annotations + import logging import os import warnings @@ -97,6 +99,12 @@ class Model(BaseModel): description="Whether the mode is 'text' completion or 'chat' completion. Allowed values are 'chat' or 'text'.", ) + # Cache configuration specific to this model (for content safety models) + cache: Optional["ModelCacheConfig"] = Field( + default=None, + description="Cache configuration for this specific model (primarily used for content safety models)", + ) + @model_validator(mode="before") @classmethod def set_and_validate_model(cls, data: Any) -> Any: @@ -870,6 +878,62 @@ class AIDefenseRailConfig(BaseModel): ) +class CachePersistenceConfig(BaseModel): + """Configuration for cache persistence to disk.""" + + enabled: bool = Field( + default=True, + description="Whether cache persistence is enabled (persistence requires both enabled=True and a valid interval)", + ) + interval: Optional[float] = Field( + default=None, + description="Seconds between periodic cache persistence to disk (None disables persistence)", + ) + path: Optional[str] = Field( + default=None, + description="Path to persistence file for cache data (defaults to 'cache_{model_name}.json' if persistence is enabled)", + ) + + +class CacheStatsConfig(BaseModel): + """Configuration for cache statistics tracking and logging.""" + + enabled: bool = Field( + default=False, + description="Whether cache statistics tracking is enabled", + ) + log_interval: Optional[float] = Field( + default=None, + description="Seconds between periodic cache stats logging to logs (None disables logging)", + ) + + +class ModelCacheConfig(BaseModel): + """Configuration for model caching.""" + + enabled: bool = Field( + default=False, + description="Whether caching is enabled (default: False - no caching)", + ) + capacity_per_model: int = Field( + default=50000, description="Maximum number of entries in the cache per model" + ) + store: str = Field( + default="memory", description="Cache store: 'memory', 'filesystem', 'redis'" + ) + store_config: Dict[str, Any] = Field( + default_factory=dict, description="Backend-specific configuration" + ) + persistence: CachePersistenceConfig = Field( + default_factory=CachePersistenceConfig, + description="Configuration for cache persistence", + ) + stats: CacheStatsConfig = Field( + default_factory=CacheStatsConfig, + description="Configuration for cache statistics tracking and logging", + ) + + class RailsConfigData(BaseModel): """Configuration data for specific rails that are supported out-of-the-box.""" diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index e736a32df..03f8c0dd4 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -131,6 +131,7 @@ def __init__( self.config = config self.llm = llm self.verbose = verbose + self._content_safety_managers = {} if self.verbose: set_verbose(True, llm_calls=True) @@ -500,6 +501,7 @@ def _init_llms(self): kwargs=kwargs, ) + # If the model is a content safety model, we need to create a ContentSafetyManager for it if llm_config.type == "main": # If a main LLM was already injected, skip creating another # one. Otherwise, create and register it. @@ -525,6 +527,36 @@ def _init_llms(self): self.runtime.register_action_param("llms", llms) + # Register content safety managers if content safety features are used + if self._has_content_safety_rails(): + from nemoguardrails.library.content_safety.manager import ( + ContentSafetyManager, + ) + + # Create a ContentSafetyManager for each content safety model + for model in self.config.models: + if model.type not in ["main", "embeddings"]: + # Use model's cache config if available, otherwise None (no caching) + cache_config = model.cache + + manager = ContentSafetyManager( + model_name=model.type, cache_config=cache_config + ) + + self._content_safety_managers[model.type] = manager + + # Register the manager for this specific model + self.runtime.register_action_param( + f"content_safety_manager_{model.type}", manager + ) + + log.info( + f"Initialized ContentSafetyManager for model '{model.type}' with cache %s", + "enabled" + if cache_config and cache_config.enabled + else "disabled", + ) + def _get_embeddings_search_provider_instance( self, esp_config: Optional[EmbeddingSearchProvider] = None ) -> EmbeddingsIndex: @@ -1477,6 +1509,16 @@ def register_embedding_provider( register_embedding_provider(engine_name=name, model=cls) return self + def _has_content_safety_rails(self) -> bool: + """Check if any content safety rails are configured in flows. + At the moment, we only support content safety manager in input flows. + """ + flows = self.config.rails.input.flows + for flow in flows: + if "content safety check input" in flow: + return True + return False + def explain(self) -> ExplainInfo: """Helper function to return the latest ExplainInfo object.""" if self.explain_info is None: @@ -1760,3 +1802,27 @@ def _prepare_params( # yield the individual chunks directly from the buffer strategy for chunk in user_output_chunks: yield chunk + + def close(self): + """Properly close and clean up resources, including persisting caches.""" + if self._content_safety_managers: + log.info("Persisting content safety caches on close") + for model_name, manager in self._content_safety_managers.items(): + manager.persist_cache() + + def __del__(self): + """Ensure caches are persisted when the object is garbage collected.""" + try: + self.close() + except Exception as e: + # Silently fail in destructor to avoid issues during shutdown + log.debug(f"Error during LLMRails cleanup: {e}") + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - ensure cleanup.""" + self.close() + return False diff --git a/tests/test_cache_lfu.py b/tests/test_cache_lfu.py new file mode 100644 index 000000000..f2941482a --- /dev/null +++ b/tests/test_cache_lfu.py @@ -0,0 +1,1861 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comprehensive test suite for LFU Cache implementation. + +Tests all functionality including basic operations, eviction policies, +capacity management, edge cases, and persistence functionality. +""" + +import asyncio +import json +import os +import tempfile +import threading +import time +import unittest +from concurrent.futures import ThreadPoolExecutor +from typing import Any +from unittest.mock import MagicMock, patch + +from nemoguardrails.cache.lfu import LFUCache +from nemoguardrails.library.content_safety.manager import ContentSafetyManager + + +class TestLFUCache(unittest.TestCase): + """Test cases for LFU Cache implementation.""" + + def setUp(self): + """Set up test fixtures.""" + self.cache = LFUCache(3) + + def test_initialization(self): + """Test cache initialization with various capacities.""" + # Normal capacity + cache = LFUCache(5) + self.assertEqual(cache.size(), 0) + self.assertTrue(cache.is_empty()) + + # Zero capacity + cache_zero = LFUCache(0) + self.assertEqual(cache_zero.size(), 0) + + # Negative capacity should raise error + with self.assertRaises(ValueError): + LFUCache(-1) + + def test_basic_put_get(self): + """Test basic put and get operations.""" + # Put and get single item + self.cache.put("key1", "value1") + self.assertEqual(self.cache.get("key1"), "value1") + self.assertEqual(self.cache.size(), 1) + + # Put and get multiple items + self.cache.put("key2", "value2") + self.cache.put("key3", "value3") + + self.assertEqual(self.cache.get("key1"), "value1") + self.assertEqual(self.cache.get("key2"), "value2") + self.assertEqual(self.cache.get("key3"), "value3") + self.assertEqual(self.cache.size(), 3) + + def test_get_nonexistent_key(self): + """Test getting non-existent keys.""" + # Default behavior (returns None) + self.assertIsNone(self.cache.get("nonexistent")) + + # With custom default + self.assertEqual(self.cache.get("nonexistent", "default"), "default") + + # After adding some items + self.cache.put("key1", "value1") + self.assertIsNone(self.cache.get("key2")) + self.assertEqual(self.cache.get("key2", 42), 42) + + def test_update_existing_key(self): + """Test updating values for existing keys.""" + self.cache.put("key1", "value1") + self.cache.put("key2", "value2") + + # Update existing key + self.cache.put("key1", "new_value1") + self.assertEqual(self.cache.get("key1"), "new_value1") + + # Size should not change + self.assertEqual(self.cache.size(), 2) + + def test_lfu_eviction_basic(self): + """Test basic LFU eviction when cache is full.""" + # Fill cache + self.cache.put("a", 1) + self.cache.put("b", 2) + self.cache.put("c", 3) + + # Access 'a' and 'b' to increase their frequency + self.cache.get("a") # freq: 2 + self.cache.get("b") # freq: 2 + # 'c' remains at freq: 1 + + # Add new item - should evict 'c' (lowest frequency) + self.cache.put("d", 4) + + self.assertEqual(self.cache.get("a"), 1) + self.assertEqual(self.cache.get("b"), 2) + self.assertEqual(self.cache.get("d"), 4) + self.assertIsNone(self.cache.get("c")) # Should be evicted + + def test_lfu_with_lru_tiebreaker(self): + """Test LRU eviction among items with same frequency.""" + # Fill cache - all items have frequency 1 + self.cache.put("a", 1) + self.cache.put("b", 2) + self.cache.put("c", 3) + + # Add new item - should evict 'a' (least recently used among freq 1) + self.cache.put("d", 4) + + self.assertIsNone(self.cache.get("a")) # Should be evicted + self.assertEqual(self.cache.get("b"), 2) + self.assertEqual(self.cache.get("c"), 3) + self.assertEqual(self.cache.get("d"), 4) + + def test_complex_eviction_scenario(self): + """Test complex eviction scenario with multiple frequency levels.""" + # Create a new cache for this test + cache = LFUCache(4) + + # Add items and create different frequency levels + cache.put("a", 1) + cache.put("b", 2) + cache.put("c", 3) + cache.put("d", 4) + + # Create frequency pattern: + # a: freq 3 (accessed 2 more times) + # b: freq 2 (accessed 1 more time) + # c: freq 2 (accessed 1 more time) + # d: freq 1 (not accessed) + + cache.get("a") + cache.get("a") + cache.get("b") + cache.get("c") + + # Add new item - should evict 'd' (freq 1) + cache.put("e", 5) + self.assertIsNone(cache.get("d")) + + # Add another item - should evict one of the least frequently used + cache.put("f", 6) + + # After eviction, we should have: + # - 'a' (freq 3) - definitely kept + # - 'b' (freq 2) and 'c' (freq 2) - higher frequency, both kept + # - 'f' (freq 1) - just added + # - 'e' (freq 1) was evicted as it was least recently used among freq 1 items + + # Check that we're at capacity + self.assertEqual(cache.size(), 4) + + # 'a' should definitely still be there (highest frequency) + self.assertEqual(cache.get("a"), 1) + + # 'b' and 'c' should both be there (freq 2) + self.assertEqual(cache.get("b"), 2) + self.assertEqual(cache.get("c"), 3) + + # 'f' should be there (just added) + self.assertEqual(cache.get("f"), 6) + + # 'e' should have been evicted (freq 1, LRU among freq 1 items) + self.assertIsNone(cache.get("e")) + + def test_zero_capacity_cache(self): + """Test cache with zero capacity.""" + cache = LFUCache(0) + + # Put should not store anything + cache.put("key", "value") + self.assertEqual(cache.size(), 0) + self.assertIsNone(cache.get("key")) + + # Multiple puts + for i in range(10): + cache.put(f"key{i}", f"value{i}") + + self.assertEqual(cache.size(), 0) + self.assertTrue(cache.is_empty()) + + def test_clear_method(self): + """Test clearing the cache.""" + # Add items + self.cache.put("a", 1) + self.cache.put("b", 2) + self.cache.put("c", 3) + + # Verify items exist + self.assertEqual(self.cache.size(), 3) + self.assertFalse(self.cache.is_empty()) + + # Clear cache + self.cache.clear() + + # Verify cache is empty + self.assertEqual(self.cache.size(), 0) + self.assertTrue(self.cache.is_empty()) + + # Verify items are gone + self.assertIsNone(self.cache.get("a")) + self.assertIsNone(self.cache.get("b")) + self.assertIsNone(self.cache.get("c")) + + # Can still use cache after clear + self.cache.put("new_key", "new_value") + self.assertEqual(self.cache.get("new_key"), "new_value") + + def test_various_data_types(self): + """Test cache with various data types as keys and values.""" + # Integer keys + self.cache.put(1, "one") + self.cache.put(2, "two") + self.assertEqual(self.cache.get(1), "one") + self.assertEqual(self.cache.get(2), "two") + + # Tuple keys + self.cache.put((1, 2), "tuple_value") + self.assertEqual(self.cache.get((1, 2)), "tuple_value") + + # Clear for more tests + self.cache.clear() + + # Complex values + self.cache.put("list", [1, 2, 3]) + self.cache.put("dict", {"a": 1, "b": 2}) + self.cache.put("set", {1, 2, 3}) + + self.assertEqual(self.cache.get("list"), [1, 2, 3]) + self.assertEqual(self.cache.get("dict"), {"a": 1, "b": 2}) + self.assertEqual(self.cache.get("set"), {1, 2, 3}) + + def test_none_values(self): + """Test storing None as a value.""" + self.cache.put("key", None) + # get should return None for the value, not the default + self.assertIsNone(self.cache.get("key")) + self.assertEqual(self.cache.get("key", "default"), None) + + # Verify key exists + self.assertEqual(self.cache.size(), 1) + + def test_size_and_capacity(self): + """Test size tracking and capacity limits.""" + # Start empty + self.assertEqual(self.cache.size(), 0) + + # Add items up to capacity + for i in range(3): + self.cache.put(f"key{i}", f"value{i}") + self.assertEqual(self.cache.size(), i + 1) + + # Add more items - size should stay at capacity + for i in range(3, 10): + self.cache.put(f"key{i}", f"value{i}") + self.assertEqual(self.cache.size(), 3) + + def test_is_empty(self): + """Test is_empty method in various states.""" + # Initially empty + self.assertTrue(self.cache.is_empty()) + + # After adding item + self.cache.put("key", "value") + self.assertFalse(self.cache.is_empty()) + + # After clearing + self.cache.clear() + self.assertTrue(self.cache.is_empty()) + + def test_repeated_puts_same_key(self): + """Test repeated puts with the same key maintain size=1 and update frequency.""" + self.cache.put("key", "value1") + self.assertEqual(self.cache.size(), 1) + + # Track initial state + initial_stats = self.cache.get_stats() if self.cache.track_stats else None + + # Update same key multiple times + for i in range(10): + self.cache.put("key", f"value{i}") + self.assertEqual(self.cache.size(), 1) + + # Final value should be the last one + self.assertEqual(self.cache.get("key"), "value9") + + # Verify stats if tracking enabled + if self.cache.track_stats: + final_stats = self.cache.get_stats() + # Should have 10 updates (after initial put) + self.assertEqual(final_stats["updates"], 10) + + def test_access_pattern_preserves_frequently_used(self): + """Test that frequently accessed items are preserved during evictions.""" + # Create specific access pattern + cache = LFUCache(3) + + # Add three items + cache.put("rarely_used", 1) + cache.put("sometimes_used", 2) + cache.put("frequently_used", 3) + + # Create access pattern + # frequently_used: access 10 times + for _ in range(10): + cache.get("frequently_used") + + # sometimes_used: access 3 times + for _ in range(3): + cache.get("sometimes_used") + + # rarely_used: no additional access (freq = 1) + + # Add new items to trigger evictions + cache.put("new1", 4) # Should evict rarely_used + cache.put("new2", 5) # Should evict new1 (freq = 1) + + # frequently_used and sometimes_used should still be there + self.assertEqual(cache.get("frequently_used"), 3) + self.assertEqual(cache.get("sometimes_used"), 2) + + # rarely_used and new1 should be evicted + self.assertIsNone(cache.get("rarely_used")) + self.assertIsNone(cache.get("new1")) + + # new2 should be there + self.assertEqual(cache.get("new2"), 5) + + +class TestLFUCacheInterface(unittest.TestCase): + """Test that LFUCache properly implements CacheInterface.""" + + def test_interface_methods_exist(self): + """Verify all interface methods are implemented.""" + cache = LFUCache(5) + + # Check all required methods exist and are callable + self.assertTrue(callable(getattr(cache, "get", None))) + self.assertTrue(callable(getattr(cache, "put", None))) + self.assertTrue(callable(getattr(cache, "size", None))) + self.assertTrue(callable(getattr(cache, "is_empty", None))) + self.assertTrue(callable(getattr(cache, "clear", None))) + + # Check property + self.assertEqual(cache.capacity, 5) + + def test_persistence_interface_methods(self): + """Verify persistence interface methods are implemented.""" + # Cache without persistence + cache_no_persist = LFUCache(5) + self.assertTrue(callable(getattr(cache_no_persist, "persist_now", None))) + self.assertTrue( + callable(getattr(cache_no_persist, "supports_persistence", None)) + ) + self.assertFalse(cache_no_persist.supports_persistence()) + + # Cache with persistence + temp_file = os.path.join(tempfile.mkdtemp(), "test_interface.json") + try: + cache_with_persist = LFUCache( + 5, persistence_interval=10.0, persistence_path=temp_file + ) + self.assertTrue(cache_with_persist.supports_persistence()) + + # persist_now should work without errors + cache_with_persist.put("key", "value") + cache_with_persist.persist_now() # Should not raise any exception + finally: + if os.path.exists(temp_file): + os.remove(temp_file) + if os.path.exists(os.path.dirname(temp_file)): + os.rmdir(os.path.dirname(temp_file)) + + +class TestLFUCachePersistence(unittest.TestCase): + """Test cases for LFU Cache persistence functionality.""" + + def setUp(self): + """Set up test fixtures.""" + # Create temporary directory for test files + self.temp_dir = tempfile.mkdtemp() + self.test_file = os.path.join(self.temp_dir, "test_cache.json") + + def tearDown(self): + """Clean up test files.""" + # Clean up any created files + if os.path.exists(self.test_file): + os.remove(self.test_file) + # Remove temporary directory + if os.path.exists(self.temp_dir): + os.rmdir(self.temp_dir) + + def test_basic_persistence(self): + """Test basic save and load functionality.""" + # Create cache and add items + cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + + cache.put("key1", "value1") + cache.put("key2", {"nested": "value"}) + cache.put("key3", [1, 2, 3]) + + # Force persistence + cache.persist_now() + + # Verify file was created + self.assertTrue(os.path.exists(self.test_file)) + + # Load into new cache + new_cache = LFUCache( + 5, persistence_interval=10.0, persistence_path=self.test_file + ) + + # Verify data was loaded correctly + self.assertEqual(new_cache.size(), 3) + self.assertEqual(new_cache.get("key1"), "value1") + self.assertEqual(new_cache.get("key2"), {"nested": "value"}) + self.assertEqual(new_cache.get("key3"), [1, 2, 3]) + + def test_frequency_preservation(self): + """Test that frequencies are preserved across persistence.""" + cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + + # Create different frequency levels + cache.put("freq1", "value1") + cache.put("freq3", "value3") + cache.put("freq5", "value5") + + # Access items to create different frequencies + cache.get("freq3") # freq = 2 + cache.get("freq3") # freq = 3 + + cache.get("freq5") # freq = 2 + cache.get("freq5") # freq = 3 + cache.get("freq5") # freq = 4 + cache.get("freq5") # freq = 5 + + # Force persistence + cache.persist_now() + + # Load into new cache + new_cache = LFUCache( + 5, persistence_interval=10.0, persistence_path=self.test_file + ) + + # Add new items to test eviction order + new_cache.put("new1", "newvalue1") + new_cache.put("new2", "newvalue2") + new_cache.put("new3", "newvalue3") + + # freq1 should be evicted first (lowest frequency) + self.assertIsNone(new_cache.get("freq1")) + # freq3 and freq5 should still be there + self.assertEqual(new_cache.get("freq3"), "value3") + self.assertEqual(new_cache.get("freq5"), "value5") + + def test_periodic_persistence(self): + """Test automatic periodic persistence.""" + # Use short interval for testing + cache = LFUCache(5, persistence_interval=0.5, persistence_path=self.test_file) + + cache.put("key1", "value1") + + # File shouldn't exist yet + self.assertFalse(os.path.exists(self.test_file)) + + # Wait for interval to pass + time.sleep(0.6) + + # Access cache to trigger persistence check + cache.get("key1") + + # File should now exist + self.assertTrue(os.path.exists(self.test_file)) + + # Verify content + with open(self.test_file, "r") as f: + data = json.load(f) + + self.assertEqual(data["capacity"], 5) + self.assertEqual(len(data["nodes"]), 1) + self.assertEqual(data["nodes"][0]["key"], "key1") + + def test_persistence_with_empty_cache(self): + """Test persistence behavior with empty cache.""" + cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + + # Add and remove items + cache.put("key1", "value1") + cache.clear() + + # Force persistence + cache.persist_now() + + # File should be removed when cache is empty + self.assertFalse(os.path.exists(self.test_file)) + + def test_no_persistence_when_disabled(self): + """Test that persistence doesn't occur when not configured.""" + # Create cache without persistence + cache = LFUCache(5) + + cache.put("key1", "value1") + cache.persist_now() # Should do nothing + + # No file should be created + self.assertFalse(os.path.exists("lfu_cache.json")) + + def test_load_from_nonexistent_file(self): + """Test loading when persistence file doesn't exist.""" + # Create cache with non-existent file + cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + + # Should start empty + self.assertEqual(cache.size(), 0) + self.assertTrue(cache.is_empty()) + + def test_persistence_with_complex_data(self): + """Test persistence with various data types.""" + cache = LFUCache(10, persistence_interval=10.0, persistence_path=self.test_file) + + # Add various data types + test_data = { + "string": "hello world", + "int": 42, + "float": 3.14, + "bool": True, + "none": None, + "list": [1, 2, [3, 4]], + "dict": {"a": 1, "b": {"c": 2}}, + "tuple_key": "value_for_tuple", # Will use string key since tuples aren't JSON serializable + } + + for key, value in test_data.items(): + cache.put(key, value) + + # Force persistence + cache.persist_now() + + # Load into new cache + new_cache = LFUCache( + 10, persistence_interval=10.0, persistence_path=self.test_file + ) + + # Verify all data types + for key, value in test_data.items(): + self.assertEqual(new_cache.get(key), value) + + def test_persistence_file_corruption_handling(self): + """Test handling of corrupted persistence files.""" + # Create invalid JSON file + with open(self.test_file, "w") as f: + f.write("{ invalid json content") + + # Should handle gracefully and start with empty cache + cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + self.assertEqual(cache.size(), 0) + + # Cache should still be functional + cache.put("key1", "value1") + self.assertEqual(cache.get("key1"), "value1") + + def test_multiple_persistence_cycles(self): + """Test multiple save/load cycles.""" + # First cycle + cache1 = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + cache1.put("key1", "value1") + cache1.put("key2", "value2") + cache1.persist_now() + + # Second cycle - load and modify + cache2 = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + self.assertEqual(cache2.size(), 2) + cache2.put("key3", "value3") + cache2.persist_now() + + # Third cycle - verify all changes + cache3 = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + self.assertEqual(cache3.size(), 3) + self.assertEqual(cache3.get("key1"), "value1") + self.assertEqual(cache3.get("key2"), "value2") + self.assertEqual(cache3.get("key3"), "value3") + + def test_capacity_change_on_load(self): + """Test loading cache data into cache with different capacity.""" + # Create cache with capacity 5 + cache1 = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + for i in range(5): + cache1.put(f"key{i}", f"value{i}") + cache1.persist_now() + + # Load into cache with smaller capacity + cache2 = LFUCache(3, persistence_interval=10.0, persistence_path=self.test_file) + + # Current design: loads all persisted items regardless of new capacity + # This is a valid design choice - preserve data integrity on load + self.assertEqual(cache2.size(), 5) + + # The cache continues to operate with loaded items + # New items can still be added, and the cache will manage its size + cache2.put("new_key", "new_value") + + # Verify the cache is still functional and contains the new item + self.assertEqual(cache2.get("new_key"), "new_value") + self.assertGreaterEqual( + cache2.size(), 4 + ) # At least has the new item plus some old ones + + def test_persistence_timing(self): + """Test that persistence doesn't happen too frequently.""" + cache = LFUCache(5, persistence_interval=1.0, persistence_path=self.test_file) + + cache.put("key1", "value1") + + # Multiple operations within interval shouldn't trigger persistence + for i in range(10): + cache.get("key1") + self.assertFalse(os.path.exists(self.test_file)) + time.sleep(0.05) # Total time still less than interval + + # Wait for interval to pass + time.sleep(0.6) + cache.get("key1") + + # Now file should exist + self.assertTrue(os.path.exists(self.test_file)) + + def test_persistence_with_statistics(self): + """Test persistence doesn't interfere with statistics tracking.""" + cache = LFUCache( + 5, + track_stats=True, + persistence_interval=0.5, + persistence_path=self.test_file, + ) + + # Perform operations + cache.put("key1", "value1") + cache.put("key2", "value2") + cache.get("key1") + cache.get("nonexistent") + + # Wait for persistence + time.sleep(0.6) + cache.get("key1") # Trigger persistence + + # Check stats are still correct + stats = cache.get_stats() + self.assertEqual(stats["puts"], 2) + self.assertEqual(stats["hits"], 2) + self.assertEqual(stats["misses"], 1) + + # Load into new cache with stats + new_cache = LFUCache( + 5, + track_stats=True, + persistence_interval=0.5, + persistence_path=self.test_file, + ) + + # Stats should be reset in new instance + new_stats = new_cache.get_stats() + self.assertEqual(new_stats["puts"], 0) + self.assertEqual(new_stats["hits"], 0) + + # But data should be loaded + self.assertEqual(new_cache.size(), 2) + + +class TestLFUCacheStatsLogging(unittest.TestCase): + """Test cases for LFU Cache statistics logging functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.test_file = tempfile.mktemp() + + def tearDown(self): + """Clean up test files.""" + if os.path.exists(self.test_file): + os.remove(self.test_file) + + def test_stats_logging_disabled_by_default(self): + """Test that stats logging is disabled when not configured.""" + cache = LFUCache(5, track_stats=True) + self.assertFalse(cache.supports_stats_logging()) + + def test_stats_logging_requires_tracking(self): + """Test that stats logging requires stats tracking to be enabled.""" + # Logging without tracking + cache = LFUCache(5, track_stats=False, stats_logging_interval=1.0) + self.assertFalse(cache.supports_stats_logging()) + + # Both enabled + cache = LFUCache(5, track_stats=True, stats_logging_interval=1.0) + self.assertTrue(cache.supports_stats_logging()) + + def test_log_stats_now(self): + """Test immediate stats logging.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=60.0) + + # Add some data + cache.put("key1", "value1") + cache.put("key2", "value2") + cache.get("key1") + cache.get("nonexistent") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + # Verify log was called + self.assertEqual(mock_log.call_count, 1) + log_message = mock_log.call_args[0][0] + + # Check log format + self.assertIn("LFU Cache Statistics", log_message) + self.assertIn("Size: 2/5", log_message) + self.assertIn("Hits: 1", log_message) + self.assertIn("Misses: 1", log_message) + self.assertIn("Hit Rate: 50.00%", log_message) + self.assertIn("Evictions: 0", log_message) + self.assertIn("Puts: 2", log_message) + self.assertIn("Updates: 0", log_message) + + def test_periodic_stats_logging(self): + """Test automatic periodic stats logging.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.5) + + # Add some data + cache.put("key1", "value1") + cache.put("key2", "value2") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + # Initial operations shouldn't trigger logging + cache.get("key1") + self.assertEqual(mock_log.call_count, 0) + + # Wait for interval to pass + time.sleep(0.6) + + # Next operation should trigger logging + cache.get("key1") + self.assertEqual(mock_log.call_count, 1) + + # Another operation without waiting shouldn't trigger + cache.get("key2") + self.assertEqual(mock_log.call_count, 1) + + # Wait again + time.sleep(0.6) + cache.put("key3", "value3") + self.assertEqual(mock_log.call_count, 2) + + def test_stats_logging_with_empty_cache(self): + """Test stats logging with empty cache.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.1) + + # Generate a miss first + cache.get("nonexistent") + + # Wait for interval to pass + time.sleep(0.2) + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + # This will trigger stats logging with the previous miss already counted + cache.get("another_nonexistent") # Trigger check + + self.assertEqual(mock_log.call_count, 1) + log_message = mock_log.call_args[0][0] + + self.assertIn("Size: 0/5", log_message) + self.assertIn("Hits: 0", log_message) + self.assertIn("Misses: 1", log_message) # The first miss is logged + self.assertIn("Hit Rate: 0.00%", log_message) + + def test_stats_logging_with_full_cache(self): + """Test stats logging when cache is at capacity.""" + import logging + from unittest.mock import patch + + cache = LFUCache(3, track_stats=True, stats_logging_interval=0.1) + + # Fill cache + cache.put("key1", "value1") + cache.put("key2", "value2") + cache.put("key3", "value3") + + # Cause eviction + cache.put("key4", "value4") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + time.sleep(0.2) + cache.get("key4") # Trigger check + + log_message = mock_log.call_args[0][0] + self.assertIn("Size: 3/3", log_message) + self.assertIn("Evictions: 1", log_message) + self.assertIn("Puts: 4", log_message) + + def test_stats_logging_high_hit_rate(self): + """Test stats logging with high hit rate.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.1) + + cache.put("key1", "value1") + + # Many hits + for _ in range(99): + cache.get("key1") + + # One miss + cache.get("nonexistent") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + log_message = mock_log.call_args[0][0] + self.assertIn("Hit Rate: 99.00%", log_message) + self.assertIn("Hits: 99", log_message) + self.assertIn("Misses: 1", log_message) + + def test_stats_logging_without_tracking(self): + """Test that log_stats_now does nothing when tracking is disabled.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=False) + + cache.put("key1", "value1") + cache.get("key1") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + # Should not log anything + self.assertEqual(mock_log.call_count, 0) + + def test_stats_logging_interval_timing(self): + """Test that stats logging respects the interval timing.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=1.0) + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + # Multiple operations within interval + for i in range(10): + cache.put(f"key{i}", f"value{i}") + cache.get(f"key{i}") + time.sleep(0.05) # Total time < 1.0 + + # Should not have logged yet + self.assertEqual(mock_log.call_count, 0) + + # Wait for interval to pass + time.sleep(0.6) + cache.get("key1") # Trigger check + + # Now should have logged once + self.assertEqual(mock_log.call_count, 1) + + def test_stats_logging_with_updates(self): + """Test stats logging includes update counts.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.1) + + cache.put("key1", "value1") + cache.put("key1", "updated_value1") # Update + cache.put("key1", "updated_again") # Another update + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + log_message = mock_log.call_args[0][0] + self.assertIn("Updates: 2", log_message) + self.assertIn("Puts: 1", log_message) + + def test_stats_logging_combined_with_persistence(self): + """Test that stats logging and persistence work together.""" + import logging + from unittest.mock import patch + + cache = LFUCache( + 5, + track_stats=True, + persistence_interval=1.0, + persistence_path=self.test_file, + stats_logging_interval=0.5, + ) + + cache.put("key1", "value1") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + # Wait for stats logging interval + time.sleep(0.6) + cache.get("key1") # Trigger stats log + + self.assertEqual(mock_log.call_count, 1) + self.assertFalse(os.path.exists(self.test_file)) # Not persisted yet + + # Wait for persistence interval + time.sleep(0.5) + cache.get("key1") # Trigger persistence + + self.assertTrue(os.path.exists(self.test_file)) # Now persisted + # Stats log might trigger again if interval passed + self.assertGreaterEqual(mock_log.call_count, 1) + + def test_stats_log_format_percentages(self): + """Test that percentages in stats log are formatted correctly.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.1) + + # Test various hit rates + test_cases = [ + (0, 0, "0.00%"), # No requests + (1, 0, "100.00%"), # All hits + (0, 1, "0.00%"), # All misses + (1, 1, "50.00%"), # 50/50 + (2, 1, "66.67%"), # 2/3 + (99, 1, "99.00%"), # High hit rate + ] + + for hits, misses, expected_rate in test_cases: + cache.reset_stats() + + # Generate hits + if hits > 0: + cache.put("hit_key", "value") + for _ in range(hits): + cache.get("hit_key") + + # Generate misses + for i in range(misses): + cache.get(f"miss_key_{i}") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + if hits > 0 or misses > 0: + log_message = mock_log.call_args[0][0] + self.assertIn(f"Hit Rate: {expected_rate}", log_message) + + +class TestContentSafetyCacheStatsConfig(unittest.TestCase): + """Test cache stats configuration in content safety context.""" + + def setUp(self): + """Set up test fixtures.""" + self.test_file = tempfile.mktemp() + + def tearDown(self): + """Clean up test files.""" + if os.path.exists(self.test_file): + os.remove(self.test_file) + + def test_cache_config_with_stats_disabled(self): + """Test cache configuration with stats disabled.""" + from nemoguardrails.library.content_safety.manager import ContentSafetyManager + from nemoguardrails.rails.llm.config import ( + CacheStatsConfig, + ModelCacheConfig, + ModelConfig, + ) + + cache_config = ModelCacheConfig( + enabled=True, capacity_per_model=1000, stats=CacheStatsConfig(enabled=False) + ) + + model_config = ModelConfig(cache=cache_config) + manager = ContentSafetyManager(model_config) + + cache = manager.get_cache_for_model("test_model") + self.assertIsNotNone(cache) + self.assertFalse(cache.track_stats) + self.assertFalse(cache.supports_stats_logging()) + + def test_cache_config_with_stats_tracking_only(self): + """Test cache configuration with stats tracking but no logging.""" + from nemoguardrails.library.content_safety.manager import ContentSafetyManager + from nemoguardrails.rails.llm.config import ( + CacheStatsConfig, + ModelCacheConfig, + ModelConfig, + ) + + cache_config = ModelCacheConfig( + enabled=True, + capacity_per_model=1000, + stats=CacheStatsConfig(enabled=True, log_interval=None), + ) + + model_config = ModelConfig(cache=cache_config) + manager = ContentSafetyManager(model_config) + + cache = manager.get_cache_for_model("test_model") + self.assertIsNotNone(cache) + self.assertTrue(cache.track_stats) + self.assertFalse(cache.supports_stats_logging()) + self.assertIsNone(cache.stats_logging_interval) + + def test_cache_config_with_stats_logging(self): + """Test cache configuration with stats tracking and logging.""" + from nemoguardrails.library.content_safety.manager import ContentSafetyManager + from nemoguardrails.rails.llm.config import ( + CacheStatsConfig, + ModelCacheConfig, + ModelConfig, + ) + + cache_config = ModelCacheConfig( + enabled=True, + capacity_per_model=1000, + stats=CacheStatsConfig(enabled=True, log_interval=60.0), + ) + + model_config = ModelConfig(cache=cache_config) + manager = ContentSafetyManager(model_config) + + cache = manager.get_cache_for_model("test_model") + self.assertIsNotNone(cache) + self.assertTrue(cache.track_stats) + self.assertTrue(cache.supports_stats_logging()) + self.assertEqual(cache.stats_logging_interval, 60.0) + + def test_cache_config_default_stats(self): + """Test cache configuration with default stats settings.""" + from nemoguardrails.library.content_safety.manager import ContentSafetyManager + from nemoguardrails.rails.llm.config import ModelCacheConfig, ModelConfig + + cache_config = ModelCacheConfig(enabled=True) + + model_config = ModelConfig(cache=cache_config) + manager = ContentSafetyManager(model_config) + + cache = manager.get_cache_for_model("test_model") + self.assertIsNotNone(cache) + self.assertFalse(cache.track_stats) # Default is disabled + self.assertFalse(cache.supports_stats_logging()) + + def test_cache_config_stats_with_persistence(self): + """Test cache configuration with both stats and persistence.""" + from nemoguardrails.library.content_safety.manager import ContentSafetyManager + from nemoguardrails.rails.llm.config import ( + CachePersistenceConfig, + CacheStatsConfig, + ModelCacheConfig, + ModelConfig, + ) + + cache_config = ModelCacheConfig( + enabled=True, + capacity_per_model=1000, + stats=CacheStatsConfig(enabled=True, log_interval=30.0), + persistence=CachePersistenceConfig( + enabled=True, interval=60.0, path=self.test_file + ), + ) + + model_config = ModelConfig(cache=cache_config) + manager = ContentSafetyManager(model_config) + + cache = manager.get_cache_for_model("test_model") + self.assertIsNotNone(cache) + self.assertTrue(cache.track_stats) + self.assertTrue(cache.supports_stats_logging()) + self.assertEqual(cache.stats_logging_interval, 30.0) + self.assertTrue(cache.supports_persistence()) + self.assertEqual(cache.persistence_interval, 60.0) + + def test_cache_config_from_dict(self): + """Test cache configuration creation from dictionary.""" + from nemoguardrails.rails.llm.config import ModelCacheConfig + + config_dict = { + "enabled": True, + "capacity_per_model": 5000, + "stats": {"enabled": True, "log_interval": 120.0}, + } + + cache_config = ModelCacheConfig(**config_dict) + self.assertTrue(cache_config.enabled) + self.assertEqual(cache_config.capacity_per_model, 5000) + self.assertTrue(cache_config.stats.enabled) + self.assertEqual(cache_config.stats.log_interval, 120.0) + + def test_cache_config_stats_validation(self): + """Test cache configuration validation for stats settings.""" + from nemoguardrails.rails.llm.config import CacheStatsConfig + + # Valid configurations + stats1 = CacheStatsConfig(enabled=True, log_interval=60.0) + self.assertTrue(stats1.enabled) + self.assertEqual(stats1.log_interval, 60.0) + + stats2 = CacheStatsConfig(enabled=True, log_interval=None) + self.assertTrue(stats2.enabled) + self.assertIsNone(stats2.log_interval) + + stats3 = CacheStatsConfig(enabled=False, log_interval=60.0) + self.assertFalse(stats3.enabled) + self.assertEqual(stats3.log_interval, 60.0) + + def test_multiple_model_caches_with_stats(self): + """Test multiple model caches each with their own stats configuration.""" + from nemoguardrails.library.content_safety.manager import ContentSafetyManager + from nemoguardrails.rails.llm.config import ( + CacheStatsConfig, + ModelCacheConfig, + ModelConfig, + ) + + cache_config = ModelCacheConfig( + enabled=True, + capacity_per_model=1000, + stats=CacheStatsConfig(enabled=True, log_interval=30.0), + ) + + model_config = ModelConfig( + cache=cache_config, model_mapping={"model_alias": "actual_model"} + ) + manager = ContentSafetyManager(model_config) + + # Get caches for different models + cache1 = manager.get_cache_for_model("model1") + cache2 = manager.get_cache_for_model("model2") + cache_alias = manager.get_cache_for_model("model_alias") + cache_actual = manager.get_cache_for_model("actual_model") + + # All should have stats enabled + self.assertTrue(cache1.track_stats) + self.assertTrue(cache2.track_stats) + self.assertTrue(cache_alias.track_stats) + + # Alias should resolve to same cache as actual + self.assertIs(cache_alias, cache_actual) + + +class TestLFUCacheThreadSafety(unittest.TestCase): + """Test thread safety of LFU Cache implementation.""" + + def setUp(self): + """Set up test fixtures.""" + self.cache = LFUCache(100, track_stats=True) + + def test_concurrent_reads_writes(self): + """Test that concurrent reads and writes don't corrupt the cache.""" + num_threads = 10 + operations_per_thread = 100 + # Use a larger cache to avoid evictions during the test + large_cache = LFUCache(2000, track_stats=True) + errors = [] + + def worker(thread_id): + """Worker function that performs cache operations.""" + for i in range(operations_per_thread): + key = f"thread_{thread_id}_key_{i}" + value = f"thread_{thread_id}_value_{i}" + + # Put operation + large_cache.put(key, value) + + # Get operation - should always succeed with large cache + retrieved = large_cache.get(key) + + # Verify data integrity + if retrieved != value: + errors.append( + f"Data corruption for {key}: expected {value}, got {retrieved}" + ) + + # Access some shared keys + shared_key = f"shared_key_{i % 10}" + large_cache.put(shared_key, f"shared_value_{thread_id}_{i}") + large_cache.get(shared_key) + + # Run threads + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + for future in futures: + future.result() # Wait for completion and raise any exceptions + + # Check for any errors + self.assertEqual(len(errors), 0, f"Errors occurred: {errors[:5]}...") + + # Verify cache is still functional + test_key = "test_after_concurrent" + test_value = "test_value" + large_cache.put(test_key, test_value) + self.assertEqual(large_cache.get(test_key), test_value) + + # Check statistics are reasonable + stats = large_cache.get_stats() + self.assertGreater(stats["hits"], 0) + self.assertGreater(stats["puts"], 0) + + def test_concurrent_evictions(self): + """Test that concurrent operations during evictions don't corrupt the cache.""" + # Use a small cache to trigger frequent evictions + small_cache = LFUCache(10) + num_threads = 5 + operations_per_thread = 50 + + def worker(thread_id): + """Worker that adds many items to trigger evictions.""" + for i in range(operations_per_thread): + key = f"t{thread_id}_k{i}" + value = f"t{thread_id}_v{i}" + small_cache.put(key, value) + + # Try to get recently added items + if i > 0: + prev_key = f"t{thread_id}_k{i-1}" + small_cache.get(prev_key) # May or may not exist + + # Run threads + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + for future in futures: + future.result() + + # Cache should still be at capacity + self.assertEqual(small_cache.size(), 10) + + def test_concurrent_clear_operations(self): + """Test concurrent clear operations with other operations.""" + + def writer(): + """Continuously write to cache.""" + for i in range(100): + self.cache.put(f"key_{i}", f"value_{i}") + time.sleep(0.001) # Small delay + + def clearer(): + """Periodically clear the cache.""" + for _ in range(5): + time.sleep(0.01) + self.cache.clear() + + def reader(): + """Continuously read from cache.""" + for i in range(100): + self.cache.get(f"key_{i}") + time.sleep(0.001) + + # Run operations concurrently + threads = [ + threading.Thread(target=writer), + threading.Thread(target=clearer), + threading.Thread(target=reader), + ] + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + # Cache should still be functional + self.cache.put("final_key", "final_value") + self.assertEqual(self.cache.get("final_key"), "final_value") + + def test_concurrent_stats_operations(self): + """Test that concurrent operations don't corrupt statistics.""" + + def worker(thread_id): + """Worker that performs operations and checks stats.""" + for i in range(50): + key = f"stats_key_{thread_id}_{i}" + self.cache.put(key, i) + self.cache.get(key) # Hit + self.cache.get(f"nonexistent_{thread_id}_{i}") # Miss + + # Periodically check stats + if i % 10 == 0: + stats = self.cache.get_stats() + # Just verify we can get stats without error + self.assertIsInstance(stats, dict) + self.assertIn("hits", stats) + self.assertIn("misses", stats) + + # Run threads + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + for future in futures: + future.result() + + # Final stats check + final_stats = self.cache.get_stats() + self.assertGreater(final_stats["hits"], 0) + self.assertGreater(final_stats["misses"], 0) + self.assertGreater(final_stats["puts"], 0) + + def test_get_or_compute_thread_safety(self): + """Test thread safety of get_or_compute method.""" + compute_count = threading.local() + compute_count.value = 0 + total_computes = [] + lock = threading.Lock() + + async def expensive_compute(): + """Simulate expensive computation that should only run once.""" + # Track how many times this is called + if not hasattr(compute_count, "value"): + compute_count.value = 0 + compute_count.value += 1 + + with lock: + total_computes.append(1) + + # Simulate expensive operation + await asyncio.sleep(0.1) + return f"computed_value_{len(total_computes)}" + + async def worker(thread_id): + """Worker that tries to get or compute the same key.""" + result = await self.cache.get_or_compute( + "shared_compute_key", expensive_compute, default="default" + ) + return result + + async def run_test(): + """Run the async test.""" + # Run multiple workers concurrently + tasks = [worker(i) for i in range(10)] + results = await asyncio.gather(*tasks) + + # All should get the same value + self.assertTrue( + all(r == results[0] for r in results), + f"All threads should get same value, got: {results}", + ) + + # Compute should have been called only once + self.assertEqual( + len(total_computes), + 1, + f"Compute should be called once, called {len(total_computes)} times", + ) + + return results[0] + + # Run the async test + result = asyncio.run(run_test()) + self.assertEqual(result, "computed_value_1") + + def test_get_or_compute_exception_handling(self): + """Test get_or_compute handles exceptions properly.""" + call_count = [0] + + async def failing_compute(): + """Compute function that fails.""" + call_count[0] += 1 + raise ValueError("Computation failed") + + async def worker(): + """Worker that tries to compute.""" + result = await self.cache.get_or_compute( + "failing_key", failing_compute, default="fallback" + ) + return result + + async def run_test(): + """Run the async test.""" + # Multiple workers should all get the default value + tasks = [worker() for _ in range(5)] + results = await asyncio.gather(*tasks) + + # All should get the default value + self.assertTrue(all(r == "fallback" for r in results)) + + # The compute function might be called multiple times + # since failed computations aren't cached + self.assertGreaterEqual(call_count[0], 1) + + asyncio.run(run_test()) + + def test_concurrent_persistence(self): + """Test thread safety of persistence operations.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f: + cache_file = f.name + + try: + # Create cache with persistence + cache = LFUCache( + capacity=50, + track_stats=True, + persistence_interval=0.1, # Short interval for testing + persistence_path=cache_file, + ) + + def worker(thread_id): + """Worker that performs operations.""" + for i in range(20): + cache.put(f"persist_key_{thread_id}_{i}", f"value_{thread_id}_{i}") + cache.get(f"persist_key_{thread_id}_{i}") + + # Force persistence sometimes + if i % 5 == 0: + cache.persist_now() + + # Run workers + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + for future in futures: + future.result() + + # Final persist + cache.persist_now() + + # Load the persisted data + new_cache = LFUCache( + capacity=50, persistence_interval=1.0, persistence_path=cache_file + ) + + # Verify some data was persisted correctly + # (Due to capacity limits, not all items will be present) + self.assertGreater(new_cache.size(), 0) + self.assertLessEqual(new_cache.size(), 50) + + finally: + # Clean up + if os.path.exists(cache_file): + os.unlink(cache_file) + + def test_thread_safe_size_operations(self): + """Test that size-related operations are thread-safe.""" + results = [] + + def worker(thread_id): + """Worker that checks size consistency.""" + for i in range(100): + # Add item + self.cache.put(f"size_key_{thread_id}_{i}", i) + + # Check size + size = self.cache.size() + is_empty = self.cache.is_empty() + + # Size should never be negative or exceed capacity + if size < 0 or size > 100: + results.append(f"Invalid size: {size}") + + # is_empty should match size + if (size == 0) != is_empty: + results.append(f"Size {size} but is_empty={is_empty}") + + # Run workers + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(worker, i) for i in range(10)] + for future in futures: + future.result() + + # Check for any inconsistencies + self.assertEqual(len(results), 0, f"Inconsistencies found: {results}") + + def test_concurrent_contains_operations(self): + """Test thread safety of contains method.""" + # Use a larger cache to avoid evictions during the test + # Need capacity for: 50 existing + (5 threads × 100 new keys) = 550+ + large_cache = LFUCache(1000, track_stats=True) + + # Pre-populate cache + for i in range(50): + large_cache.put(f"existing_key_{i}", f"value_{i}") + + results = [] + eviction_warnings = [] + + def worker(thread_id): + """Worker that checks contains and manipulates cache.""" + for i in range(100): + # Check existing keys + key = f"existing_key_{i % 50}" + if not large_cache.contains(key): + results.append(f"Thread {thread_id}: Missing key {key}") + + # Add new keys + new_key = f"new_key_{thread_id}_{i}" + large_cache.put(new_key, f"value_{thread_id}_{i}") + + # Check new key immediately + if not large_cache.contains(new_key): + # This could happen if cache is full and eviction occurred + # Track it separately as it's not a thread safety issue + eviction_warnings.append( + f"Thread {thread_id}: Key {new_key} possibly evicted" + ) + + # Check non-existent keys + if large_cache.contains(f"non_existent_{thread_id}_{i}"): + results.append(f"Thread {thread_id}: Found non-existent key") + + # Run workers + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + for future in futures: + future.result() + + # Check for any errors (not counting eviction warnings) + self.assertEqual(len(results), 0, f"Errors found: {results}") + + # Eviction warnings should be minimal with large cache + if eviction_warnings: + print(f"Note: {len(eviction_warnings)} keys were evicted during test") + + def test_concurrent_reset_stats(self): + """Test thread safety of reset_stats operations.""" + errors = [] + + def worker(thread_id): + """Worker that performs operations and resets stats.""" + for i in range(50): + # Perform operations + self.cache.put(f"key_{thread_id}_{i}", i) + self.cache.get(f"key_{thread_id}_{i}") + self.cache.get("non_existent") + + # Periodically reset stats + if i % 10 == 0: + self.cache.reset_stats() + + # Check stats integrity + stats = self.cache.get_stats() + if any(v < 0 for v in stats.values() if isinstance(v, (int, float))): + errors.append(f"Thread {thread_id}: Negative stat value: {stats}") + + # Run workers + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + for future in futures: + future.result() + + # Verify no errors + self.assertEqual(len(errors), 0, f"Stats errors: {errors[:5]}") + + def test_get_or_compute_concurrent_different_keys(self): + """Test get_or_compute with different keys being computed concurrently.""" + compute_counts = {} + lock = threading.Lock() + + async def compute_for_key(key): + """Compute function that tracks calls per key.""" + with lock: + compute_counts[key] = compute_counts.get(key, 0) + 1 + await asyncio.sleep(0.05) # Simulate work + return f"value_for_{key}" + + async def worker(thread_id, key_id): + """Worker that computes values for specific keys.""" + key = f"key_{key_id}" + result = await self.cache.get_or_compute( + key, lambda: compute_for_key(key), default="error" + ) + return key, result + + async def run_test(): + """Run concurrent computations for different keys.""" + # Create tasks for multiple keys, with some overlap + tasks = [] + for key_id in range(5): + for thread_id in range(3): # 3 threads per key + tasks.append(worker(thread_id, key_id)) + + results = await asyncio.gather(*tasks) + + # Verify each key was computed exactly once + for key_id in range(5): + key = f"key_{key_id}" + self.assertEqual( + compute_counts.get(key, 0), + 1, + f"{key} should be computed exactly once", + ) + + # Verify all threads got correct values + for key, value in results: + expected = f"value_for_{key}" + self.assertEqual(value, expected) + + asyncio.run(run_test()) + + def test_concurrent_operations_with_evictions(self): + """Test thread safety when cache is at capacity and evictions occur.""" + # Small cache to force evictions + small_cache = LFUCache(50, track_stats=True) + data_integrity_errors = [] + + def worker(thread_id): + """Worker that handles potential evictions gracefully.""" + for i in range(100): + key = f"t{thread_id}_k{i}" + value = f"t{thread_id}_v{i}" + + # Put value + small_cache.put(key, value) + + # Immediately access to increase frequency + retrieved = small_cache.get(key) + + # Value might be None if evicted immediately (unlikely but possible) + if retrieved is not None and retrieved != value: + # This would indicate actual data corruption + data_integrity_errors.append( + f"Wrong value for {key}: expected {value}, got {retrieved}" + ) + + # Also work with some persistent keys (access multiple times) + persistent_key = f"persistent_{thread_id % 5}" + for _ in range(3): # Access 3 times to increase frequency + small_cache.put(persistent_key, f"persistent_value_{thread_id}") + small_cache.get(persistent_key) + + # Run workers + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(worker, i) for i in range(10)] + for future in futures: + future.result() + + # Should have no data integrity errors (wrong values) + self.assertEqual( + len(data_integrity_errors), + 0, + f"Data integrity errors: {data_integrity_errors}", + ) + + # Cache should be at capacity + self.assertEqual(small_cache.size(), 50) + + # Stats should show many evictions + stats = small_cache.get_stats() + self.assertGreater(stats["evictions"], 0) + self.assertGreater(stats["puts"], 0) + + +class TestContentSafetyManagerThreadSafety(unittest.TestCase): + """Test thread safety of ContentSafetyManager.""" + + def setUp(self): + """Set up test fixtures.""" + # Create mock cache config + self.cache_config = MagicMock() + self.cache_config.enabled = True + self.cache_config.store = "memory" + self.cache_config.capacity_per_model = 100 + self.cache_config.stats.enabled = True + self.cache_config.stats.log_interval = None + self.cache_config.persistence.enabled = False + self.cache_config.persistence.interval = None + self.cache_config.persistence.path = None + + # Create mock model config + self.model_config = MagicMock() + self.model_config.cache = self.cache_config + self.model_config.model_mapping = {"alias_model": "actual_model"} + + def test_concurrent_cache_creation(self): + """Test that concurrent cache creation returns the same instance.""" + manager = ContentSafetyManager(self.model_config) + caches = [] + + def worker(thread_id): + """Worker that gets cache for model.""" + cache = manager.get_cache_for_model("test_model") + caches.append((thread_id, cache)) + return cache + + # Run many threads to increase chance of race condition + with ThreadPoolExecutor(max_workers=20) as executor: + futures = [executor.submit(worker, i) for i in range(20)] + for future in futures: + future.result() + + # All caches should be the same instance + first_cache = caches[0][1] + for thread_id, cache in caches: + self.assertIs( + cache, first_cache, f"Thread {thread_id} got different cache instance" + ) + + def test_concurrent_multi_model_caches(self): + """Test concurrent access to caches for different models.""" + manager = ContentSafetyManager(self.model_config) + results = [] + + def worker(thread_id): + """Worker that accesses multiple model caches.""" + model_names = [f"model_{i}" for i in range(5)] + + for model_name in model_names: + cache = manager.get_cache_for_model(model_name) + + # Perform operations + key = f"thread_{thread_id}_key" + value = f"thread_{thread_id}_value" + cache.put(key, value) + retrieved = cache.get(key) + + if retrieved != value: + results.append(f"Mismatch for {model_name}: {retrieved} != {value}") + + # Run workers + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(worker, i) for i in range(10)] + for future in futures: + future.result() + + # Check for errors + self.assertEqual(len(results), 0, f"Errors found: {results}") + + def test_concurrent_persist_all_caches(self): + """Test thread safety of persist_all_caches method.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create mock config with persistence + cache_config = MagicMock() + cache_config.enabled = True + cache_config.store = "memory" + cache_config.capacity_per_model = 50 + cache_config.persistence.enabled = True + cache_config.persistence.interval = 1.0 + cache_config.persistence.path = f"{temp_dir}/cache_{{model_name}}.json" + cache_config.stats.enabled = True + cache_config.stats.log_interval = None + + model_config = MagicMock() + model_config.cache = cache_config + model_config.model_mapping = {} + + manager = ContentSafetyManager(model_config) + + # Create caches for multiple models + for i in range(5): + cache = manager.get_cache_for_model(f"model_{i}") + for j in range(10): + cache.put(f"key_{j}", f"value_{j}") + + persist_count = [0] + + def persist_worker(): + """Worker that calls persist_all_caches.""" + manager.persist_all_caches() + persist_count[0] += 1 + + def modify_worker(): + """Worker that modifies caches while persistence happens.""" + for i in range(20): + model_name = f"model_{i % 5}" + cache = manager.get_cache_for_model(model_name) + cache.put(f"new_key_{i}", f"new_value_{i}") + time.sleep(0.001) + + # Run persistence and modifications concurrently + threads = [] + + # Multiple persist threads + for _ in range(3): + t = threading.Thread(target=persist_worker) + threads.append(t) + t.start() + + # Modification thread + t = threading.Thread(target=modify_worker) + threads.append(t) + t.start() + + # Wait for all threads + for t in threads: + t.join() + + # Verify persistence was called + self.assertEqual(persist_count[0], 3) + + def test_model_alias_thread_safety(self): + """Test thread safety when using model aliases.""" + manager = ContentSafetyManager(self.model_config) + caches = [] + + def worker(use_alias): + """Worker that gets cache using alias or actual name.""" + if use_alias: + cache = manager.get_cache_for_model("alias_model") + else: + cache = manager.get_cache_for_model("actual_model") + caches.append(cache) + + # Mix of threads using alias and actual name + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [] + for i in range(10): + use_alias = i % 2 == 0 + futures.append(executor.submit(worker, use_alias)) + + for future in futures: + future.result() + + # All should get the same cache instance + first_cache = caches[0] + for cache in caches: + self.assertIs( + cache, + first_cache, + "Alias and actual model should resolve to same cache", + ) + + +if __name__ == "__main__": + unittest.main() From b07a0225f4c47ed2ac759a34acb49a01b387b6c5 Mon Sep 17 00:00:00 2001 From: Oren Hazai Date: Sat, 4 Oct 2025 16:04:10 +0300 Subject: [PATCH 02/19] remove cache persistence --- examples/configs/content_safety/README.md | 167 +++-- examples/configs/content_safety/config.yml | 41 +- examples/configs/content_safety/prompts.yml | 6 +- nemoguardrails/cache/README.md | 65 +- .../library/content_safety/actions.py | 26 +- .../library/content_safety/manager.py | 98 --- nemoguardrails/rails/llm/config.py | 21 - nemoguardrails/rails/llm/llmrails.py | 59 +- tests/test_cache_lfu.py | 693 +----------------- 9 files changed, 203 insertions(+), 973 deletions(-) delete mode 100644 nemoguardrails/library/content_safety/manager.py diff --git a/examples/configs/content_safety/README.md b/examples/configs/content_safety/README.md index 0645e9043..06de8b372 100644 --- a/examples/configs/content_safety/README.md +++ b/examples/configs/content_safety/README.md @@ -1,56 +1,116 @@ # Content Safety Configuration -This example demonstrates how to configure content safety rails with NeMo Guardrails, including optional cache persistence. +This example demonstrates how to configure content safety rails with NeMo Guardrails, from basic setup to advanced per-model configurations. ## Features - **Input Safety Checks**: Validates user inputs before processing - **Output Safety Checks**: Ensures bot responses are appropriate -- **Caching**: Reduces redundant API calls with LFU cache -- **Persistence**: Optional cache persistence for resilience across restarts - **Thread Safety**: Fully thread-safe for use in multi-threaded web servers +- **Per-Model Caching**: Optional caching with configurable settings per model +- **Multiple Models**: Support for different content safety models with different configurations ## Configuration Overview -The configuration includes: +### Basic Configuration -1. **Main Model**: The primary LLM for conversations (Llama 3.3 70B) -2. **Content Safety Model**: Dedicated model for safety checks (NemoGuard 8B) -3. **Rails**: Input and output safety check flows -4. **Cache Configuration**: Memory cache with optional persistence +The simplest configuration uses a single content safety model: + +```yaml +models: + - type: main + engine: nim + model: meta/llama-3.3-70b-instruct + + - type: content_safety + engine: nim + model: nvidia/llama-3.1-nemoguard-8b-content-safety + +rails: + input: + flows: + - content safety check input $model=content_safety + output: + flows: + - content safety check output $model=content_safety +``` + +### Advanced Configuration with Per-Model Caching + +For production environments, you can configure caching per model: + +```yaml +models: + - type: content_safety + engine: nim + model: nvidia/llama-3.1-nemoguard-8b-content-safety + cache: + enabled: true + capacity_per_model: 50000 # Larger cache for primary model + stats: + enabled: true + log_interval: 60.0 # Log stats every 60 seconds + + - type: llama_guard + engine: vllm_openai + model: meta-llama/Llama-Guard-7b + cache: + enabled: true + capacity_per_model: 25000 # Medium cache + stats: + enabled: false # No stats for this model +``` ## How It Works -1. **User Input**: When a user sends a message, it's checked by the content safety model -2. **Cache Check**: The system first checks if this content was already evaluated (cache hit) -3. **Safety Evaluation**: If not cached, the content safety model evaluates the input -4. **Result Caching**: The safety check result is cached for future use -5. **Response Generation**: If safe, the main model generates a response -6. **Output Check**: The response is also checked for safety before returning to the user +1. **User Input**: When a user sends a message, it's checked by the content safety model(s) +2. **Safety Evaluation**: The content safety model evaluates the input +3. **Caching** (optional): Results are cached to avoid duplicate API calls +4. **Response Generation**: If safe, the main model generates a response +5. **Output Check**: The response is also checked for safety before returning to the user + +## Cache Configuration Options -## Cache Persistence +### Default Behavior (No Caching) -The cache configuration includes: +By default, caching is **disabled**. Models without cache configuration will have no caching: -- **Automatic Saves**: Every 5 minutes (configurable) -- **Shutdown Saves**: Caches are automatically persisted when the application closes -- **Crash Recovery**: Cache reloads from disk on restart -- **Per-Model Storage**: Each model gets its own cache file +```yaml +models: + - type: shieldgemma + engine: google + model: google/shieldgemma-2b + # No cache config = no caching (default) +``` -To disable persistence, you can either: +### Enabling Cache -1. Set `enabled: false` in the persistence section -2. Remove the `persistence` section entirely -3. Set `interval` to `null` or remove it +Add cache configuration to any model definition: -Note: Persistence requires both `enabled: true` and a valid `interval` value to be active. +```yaml +cache: + enabled: true # Enable caching + capacity_per_model: 10000 # Cache capacity (number of entries) + store: "memory" # Cache storage type (currently only memory) + stats: + enabled: true # Enable statistics tracking + log_interval: 300.0 # Log stats every 5 minutes (optional) +``` + +## Architecture + +Each content safety model gets its own dedicated cache instance, providing: + +- **Isolated cache management** per model +- **Different cache capacities** for different models +- **Model-specific performance tuning** +- **Thread-safe concurrent access** ## Thread Safety The content safety implementation is fully thread-safe: - **Concurrent Requests**: Safely handles multiple simultaneous safety checks -- **No Data Corruption**: Thread-safe cache operations prevent data corruption - **Efficient Locking**: Uses RLock for minimal performance impact - **Atomic Operations**: Prevents duplicate LLM calls for the same content @@ -60,23 +120,6 @@ This makes it suitable for: - Concurrent request processing - High-traffic applications -### Proper Shutdown - -For best results, use one of these patterns: - -```python -# Context manager (recommended) -with LLMRails(config) as rails: - # Your code here - pass -# Caches automatically persisted on exit - -# Or manual cleanup -rails = LLMRails(config) -# Your code here -rails.close() # Persist caches -``` - ## Running the Example ```bash @@ -84,34 +127,18 @@ rails.close() # Persist caches nemoguardrails server --config examples/configs/content_safety/ ``` -## Customization - -### Adjust Cache Settings - -```yaml -cache: - enabled: true # Enable/disable caching - capacity_per_model: 5000 # Maximum entries per model - persistence: - interval: 300.0 # Seconds between saves - path: ./my_cache.json # Custom path -``` - -### Memory-Only Cache - -For memory-only caching without persistence: - -```yaml -cache: - enabled: true - capacity_per_model: 5000 - store: memory - # No persistence section -``` - ## Benefits 1. **Performance**: Avoid redundant content safety API calls 2. **Cost Savings**: Reduce API usage for repeated content -3. **Reliability**: Cache survives process restarts -4. **Flexibility**: Easy to enable/disable features as needed +3. **Flexibility**: Enable caching only for models that benefit from it +4. **Clean Architecture**: Each model has its own dedicated cache +5. **Scalability**: Easy to add new models with different caching strategies + +## Tips + +- Start with no caching to establish baseline performance +- Enable caching for frequently-used models first +- Use stats logging to monitor cache effectiveness +- Adjust cache capacity based on your usage patterns +- Consider different cache sizes for different models based on their usage diff --git a/examples/configs/content_safety/config.yml b/examples/configs/content_safety/config.yml index 5ca7c8608..018d1aade 100644 --- a/examples/configs/content_safety/config.yml +++ b/examples/configs/content_safety/config.yml @@ -3,29 +3,40 @@ models: engine: nim model: meta/llama-3.3-70b-instruct + # Multiple content safety models with different cache configurations - type: content_safety engine: nim model: nvidia/llama-3.1-nemoguard-8b-content-safety + # Model-specific cache configuration (optional) + cache: + enabled: true + capacity_per_model: 50000 # Larger cache for primary model + stats: + enabled: true + log_interval: 60.0 # Log stats every minute + + - type: llama_guard + engine: vllm_openai + model: meta-llama/Llama-Guard-7b + # Different cache settings for this model + cache: + enabled: true + capacity_per_model: 25000 # Medium cache + stats: + enabled: false # No stats for this model + + - type: shieldgemma + engine: google + model: google/shieldgemma-2b + # No cache configuration = no caching (default behavior) rails: input: flows: + # You can use multiple content safety models - content safety check input $model=content_safety + # - content safety check input $model=llama_guard + # - content safety check input $model=shieldgemma output: flows: - content safety check output $model=content_safety - - # Content safety cache configuration with persistence and stats - config: - content_safety: - cache: - enabled: true - capacity_per_model: 5000 - store: memory # In-memory cache with optional disk persistence - persistence: - enabled: true # Enable persistence (requires interval to be set) - interval: 300.0 # Persist every 5 minutes - path: ./content_safety_cache.json # Where to save cache - stats: - enabled: true # Enable statistics tracking - log_interval: 60.0 # Log cache statistics every minute diff --git a/examples/configs/content_safety/prompts.yml b/examples/configs/content_safety/prompts.yml index 1321a6461..61adc43cb 100644 --- a/examples/configs/content_safety/prompts.yml +++ b/examples/configs/content_safety/prompts.yml @@ -1,6 +1,10 @@ +# Default content safety prompts for nvidia/llama-3.1-nemoguard-8b-content-safety # These are the default prompts released by Meta, except for policy O7, which was added to address direct insults. +# +# To add prompts for other content safety models, add them below with the appropriate model name: +# - task: content_safety_check_input $model=llama_guard +# - task: content_safety_check_input $model=shieldgemma prompts: - - task: content_safety_check_input $model=content_safety content: | Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below. diff --git a/nemoguardrails/cache/README.md b/nemoguardrails/cache/README.md index 369e82f5e..c1ac575ae 100644 --- a/nemoguardrails/cache/README.md +++ b/nemoguardrails/cache/README.md @@ -2,7 +2,7 @@ ## Overview -The content safety checks in `actions.py` now use an LFU (Least Frequently Used) cache to improve performance by avoiding redundant LLM calls for identical safety checks. The cache supports optional persistence to disk for resilience across restarts. +The content safety checks in `actions.py` now use an LFU (Least Frequently Used) cache to improve performance by avoiding redundant LLM calls for identical safety checks. ## Implementation Details @@ -14,7 +14,6 @@ The content safety checks in `actions.py` now use an LFU (Least Frequently Used) - Statistics tracking: Enabled by default - Tracks timestamps: `created_at` and `accessed_at` for each entry - Cache creation: Automatic when a model is first used -- Persistence: Optional periodic save to disk with configurable interval ### Cached Functions @@ -61,51 +60,6 @@ if "llama_guard" in _MODEL_CACHES: stats = _MODEL_CACHES["llama_guard"].get_stats() ``` -### Persistence Configuration - -The cache supports optional persistence to disk for resilience across restarts: - -```yaml -rails: - config: - content_safety: - cache: - enabled: true - capacity_per_model: 5000 - persistence: - interval: 300.0 # Persist every 5 minutes - path: ./cache_{model_name}.json # {model_name} is replaced -``` - -**Configuration Options:** - -- `persistence.interval`: Seconds between automatic saves (None = no persistence) -- `persistence.path`: Where to save cache data (can include `{model_name}` placeholder) - -**How Persistence Works:** - -1. **Automatic Saves**: Cache checks trigger persistence if interval has passed -2. **On Shutdown**: Caches are automatically persisted when LLMRails is closed or garbage collected -3. **On Restart**: Cache loads from disk if persistence file exists -4. **Preserves State**: Frequencies and access patterns are maintained -5. **Per-Model Files**: Each model gets its own persistence file - -**Manual Persistence:** - -```python -# Force immediate persistence of all caches -content_safety_manager.persist_all_caches() -``` - -This is useful for graceful shutdown scenarios. - -**Notes on Persistence:** - -- Persistence only works with "memory" store type -- Cache files are JSON format for easy inspection and debugging -- Set `persistence.interval` to None to disable persistence -- The cache automatically persists on each check if the interval has passed - ### Statistics and Monitoring The cache supports detailed statistics tracking and periodic logging for monitoring cache performance: @@ -175,24 +129,19 @@ if "safety_model" in _MODEL_CACHES: ```python from nemoguardrails import RailsConfig, LLMRails -# Method 1: Using context manager (recommended - ensures cleanup) +# Method 1: Using context manager config = RailsConfig.from_path("./config.yml") with LLMRails(config) as rails: - # Content safety checks will be cached and persisted automatically + # Content safety checks will be cached automatically response = await rails.generate_async( messages=[{"role": "user", "content": "Hello, how are you?"}] ) -# Caches are automatically persisted on exit -# Method 2: Manual cleanup +# Method 2: Direct usage rails = LLMRails(config) response = await rails.generate_async( messages=[{"role": "user", "content": "Hello, how are you?"}] ) -rails.close() # Manually persist caches - -# Note: If neither method is used, caches will still be persisted -# when the object is garbage collected (__del__) ``` ### Thread Safety @@ -207,7 +156,6 @@ The content safety caching system is **thread-safe** for single-node deployments 2. **ContentSafetyManager**: - Thread-safe cache creation using double-checked locking pattern - Ensures only one cache instance per model across all threads - - Thread-safe persistence operations 3. **Key Features**: - **No Data Corruption**: Concurrent operations maintain data integrity @@ -231,9 +179,8 @@ The content safety caching system is **thread-safe** for single-node deployments 5. **Model Isolation**: Each model has its own cache, preventing interference between different safety models 6. **Statistics Tracking**: Monitor cache performance with hit rates, evictions, and more per model 7. **Timestamp Tracking**: Track when entries were created and last accessed -8. **Resilience**: Cache survives process restarts without losing data when persistence is enabled -9. **Efficiency**: LFU eviction algorithm ensures the most useful entries remain in cache -10. **Thread Safety**: Safe for concurrent access in multi-threaded environments +8. **Efficiency**: LFU eviction algorithm ensures the most useful entries remain in cache +9. **Thread Safety**: Safe for concurrent access in multi-threaded environments ### Example Usage Pattern diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py index 3b8e97734..4b755a947 100644 --- a/nemoguardrails/library/content_safety/actions.py +++ b/nemoguardrails/library/content_safety/actions.py @@ -48,7 +48,6 @@ def _create_cache_key(prompt: Union[str, List[str]]) -> str: # Thread Safety Note: # The content safety caching mechanism is thread-safe for single-node deployments. # The underlying LFUCache uses threading.RLock to ensure atomic operations. -# ContentSafetyManager uses double-checked locking for efficient cache creation. # # However, this implementation is NOT suitable for distributed environments. # For multi-node deployments, consider using distributed caching solutions @@ -103,22 +102,19 @@ async def content_safety_check_input( max_tokens = max_tokens or _MAX_TOKENS - # Check cache if content safety manager is available for this model + # Check cache if available for this model cached_result = None cache_key = None - cache = None - - # Try to get the model-specific content safety manager - content_safety_manager = kwargs.get(f"content_safety_manager_{model_name}") - - if content_safety_manager: - cache = content_safety_manager.get_cache() - if cache: - cache_key = _create_cache_key(check_input_prompt) - cached_result = cache.get(cache_key) - if cached_result is not None: - log.debug(f"Content safety cache hit for model '{model_name}'") - return cached_result + + # Try to get the model-specific cache + cache = kwargs.get(f"model_cache_{model_name}") + + if cache: + cache_key = _create_cache_key(check_input_prompt) + cached_result = cache.get(cache_key) + if cached_result is not None: + log.debug(f"Content safety cache hit for model '{model_name}'") + return cached_result # Make the actual LLM call result = await llm_call( diff --git a/nemoguardrails/library/content_safety/manager.py b/nemoguardrails/library/content_safety/manager.py deleted file mode 100644 index 16a0ae12f..000000000 --- a/nemoguardrails/library/content_safety/manager.py +++ /dev/null @@ -1,98 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Optional - -from nemoguardrails.cache.interface import CacheInterface -from nemoguardrails.cache.lfu import LFUCache -from nemoguardrails.rails.llm.config import ModelCacheConfig - -log = logging.getLogger(__name__) - - -class ContentSafetyManager: - """Manages content safety functionality for a specific model.""" - - def __init__( - self, model_name: str, cache_config: Optional[ModelCacheConfig] = None - ): - self.model_name = model_name - self.cache_config = cache_config - self._cache: Optional[CacheInterface] = None - self._initialize_cache() - - def _initialize_cache(self): - """Initialize cache based on configuration.""" - if not self.cache_config or not self.cache_config.enabled: - log.debug( - f"Content safety caching is disabled for model '{self.model_name}'" - ) - return - - # Create cache based on store type - if self.cache_config.store == "memory": - # Determine persistence settings - persistence_path = None - persistence_interval = None - - if ( - self.cache_config.persistence.enabled - and self.cache_config.persistence.interval is not None - ): - persistence_interval = self.cache_config.persistence.interval - - if self.cache_config.persistence.path: - # Use configured path, replacing {model_name} if present - persistence_path = self.cache_config.persistence.path.replace( - "{model_name}", self.model_name - ) - else: - # Default path if persistence is enabled but no path specified - persistence_path = f"cache_{self.model_name}.json" - - # Determine stats logging settings - stats_logging_interval = None - if ( - self.cache_config.stats.enabled - and self.cache_config.stats.log_interval is not None - ): - stats_logging_interval = self.cache_config.stats.log_interval - - self._cache = LFUCache( - capacity=self.cache_config.capacity_per_model, - track_stats=self.cache_config.stats.enabled, - persistence_interval=persistence_interval, - persistence_path=persistence_path, - stats_logging_interval=stats_logging_interval, - ) - - log.info( - f"Created cache for model '{self.model_name}' with capacity {self.cache_config.capacity_per_model}" - ) - # elif self.cache_config.store == "filesystem": - # self._cache = FilesystemCache(...) - # elif self.cache_config.store == "redis": - # self._cache = RedisCache(...) - - def get_cache(self) -> Optional[CacheInterface]: - """Get the cache for this model.""" - return self._cache - - def persist_cache(self): - """Force immediate persistence of cache if it supports it.""" - if self._cache and self._cache.supports_persistence(): - self._cache.persist_now() - log.info(f"Persisted cache for model: {self.model_name}") diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 1bc92aa93..50c1190c1 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -878,23 +878,6 @@ class AIDefenseRailConfig(BaseModel): ) -class CachePersistenceConfig(BaseModel): - """Configuration for cache persistence to disk.""" - - enabled: bool = Field( - default=True, - description="Whether cache persistence is enabled (persistence requires both enabled=True and a valid interval)", - ) - interval: Optional[float] = Field( - default=None, - description="Seconds between periodic cache persistence to disk (None disables persistence)", - ) - path: Optional[str] = Field( - default=None, - description="Path to persistence file for cache data (defaults to 'cache_{model_name}.json' if persistence is enabled)", - ) - - class CacheStatsConfig(BaseModel): """Configuration for cache statistics tracking and logging.""" @@ -924,10 +907,6 @@ class ModelCacheConfig(BaseModel): store_config: Dict[str, Any] = Field( default_factory=dict, description="Backend-specific configuration" ) - persistence: CachePersistenceConfig = Field( - default_factory=CachePersistenceConfig, - description="Configuration for cache persistence", - ) stats: CacheStatsConfig = Field( default_factory=CacheStatsConfig, description="Configuration for cache statistics tracking and logging", diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 03f8c0dd4..99a097dba 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -131,7 +131,7 @@ def __init__( self.config = config self.llm = llm self.verbose = verbose - self._content_safety_managers = {} + self._model_caches = {} if self.verbose: set_verbose(True, llm_calls=True) @@ -501,7 +501,7 @@ def _init_llms(self): kwargs=kwargs, ) - # If the model is a content safety model, we need to create a ContentSafetyManager for it + # Configure the model based on its type if llm_config.type == "main": # If a main LLM was already injected, skip creating another # one. Otherwise, create and register it. @@ -529,32 +529,44 @@ def _init_llms(self): # Register content safety managers if content safety features are used if self._has_content_safety_rails(): - from nemoguardrails.library.content_safety.manager import ( - ContentSafetyManager, - ) + from nemoguardrails.cache.lfu import LFUCache - # Create a ContentSafetyManager for each content safety model + # Create a cache for each content safety model for model in self.config.models: if model.type not in ["main", "embeddings"]: - # Use model's cache config if available, otherwise None (no caching) - cache_config = model.cache - - manager = ContentSafetyManager( - model_name=model.type, cache_config=cache_config - ) + cache = None + + # Create cache if configured + if model.cache and model.cache.enabled: + if model.cache.store == "memory": + stats_logging_interval = None + if ( + model.cache.stats.enabled + and model.cache.stats.log_interval is not None + ): + stats_logging_interval = model.cache.stats.log_interval + + cache = LFUCache( + capacity=model.cache.capacity_per_model, + track_stats=model.cache.stats.enabled, + stats_logging_interval=stats_logging_interval, + ) - self._content_safety_managers[model.type] = manager + log.info( + f"Created cache for model '{model.type}' with capacity {model.cache.capacity_per_model}" + ) - # Register the manager for this specific model + # Register the cache for this specific model self.runtime.register_action_param( - f"content_safety_manager_{model.type}", manager + f"model_cache_{model.type}", cache ) + if cache: + self._model_caches[model.type] = cache + log.info( - f"Initialized ContentSafetyManager for model '{model.type}' with cache %s", - "enabled" - if cache_config and cache_config.enabled - else "disabled", + f"Initialized content safety for model '{model.type}' with cache %s", + "enabled" if cache else "disabled", ) def _get_embeddings_search_provider_instance( @@ -1804,14 +1816,11 @@ def _prepare_params( yield chunk def close(self): - """Properly close and clean up resources, including persisting caches.""" - if self._content_safety_managers: - log.info("Persisting content safety caches on close") - for model_name, manager in self._content_safety_managers.items(): - manager.persist_cache() + """Properly close and clean up resources.""" + pass def __del__(self): - """Ensure caches are persisted when the object is garbage collected.""" + """Clean up resources when the object is garbage collected.""" try: self.close() except Exception as e: diff --git a/tests/test_cache_lfu.py b/tests/test_cache_lfu.py index f2941482a..83ef81097 100644 --- a/tests/test_cache_lfu.py +++ b/tests/test_cache_lfu.py @@ -32,7 +32,6 @@ from unittest.mock import MagicMock, patch from nemoguardrails.cache.lfu import LFUCache -from nemoguardrails.library.content_safety.manager import ContentSafetyManager class TestLFUCache(unittest.TestCase): @@ -365,327 +364,6 @@ def test_interface_methods_exist(self): # Check property self.assertEqual(cache.capacity, 5) - def test_persistence_interface_methods(self): - """Verify persistence interface methods are implemented.""" - # Cache without persistence - cache_no_persist = LFUCache(5) - self.assertTrue(callable(getattr(cache_no_persist, "persist_now", None))) - self.assertTrue( - callable(getattr(cache_no_persist, "supports_persistence", None)) - ) - self.assertFalse(cache_no_persist.supports_persistence()) - - # Cache with persistence - temp_file = os.path.join(tempfile.mkdtemp(), "test_interface.json") - try: - cache_with_persist = LFUCache( - 5, persistence_interval=10.0, persistence_path=temp_file - ) - self.assertTrue(cache_with_persist.supports_persistence()) - - # persist_now should work without errors - cache_with_persist.put("key", "value") - cache_with_persist.persist_now() # Should not raise any exception - finally: - if os.path.exists(temp_file): - os.remove(temp_file) - if os.path.exists(os.path.dirname(temp_file)): - os.rmdir(os.path.dirname(temp_file)) - - -class TestLFUCachePersistence(unittest.TestCase): - """Test cases for LFU Cache persistence functionality.""" - - def setUp(self): - """Set up test fixtures.""" - # Create temporary directory for test files - self.temp_dir = tempfile.mkdtemp() - self.test_file = os.path.join(self.temp_dir, "test_cache.json") - - def tearDown(self): - """Clean up test files.""" - # Clean up any created files - if os.path.exists(self.test_file): - os.remove(self.test_file) - # Remove temporary directory - if os.path.exists(self.temp_dir): - os.rmdir(self.temp_dir) - - def test_basic_persistence(self): - """Test basic save and load functionality.""" - # Create cache and add items - cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) - - cache.put("key1", "value1") - cache.put("key2", {"nested": "value"}) - cache.put("key3", [1, 2, 3]) - - # Force persistence - cache.persist_now() - - # Verify file was created - self.assertTrue(os.path.exists(self.test_file)) - - # Load into new cache - new_cache = LFUCache( - 5, persistence_interval=10.0, persistence_path=self.test_file - ) - - # Verify data was loaded correctly - self.assertEqual(new_cache.size(), 3) - self.assertEqual(new_cache.get("key1"), "value1") - self.assertEqual(new_cache.get("key2"), {"nested": "value"}) - self.assertEqual(new_cache.get("key3"), [1, 2, 3]) - - def test_frequency_preservation(self): - """Test that frequencies are preserved across persistence.""" - cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) - - # Create different frequency levels - cache.put("freq1", "value1") - cache.put("freq3", "value3") - cache.put("freq5", "value5") - - # Access items to create different frequencies - cache.get("freq3") # freq = 2 - cache.get("freq3") # freq = 3 - - cache.get("freq5") # freq = 2 - cache.get("freq5") # freq = 3 - cache.get("freq5") # freq = 4 - cache.get("freq5") # freq = 5 - - # Force persistence - cache.persist_now() - - # Load into new cache - new_cache = LFUCache( - 5, persistence_interval=10.0, persistence_path=self.test_file - ) - - # Add new items to test eviction order - new_cache.put("new1", "newvalue1") - new_cache.put("new2", "newvalue2") - new_cache.put("new3", "newvalue3") - - # freq1 should be evicted first (lowest frequency) - self.assertIsNone(new_cache.get("freq1")) - # freq3 and freq5 should still be there - self.assertEqual(new_cache.get("freq3"), "value3") - self.assertEqual(new_cache.get("freq5"), "value5") - - def test_periodic_persistence(self): - """Test automatic periodic persistence.""" - # Use short interval for testing - cache = LFUCache(5, persistence_interval=0.5, persistence_path=self.test_file) - - cache.put("key1", "value1") - - # File shouldn't exist yet - self.assertFalse(os.path.exists(self.test_file)) - - # Wait for interval to pass - time.sleep(0.6) - - # Access cache to trigger persistence check - cache.get("key1") - - # File should now exist - self.assertTrue(os.path.exists(self.test_file)) - - # Verify content - with open(self.test_file, "r") as f: - data = json.load(f) - - self.assertEqual(data["capacity"], 5) - self.assertEqual(len(data["nodes"]), 1) - self.assertEqual(data["nodes"][0]["key"], "key1") - - def test_persistence_with_empty_cache(self): - """Test persistence behavior with empty cache.""" - cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) - - # Add and remove items - cache.put("key1", "value1") - cache.clear() - - # Force persistence - cache.persist_now() - - # File should be removed when cache is empty - self.assertFalse(os.path.exists(self.test_file)) - - def test_no_persistence_when_disabled(self): - """Test that persistence doesn't occur when not configured.""" - # Create cache without persistence - cache = LFUCache(5) - - cache.put("key1", "value1") - cache.persist_now() # Should do nothing - - # No file should be created - self.assertFalse(os.path.exists("lfu_cache.json")) - - def test_load_from_nonexistent_file(self): - """Test loading when persistence file doesn't exist.""" - # Create cache with non-existent file - cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) - - # Should start empty - self.assertEqual(cache.size(), 0) - self.assertTrue(cache.is_empty()) - - def test_persistence_with_complex_data(self): - """Test persistence with various data types.""" - cache = LFUCache(10, persistence_interval=10.0, persistence_path=self.test_file) - - # Add various data types - test_data = { - "string": "hello world", - "int": 42, - "float": 3.14, - "bool": True, - "none": None, - "list": [1, 2, [3, 4]], - "dict": {"a": 1, "b": {"c": 2}}, - "tuple_key": "value_for_tuple", # Will use string key since tuples aren't JSON serializable - } - - for key, value in test_data.items(): - cache.put(key, value) - - # Force persistence - cache.persist_now() - - # Load into new cache - new_cache = LFUCache( - 10, persistence_interval=10.0, persistence_path=self.test_file - ) - - # Verify all data types - for key, value in test_data.items(): - self.assertEqual(new_cache.get(key), value) - - def test_persistence_file_corruption_handling(self): - """Test handling of corrupted persistence files.""" - # Create invalid JSON file - with open(self.test_file, "w") as f: - f.write("{ invalid json content") - - # Should handle gracefully and start with empty cache - cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) - self.assertEqual(cache.size(), 0) - - # Cache should still be functional - cache.put("key1", "value1") - self.assertEqual(cache.get("key1"), "value1") - - def test_multiple_persistence_cycles(self): - """Test multiple save/load cycles.""" - # First cycle - cache1 = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) - cache1.put("key1", "value1") - cache1.put("key2", "value2") - cache1.persist_now() - - # Second cycle - load and modify - cache2 = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) - self.assertEqual(cache2.size(), 2) - cache2.put("key3", "value3") - cache2.persist_now() - - # Third cycle - verify all changes - cache3 = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) - self.assertEqual(cache3.size(), 3) - self.assertEqual(cache3.get("key1"), "value1") - self.assertEqual(cache3.get("key2"), "value2") - self.assertEqual(cache3.get("key3"), "value3") - - def test_capacity_change_on_load(self): - """Test loading cache data into cache with different capacity.""" - # Create cache with capacity 5 - cache1 = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) - for i in range(5): - cache1.put(f"key{i}", f"value{i}") - cache1.persist_now() - - # Load into cache with smaller capacity - cache2 = LFUCache(3, persistence_interval=10.0, persistence_path=self.test_file) - - # Current design: loads all persisted items regardless of new capacity - # This is a valid design choice - preserve data integrity on load - self.assertEqual(cache2.size(), 5) - - # The cache continues to operate with loaded items - # New items can still be added, and the cache will manage its size - cache2.put("new_key", "new_value") - - # Verify the cache is still functional and contains the new item - self.assertEqual(cache2.get("new_key"), "new_value") - self.assertGreaterEqual( - cache2.size(), 4 - ) # At least has the new item plus some old ones - - def test_persistence_timing(self): - """Test that persistence doesn't happen too frequently.""" - cache = LFUCache(5, persistence_interval=1.0, persistence_path=self.test_file) - - cache.put("key1", "value1") - - # Multiple operations within interval shouldn't trigger persistence - for i in range(10): - cache.get("key1") - self.assertFalse(os.path.exists(self.test_file)) - time.sleep(0.05) # Total time still less than interval - - # Wait for interval to pass - time.sleep(0.6) - cache.get("key1") - - # Now file should exist - self.assertTrue(os.path.exists(self.test_file)) - - def test_persistence_with_statistics(self): - """Test persistence doesn't interfere with statistics tracking.""" - cache = LFUCache( - 5, - track_stats=True, - persistence_interval=0.5, - persistence_path=self.test_file, - ) - - # Perform operations - cache.put("key1", "value1") - cache.put("key2", "value2") - cache.get("key1") - cache.get("nonexistent") - - # Wait for persistence - time.sleep(0.6) - cache.get("key1") # Trigger persistence - - # Check stats are still correct - stats = cache.get_stats() - self.assertEqual(stats["puts"], 2) - self.assertEqual(stats["hits"], 2) - self.assertEqual(stats["misses"], 1) - - # Load into new cache with stats - new_cache = LFUCache( - 5, - track_stats=True, - persistence_interval=0.5, - persistence_path=self.test_file, - ) - - # Stats should be reset in new instance - new_stats = new_cache.get_stats() - self.assertEqual(new_stats["puts"], 0) - self.assertEqual(new_stats["hits"], 0) - - # But data should be loaded - self.assertEqual(new_cache.size(), 2) - class TestLFUCacheStatsLogging(unittest.TestCase): """Test cases for LFU Cache statistics logging functionality.""" @@ -923,39 +601,6 @@ def test_stats_logging_with_updates(self): self.assertIn("Updates: 2", log_message) self.assertIn("Puts: 1", log_message) - def test_stats_logging_combined_with_persistence(self): - """Test that stats logging and persistence work together.""" - import logging - from unittest.mock import patch - - cache = LFUCache( - 5, - track_stats=True, - persistence_interval=1.0, - persistence_path=self.test_file, - stats_logging_interval=0.5, - ) - - cache.put("key1", "value1") - - with patch.object( - logging.getLogger("nemoguardrails.cache.lfu"), "info" - ) as mock_log: - # Wait for stats logging interval - time.sleep(0.6) - cache.get("key1") # Trigger stats log - - self.assertEqual(mock_log.call_count, 1) - self.assertFalse(os.path.exists(self.test_file)) # Not persisted yet - - # Wait for persistence interval - time.sleep(0.5) - cache.get("key1") # Trigger persistence - - self.assertTrue(os.path.exists(self.test_file)) # Now persisted - # Stats log might trigger again if interval passed - self.assertGreaterEqual(mock_log.call_count, 1) - def test_stats_log_format_percentages(self): """Test that percentages in stats log are formatted correctly.""" import logging @@ -1010,33 +655,25 @@ def tearDown(self): def test_cache_config_with_stats_disabled(self): """Test cache configuration with stats disabled.""" - from nemoguardrails.library.content_safety.manager import ContentSafetyManager - from nemoguardrails.rails.llm.config import ( - CacheStatsConfig, - ModelCacheConfig, - ModelConfig, - ) + from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig cache_config = ModelCacheConfig( enabled=True, capacity_per_model=1000, stats=CacheStatsConfig(enabled=False) ) - model_config = ModelConfig(cache=cache_config) - manager = ContentSafetyManager(model_config) + cache = LFUCache( + capacity=cache_config.capacity_per_model, + track_stats=cache_config.stats.enabled, + stats_logging_interval=None, + ) - cache = manager.get_cache_for_model("test_model") self.assertIsNotNone(cache) self.assertFalse(cache.track_stats) self.assertFalse(cache.supports_stats_logging()) def test_cache_config_with_stats_tracking_only(self): """Test cache configuration with stats tracking but no logging.""" - from nemoguardrails.library.content_safety.manager import ContentSafetyManager - from nemoguardrails.rails.llm.config import ( - CacheStatsConfig, - ModelCacheConfig, - ModelConfig, - ) + from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig cache_config = ModelCacheConfig( enabled=True, @@ -1044,10 +681,12 @@ def test_cache_config_with_stats_tracking_only(self): stats=CacheStatsConfig(enabled=True, log_interval=None), ) - model_config = ModelConfig(cache=cache_config) - manager = ContentSafetyManager(model_config) + cache = LFUCache( + capacity=cache_config.capacity_per_model, + track_stats=cache_config.stats.enabled, + stats_logging_interval=cache_config.stats.log_interval, + ) - cache = manager.get_cache_for_model("test_model") self.assertIsNotNone(cache) self.assertTrue(cache.track_stats) self.assertFalse(cache.supports_stats_logging()) @@ -1055,12 +694,7 @@ def test_cache_config_with_stats_tracking_only(self): def test_cache_config_with_stats_logging(self): """Test cache configuration with stats tracking and logging.""" - from nemoguardrails.library.content_safety.manager import ContentSafetyManager - from nemoguardrails.rails.llm.config import ( - CacheStatsConfig, - ModelCacheConfig, - ModelConfig, - ) + from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig cache_config = ModelCacheConfig( enabled=True, @@ -1068,10 +702,12 @@ def test_cache_config_with_stats_logging(self): stats=CacheStatsConfig(enabled=True, log_interval=60.0), ) - model_config = ModelConfig(cache=cache_config) - manager = ContentSafetyManager(model_config) + cache = LFUCache( + capacity=cache_config.capacity_per_model, + track_stats=cache_config.stats.enabled, + stats_logging_interval=cache_config.stats.log_interval, + ) - cache = manager.get_cache_for_model("test_model") self.assertIsNotNone(cache) self.assertTrue(cache.track_stats) self.assertTrue(cache.supports_stats_logging()) @@ -1079,49 +715,20 @@ def test_cache_config_with_stats_logging(self): def test_cache_config_default_stats(self): """Test cache configuration with default stats settings.""" - from nemoguardrails.library.content_safety.manager import ContentSafetyManager - from nemoguardrails.rails.llm.config import ModelCacheConfig, ModelConfig + from nemoguardrails.rails.llm.config import ModelCacheConfig cache_config = ModelCacheConfig(enabled=True) - model_config = ModelConfig(cache=cache_config) - manager = ContentSafetyManager(model_config) + cache = LFUCache( + capacity=cache_config.capacity_per_model, + track_stats=cache_config.stats.enabled, + stats_logging_interval=None, + ) - cache = manager.get_cache_for_model("test_model") self.assertIsNotNone(cache) self.assertFalse(cache.track_stats) # Default is disabled self.assertFalse(cache.supports_stats_logging()) - def test_cache_config_stats_with_persistence(self): - """Test cache configuration with both stats and persistence.""" - from nemoguardrails.library.content_safety.manager import ContentSafetyManager - from nemoguardrails.rails.llm.config import ( - CachePersistenceConfig, - CacheStatsConfig, - ModelCacheConfig, - ModelConfig, - ) - - cache_config = ModelCacheConfig( - enabled=True, - capacity_per_model=1000, - stats=CacheStatsConfig(enabled=True, log_interval=30.0), - persistence=CachePersistenceConfig( - enabled=True, interval=60.0, path=self.test_file - ), - ) - - model_config = ModelConfig(cache=cache_config) - manager = ContentSafetyManager(model_config) - - cache = manager.get_cache_for_model("test_model") - self.assertIsNotNone(cache) - self.assertTrue(cache.track_stats) - self.assertTrue(cache.supports_stats_logging()) - self.assertEqual(cache.stats_logging_interval, 30.0) - self.assertTrue(cache.supports_persistence()) - self.assertEqual(cache.persistence_interval, 60.0) - def test_cache_config_from_dict(self): """Test cache configuration creation from dictionary.""" from nemoguardrails.rails.llm.config import ModelCacheConfig @@ -1155,40 +762,6 @@ def test_cache_config_stats_validation(self): self.assertFalse(stats3.enabled) self.assertEqual(stats3.log_interval, 60.0) - def test_multiple_model_caches_with_stats(self): - """Test multiple model caches each with their own stats configuration.""" - from nemoguardrails.library.content_safety.manager import ContentSafetyManager - from nemoguardrails.rails.llm.config import ( - CacheStatsConfig, - ModelCacheConfig, - ModelConfig, - ) - - cache_config = ModelCacheConfig( - enabled=True, - capacity_per_model=1000, - stats=CacheStatsConfig(enabled=True, log_interval=30.0), - ) - - model_config = ModelConfig( - cache=cache_config, model_mapping={"model_alias": "actual_model"} - ) - manager = ContentSafetyManager(model_config) - - # Get caches for different models - cache1 = manager.get_cache_for_model("model1") - cache2 = manager.get_cache_for_model("model2") - cache_alias = manager.get_cache_for_model("model_alias") - cache_actual = manager.get_cache_for_model("actual_model") - - # All should have stats enabled - self.assertTrue(cache1.track_stats) - self.assertTrue(cache2.track_stats) - self.assertTrue(cache_alias.track_stats) - - # Alias should resolve to same cache as actual - self.assertIs(cache_alias, cache_actual) - class TestLFUCacheThreadSafety(unittest.TestCase): """Test thread safety of LFU Cache implementation.""" @@ -1429,54 +1002,6 @@ async def run_test(): asyncio.run(run_test()) - def test_concurrent_persistence(self): - """Test thread safety of persistence operations.""" - with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f: - cache_file = f.name - - try: - # Create cache with persistence - cache = LFUCache( - capacity=50, - track_stats=True, - persistence_interval=0.1, # Short interval for testing - persistence_path=cache_file, - ) - - def worker(thread_id): - """Worker that performs operations.""" - for i in range(20): - cache.put(f"persist_key_{thread_id}_{i}", f"value_{thread_id}_{i}") - cache.get(f"persist_key_{thread_id}_{i}") - - # Force persistence sometimes - if i % 5 == 0: - cache.persist_now() - - # Run workers - with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(worker, i) for i in range(5)] - for future in futures: - future.result() - - # Final persist - cache.persist_now() - - # Load the persisted data - new_cache = LFUCache( - capacity=50, persistence_interval=1.0, persistence_path=cache_file - ) - - # Verify some data was persisted correctly - # (Due to capacity limits, not all items will be present) - self.assertGreater(new_cache.size(), 0) - self.assertLessEqual(new_cache.size(), 50) - - finally: - # Clean up - if os.path.exists(cache_file): - os.unlink(cache_file) - def test_thread_safe_size_operations(self): """Test that size-related operations are thread-safe.""" results = [] @@ -1687,175 +1212,5 @@ def worker(thread_id): self.assertGreater(stats["puts"], 0) -class TestContentSafetyManagerThreadSafety(unittest.TestCase): - """Test thread safety of ContentSafetyManager.""" - - def setUp(self): - """Set up test fixtures.""" - # Create mock cache config - self.cache_config = MagicMock() - self.cache_config.enabled = True - self.cache_config.store = "memory" - self.cache_config.capacity_per_model = 100 - self.cache_config.stats.enabled = True - self.cache_config.stats.log_interval = None - self.cache_config.persistence.enabled = False - self.cache_config.persistence.interval = None - self.cache_config.persistence.path = None - - # Create mock model config - self.model_config = MagicMock() - self.model_config.cache = self.cache_config - self.model_config.model_mapping = {"alias_model": "actual_model"} - - def test_concurrent_cache_creation(self): - """Test that concurrent cache creation returns the same instance.""" - manager = ContentSafetyManager(self.model_config) - caches = [] - - def worker(thread_id): - """Worker that gets cache for model.""" - cache = manager.get_cache_for_model("test_model") - caches.append((thread_id, cache)) - return cache - - # Run many threads to increase chance of race condition - with ThreadPoolExecutor(max_workers=20) as executor: - futures = [executor.submit(worker, i) for i in range(20)] - for future in futures: - future.result() - - # All caches should be the same instance - first_cache = caches[0][1] - for thread_id, cache in caches: - self.assertIs( - cache, first_cache, f"Thread {thread_id} got different cache instance" - ) - - def test_concurrent_multi_model_caches(self): - """Test concurrent access to caches for different models.""" - manager = ContentSafetyManager(self.model_config) - results = [] - - def worker(thread_id): - """Worker that accesses multiple model caches.""" - model_names = [f"model_{i}" for i in range(5)] - - for model_name in model_names: - cache = manager.get_cache_for_model(model_name) - - # Perform operations - key = f"thread_{thread_id}_key" - value = f"thread_{thread_id}_value" - cache.put(key, value) - retrieved = cache.get(key) - - if retrieved != value: - results.append(f"Mismatch for {model_name}: {retrieved} != {value}") - - # Run workers - with ThreadPoolExecutor(max_workers=10) as executor: - futures = [executor.submit(worker, i) for i in range(10)] - for future in futures: - future.result() - - # Check for errors - self.assertEqual(len(results), 0, f"Errors found: {results}") - - def test_concurrent_persist_all_caches(self): - """Test thread safety of persist_all_caches method.""" - with tempfile.TemporaryDirectory() as temp_dir: - # Create mock config with persistence - cache_config = MagicMock() - cache_config.enabled = True - cache_config.store = "memory" - cache_config.capacity_per_model = 50 - cache_config.persistence.enabled = True - cache_config.persistence.interval = 1.0 - cache_config.persistence.path = f"{temp_dir}/cache_{{model_name}}.json" - cache_config.stats.enabled = True - cache_config.stats.log_interval = None - - model_config = MagicMock() - model_config.cache = cache_config - model_config.model_mapping = {} - - manager = ContentSafetyManager(model_config) - - # Create caches for multiple models - for i in range(5): - cache = manager.get_cache_for_model(f"model_{i}") - for j in range(10): - cache.put(f"key_{j}", f"value_{j}") - - persist_count = [0] - - def persist_worker(): - """Worker that calls persist_all_caches.""" - manager.persist_all_caches() - persist_count[0] += 1 - - def modify_worker(): - """Worker that modifies caches while persistence happens.""" - for i in range(20): - model_name = f"model_{i % 5}" - cache = manager.get_cache_for_model(model_name) - cache.put(f"new_key_{i}", f"new_value_{i}") - time.sleep(0.001) - - # Run persistence and modifications concurrently - threads = [] - - # Multiple persist threads - for _ in range(3): - t = threading.Thread(target=persist_worker) - threads.append(t) - t.start() - - # Modification thread - t = threading.Thread(target=modify_worker) - threads.append(t) - t.start() - - # Wait for all threads - for t in threads: - t.join() - - # Verify persistence was called - self.assertEqual(persist_count[0], 3) - - def test_model_alias_thread_safety(self): - """Test thread safety when using model aliases.""" - manager = ContentSafetyManager(self.model_config) - caches = [] - - def worker(use_alias): - """Worker that gets cache using alias or actual name.""" - if use_alias: - cache = manager.get_cache_for_model("alias_model") - else: - cache = manager.get_cache_for_model("actual_model") - caches.append(cache) - - # Mix of threads using alias and actual name - with ThreadPoolExecutor(max_workers=10) as executor: - futures = [] - for i in range(10): - use_alias = i % 2 == 0 - futures.append(executor.submit(worker, use_alias)) - - for future in futures: - future.result() - - # All should get the same cache instance - first_cache = caches[0] - for cache in caches: - self.assertIs( - cache, - first_cache, - "Alias and actual model should resolve to same cache", - ) - - if __name__ == "__main__": unittest.main() From 53a1c8253bd7869fd92b1d6c9404ba0a9ffdd02a Mon Sep 17 00:00:00 2001 From: Oren Hazai Date: Mon, 6 Oct 2025 16:02:26 +0300 Subject: [PATCH 03/19] update README and test --- examples/configs/content_safety/README.md | 15 ++++++++++++--- tests/test_cache_lfu.py | 11 ++++++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/examples/configs/content_safety/README.md b/examples/configs/content_safety/README.md index 06de8b372..2194fa571 100644 --- a/examples/configs/content_safety/README.md +++ b/examples/configs/content_safety/README.md @@ -10,6 +10,13 @@ This example demonstrates how to configure content safety rails with NeMo Guardr - **Per-Model Caching**: Optional caching with configurable settings per model - **Multiple Models**: Support for different content safety models with different configurations +## Folder Structure + +The structure of the config folder is the following: + +- `config.yml` - The main configuration file with model definitions, rails configuration, and cache settings +- `prompts.yml` - Contains the content safety prompt templates used by the safety models to evaluate content + ## Configuration Overview ### Basic Configuration @@ -77,9 +84,9 @@ By default, caching is **disabled**. Models without cache configuration will hav ```yaml models: - - type: shieldgemma - engine: google - model: google/shieldgemma-2b + - type: content_safety + engine: nim + model: nvidia/llama-3.1-nemoguard-8b-content-safety # No cache config = no caching (default) ``` @@ -127,6 +134,8 @@ This makes it suitable for: nemoguardrails server --config examples/configs/content_safety/ ``` +Please see the docs for more details about the [recommended ContentSafety deployment](./../../../docs/user-guides/advanced/nemoguard-contentsafety-deployment.md) methods, either using locally downloaded NIMs or NVIDIA AI Enterprise (NVAIE). + ## Benefits 1. **Performance**: Avoid redundant content safety API calls diff --git a/tests/test_cache_lfu.py b/tests/test_cache_lfu.py index 83ef81097..3f9340452 100644 --- a/tests/test_cache_lfu.py +++ b/tests/test_cache_lfu.py @@ -972,7 +972,16 @@ async def run_test(): self.assertEqual(result, "computed_value_1") def test_get_or_compute_exception_handling(self): - """Test get_or_compute handles exceptions properly.""" + """Test get_or_compute handles exceptions properly. + + NOTE: This test will produce "ValueError: Computation failed" messages in the test output. + These are EXPECTED and NORMAL - the test intentionally triggers failures to verify + that the cache handles exceptions correctly. Each of the 5 workers will generate one + error message, but all workers should receive the fallback value successfully. + """ + # Optional: Uncomment to see a message before the expected errors + # print("\n[test_get_or_compute_exception_handling] Note: The following 5 'ValueError: Computation failed' messages are expected...") + call_count = [0] async def failing_compute(): From 60088c581b3ba3ea734f53a1647d6bc1243b7cd4 Mon Sep 17 00:00:00 2001 From: Oren Hazai Date: Mon, 6 Oct 2025 16:14:11 +0300 Subject: [PATCH 04/19] update yml files --- examples/configs/content_safety/config.yml | 18 ------------------ examples/configs/content_safety/prompts.yml | 4 ---- 2 files changed, 22 deletions(-) diff --git a/examples/configs/content_safety/config.yml b/examples/configs/content_safety/config.yml index 018d1aade..474a43dad 100644 --- a/examples/configs/content_safety/config.yml +++ b/examples/configs/content_safety/config.yml @@ -3,7 +3,6 @@ models: engine: nim model: meta/llama-3.3-70b-instruct - # Multiple content safety models with different cache configurations - type: content_safety engine: nim model: nvidia/llama-3.1-nemoguard-8b-content-safety @@ -15,28 +14,11 @@ models: enabled: true log_interval: 60.0 # Log stats every minute - - type: llama_guard - engine: vllm_openai - model: meta-llama/Llama-Guard-7b - # Different cache settings for this model - cache: - enabled: true - capacity_per_model: 25000 # Medium cache - stats: - enabled: false # No stats for this model - - - type: shieldgemma - engine: google - model: google/shieldgemma-2b - # No cache configuration = no caching (default behavior) - rails: input: flows: # You can use multiple content safety models - content safety check input $model=content_safety - # - content safety check input $model=llama_guard - # - content safety check input $model=shieldgemma output: flows: - content safety check output $model=content_safety diff --git a/examples/configs/content_safety/prompts.yml b/examples/configs/content_safety/prompts.yml index 61adc43cb..0a70e1676 100644 --- a/examples/configs/content_safety/prompts.yml +++ b/examples/configs/content_safety/prompts.yml @@ -1,9 +1,5 @@ # Default content safety prompts for nvidia/llama-3.1-nemoguard-8b-content-safety # These are the default prompts released by Meta, except for policy O7, which was added to address direct insults. -# -# To add prompts for other content safety models, add them below with the appropriate model name: -# - task: content_safety_check_input $model=llama_guard -# - task: content_safety_check_input $model=shieldgemma prompts: - task: content_safety_check_input $model=content_safety content: | From b8f1ad3c778c662b1ceba6308241b914236b2fea Mon Sep 17 00:00:00 2001 From: Oren Hazai Date: Mon, 6 Oct 2025 16:15:08 +0300 Subject: [PATCH 05/19] create cache per any model with such config --- nemoguardrails/rails/llm/llmrails.py | 79 +++++++++++----------------- 1 file changed, 32 insertions(+), 47 deletions(-) diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 99a097dba..fb7e3ea7e 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -50,6 +50,7 @@ ) from nemoguardrails.actions.output_mapping import is_output_blocked from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx +from nemoguardrails.cache.lfu import LFUCache from nemoguardrails.colang import parse_colang_file from nemoguardrails.colang.v1_0.runtime.flows import _normalize_flow_id, compute_context from nemoguardrails.colang.v1_0.runtime.runtime import Runtime, RuntimeV1_0 @@ -527,47 +528,41 @@ def _init_llms(self): self.runtime.register_action_param("llms", llms) - # Register content safety managers if content safety features are used - if self._has_content_safety_rails(): - from nemoguardrails.cache.lfu import LFUCache - - # Create a cache for each content safety model - for model in self.config.models: - if model.type not in ["main", "embeddings"]: - cache = None - - # Create cache if configured - if model.cache and model.cache.enabled: - if model.cache.store == "memory": - stats_logging_interval = None - if ( - model.cache.stats.enabled - and model.cache.stats.log_interval is not None - ): - stats_logging_interval = model.cache.stats.log_interval - - cache = LFUCache( - capacity=model.cache.capacity_per_model, - track_stats=model.cache.stats.enabled, - stats_logging_interval=stats_logging_interval, - ) + # Create cache per model + for model in self.config.models: + if model.type not in ["main", "embeddings"]: + cache = None + + # Create cache if configured + if model.cache and model.cache.enabled: + if model.cache.store == "memory": + stats_logging_interval = None + if ( + model.cache.stats.enabled + and model.cache.stats.log_interval is not None + ): + stats_logging_interval = model.cache.stats.log_interval + + cache = LFUCache( + capacity=model.cache.capacity_per_model, + track_stats=model.cache.stats.enabled, + stats_logging_interval=stats_logging_interval, + ) - log.info( - f"Created cache for model '{model.type}' with capacity {model.cache.capacity_per_model}" - ) + log.info( + f"Created cache for model '{model.type}' with capacity {model.cache.capacity_per_model}" + ) - # Register the cache for this specific model - self.runtime.register_action_param( - f"model_cache_{model.type}", cache - ) + # Register the cache for this specific model + self.runtime.register_action_param(f"model_cache_{model.type}", cache) - if cache: - self._model_caches[model.type] = cache + if cache: + self._model_caches[model.type] = cache - log.info( - f"Initialized content safety for model '{model.type}' with cache %s", - "enabled" if cache else "disabled", - ) + log.info( + f"Initialized model '{model.type}' with cache %s", + "enabled" if cache else "disabled", + ) def _get_embeddings_search_provider_instance( self, esp_config: Optional[EmbeddingSearchProvider] = None @@ -1521,16 +1516,6 @@ def register_embedding_provider( register_embedding_provider(engine_name=name, model=cls) return self - def _has_content_safety_rails(self) -> bool: - """Check if any content safety rails are configured in flows. - At the moment, we only support content safety manager in input flows. - """ - flows = self.config.rails.input.flows - for flow in flows: - if "content safety check input" in flow: - return True - return False - def explain(self) -> ExplainInfo: """Helper function to return the latest ExplainInfo object.""" if self.explain_info is None: From 9e41c78303146a664042757e221db1f5e7600974 Mon Sep 17 00:00:00 2001 From: Oren Hazai Date: Mon, 6 Oct 2025 16:23:05 +0300 Subject: [PATCH 06/19] minor fixes --- examples/configs/content_safety/prompts.yml | 1 - nemoguardrails/rails/llm/llmrails.py | 25 --------------------- 2 files changed, 26 deletions(-) diff --git a/examples/configs/content_safety/prompts.yml b/examples/configs/content_safety/prompts.yml index 0a70e1676..dfd8b45a8 100644 --- a/examples/configs/content_safety/prompts.yml +++ b/examples/configs/content_safety/prompts.yml @@ -1,4 +1,3 @@ -# Default content safety prompts for nvidia/llama-3.1-nemoguard-8b-content-safety # These are the default prompts released by Meta, except for policy O7, which was added to address direct insults. prompts: - task: content_safety_check_input $model=content_safety diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index fb7e3ea7e..7d3c350c5 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -132,7 +132,6 @@ def __init__( self.config = config self.llm = llm self.verbose = verbose - self._model_caches = {} if self.verbose: set_verbose(True, llm_calls=True) @@ -556,9 +555,6 @@ def _init_llms(self): # Register the cache for this specific model self.runtime.register_action_param(f"model_cache_{model.type}", cache) - if cache: - self._model_caches[model.type] = cache - log.info( f"Initialized model '{model.type}' with cache %s", "enabled" if cache else "disabled", @@ -1799,24 +1795,3 @@ def _prepare_params( # yield the individual chunks directly from the buffer strategy for chunk in user_output_chunks: yield chunk - - def close(self): - """Properly close and clean up resources.""" - pass - - def __del__(self): - """Clean up resources when the object is garbage collected.""" - try: - self.close() - except Exception as e: - # Silently fail in destructor to avoid issues during shutdown - log.debug(f"Error during LLMRails cleanup: {e}") - - def __enter__(self): - """Context manager entry.""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Context manager exit - ensure cleanup.""" - self.close() - return False From 2e93fde18474d62e06253d87f4f10ffaca602881 Mon Sep 17 00:00:00 2001 From: Oren Hazai Date: Sun, 12 Oct 2025 15:55:51 +0300 Subject: [PATCH 07/19] changes following PR --- nemoguardrails/cache/README.md | 119 +++++++++--------- .../library/content_safety/actions.py | 5 +- nemoguardrails/rails/llm/config.py | 72 ++++++----- nemoguardrails/rails/llm/llmrails.py | 14 ++- 4 files changed, 111 insertions(+), 99 deletions(-) diff --git a/nemoguardrails/cache/README.md b/nemoguardrails/cache/README.md index c1ac575ae..dfa18c036 100644 --- a/nemoguardrails/cache/README.md +++ b/nemoguardrails/cache/README.md @@ -11,21 +11,20 @@ The content safety checks in `actions.py` now use an LFU (Least Frequently Used) - Per-model caches: Each model gets its own LFU cache instance - Default capacity: 50,000 entries per model - Eviction policy: LFU with LRU tiebreaker -- Statistics tracking: Enabled by default +- Statistics tracking: Disabled by default (configurable) - Tracks timestamps: `created_at` and `accessed_at` for each entry -- Cache creation: Automatic when a model is first used +- Cache creation: Automatic when a model is initialized with cache enabled +- Supported model types: Any non-`main` and non-`embeddings` model type (typically content safety models) ### Cached Functions -1. `content_safety_check_input()` - Caches safety checks for user inputs - -Note: `content_safety_check_output()` does not use caching to ensure fresh evaluation of bot responses. +`content_safety_check_input()` - Caches safety checks for user inputs ### Cache Key Components -The cache key is a SHA256 hash of: +The cache key is generated from: -- The rendered prompt only (can be a string or list of strings) +- The rendered prompt (normalized for whitespace) Since temperature is fixed (1e-20) and stop/max_tokens are derived from the model configuration, they don't need to be part of the cache key. @@ -44,43 +43,33 @@ Since temperature is fixed (1e-20) and stop/max_tokens are derived from the mode The caching system automatically creates and manages separate caches for each model. Key features: -- **Automatic Creation**: Caches are created on first use for each model +- **Automatic Creation**: Caches are created when the model is initialized with cache configuration - **Isolated Storage**: Each model maintains its own cache, preventing cross-model interference -- **Default Settings**: Each cache has 50,000 entry capacity with stats tracking enabled - -```python -# Internal cache access (for debugging/monitoring): -from nemoguardrails.library.content_safety.actions import _MODEL_CACHES - -# View which models have caches -models_with_caches = list(_MODEL_CACHES.keys()) - -# Get stats for a specific model's cache -if "llama_guard" in _MODEL_CACHES: - stats = _MODEL_CACHES["llama_guard"].get_stats() -``` +- **Default Settings**: Each cache has 50,000 entry capacity (configurable) +- **Per-Model Configuration**: Cache is configured per model in the YAML configuration ### Statistics and Monitoring The cache supports detailed statistics tracking and periodic logging for monitoring cache performance: ```yaml -rails: - config: - content_safety: - cache: - enabled: true - capacity_per_model: 10000 - stats: - enabled: true # Enable stats tracking - log_interval: 60.0 # Log stats every minute +models: + - type: content_safety + engine: nim + model: nvidia/llama-3.1-nemoguard-8b-content-safety + cache: + enabled: true + capacity_per_model: 10000 + store: memory # Currently only 'memory' is supported + stats: + enabled: true # Enable stats tracking + log_interval: 60.0 # Log stats every minute ``` **Statistics Features:** 1. **Tracking Only**: Set `stats.enabled: true` with no `log_interval` to track stats without logging 2. **Automatic Logging**: Set both `stats.enabled: true` and `log_interval` for periodic logging -3. **Manual Logging**: Force immediate stats logging with `cache.log_stats_now()` **Statistics Tracked:** @@ -100,17 +89,12 @@ LFU Cache Statistics - Size: 2456/10000 | Hits: 15234 | Misses: 2456 | Hit Rate: **Usage Examples:** -```python -# Programmatically access stats -if "safety_model" in _MODEL_CACHES: - cache = _MODEL_CACHES["safety_model"] - stats = cache.get_stats() - print(f"Cache hit rate: {stats['hit_rate']:.2%}") - - # Force immediate stats logging - if cache.supports_stats_logging(): - cache.log_stats_now() -``` +The cache is managed internally by the NeMo Guardrails framework. When you configure a model with caching enabled, the framework automatically: + +1. Creates an LFU cache instance for that model +2. Passes the cache to content safety actions via kwargs +3. Tracks statistics if configured +4. Logs statistics at the specified interval **Configuration Options:** @@ -124,21 +108,42 @@ if "safety_model" in _MODEL_CACHES: - Stats are reset when cache is cleared or when `reset_stats()` is called - Each model maintains independent statistics -### Example Configuration Usage +### Example Configuration + +```yaml +# config.yml +models: + - type: main + engine: openai + model: gpt-4 + + - type: content_safety + engine: nim + model: nvidia/llama-3.1-nemoguard-8b-content-safety + cache: + enabled: true + capacity_per_model: 50000 + store: memory + stats: + enabled: true + log_interval: 300.0 # Log stats every 5 minutes + +rails: + input: + flows: + - content safety check input model="content_safety" +``` + +### Example Usage ```python from nemoguardrails import RailsConfig, LLMRails -# Method 1: Using context manager +# The cache is automatically configured based on your YAML config config = RailsConfig.from_path("./config.yml") -with LLMRails(config) as rails: - # Content safety checks will be cached automatically - response = await rails.generate_async( - messages=[{"role": "user", "content": "Hello, how are you?"}] - ) - -# Method 2: Direct usage rails = LLMRails(config) + +# Content safety checks will be cached automatically response = await rails.generate_async( messages=[{"role": "user", "content": "Hello, how are you?"}] ) @@ -153,8 +158,8 @@ The content safety caching system is **thread-safe** for single-node deployments - All public methods (`get`, `put`, `size`, `clear`, etc.) are protected by locks - Supports atomic `get_or_compute()` operations that prevent duplicate computations -2. **ContentSafetyManager**: - - Thread-safe cache creation using double-checked locking pattern +2. **LLMRails Model Initialization**: + - Thread-safe cache creation during model initialization - Ensures only one cache instance per model across all threads 3. **Key Features**: @@ -189,7 +194,7 @@ The content safety caching system is **thread-safe** for single-node deployments result = await content_safety_check_input( llms=llms, llm_task_manager=task_manager, - model_name="safety_model", + model_name="content_safety", context={"user_message": "Hello world"} ) @@ -197,7 +202,7 @@ result = await content_safety_check_input( result = await content_safety_check_input( llms=llms, llm_task_manager=task_manager, - model_name="safety_model", + model_name="content_safety", context={"user_message": "Hello world"} ) ``` @@ -207,8 +212,8 @@ result = await content_safety_check_input( The implementation includes debug logging: - Cache creation: `"Created cache for model '{model_name}' with capacity {capacity}"` -- Cache hits: `"Content safety cache hit for model '{model_name}', key: {key[:8]}..."` -- Cache stores: `"Content safety result cached for model '{model_name}', key: {key[:8]}..."` +- Cache hits: `"Content safety cache hit for model '{model_name}'"` +- Cache stores: `"Content safety result cached for model '{model_name}'"` Enable debug logging to monitor cache behavior: diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py index 4b755a947..c8fb0e662 100644 --- a/nemoguardrails/library/content_safety/actions.py +++ b/nemoguardrails/library/content_safety/actions.py @@ -40,8 +40,9 @@ def _create_cache_key(prompt: Union[str, List[str]]) -> str: else: prompt_str = prompt - # normalize the prompt to a string - # should we do more normalizations? + # Normalize the prompt by collapsing all whitespace sequences to a single space + # and stripping leading/trailing whitespace. This ensures semantically equivalent + # prompts map to the same cache key. No further normalization is currently needed. return PROMPT_PATTERN_WHITESPACES.sub(" ", prompt_str).strip() diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 50c1190c1..0a14c33cd 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -15,8 +15,6 @@ """Module for the configuration of rails.""" -from __future__ import annotations - import logging import os import warnings @@ -71,6 +69,41 @@ colang_path_dirs.append(guardrails_stdlib_path) +class CacheStatsConfig(BaseModel): + """Configuration for cache statistics tracking and logging.""" + + enabled: bool = Field( + default=False, + description="Whether cache statistics tracking is enabled", + ) + log_interval: Optional[float] = Field( + default=None, + description="Seconds between periodic cache stats logging to logs (None disables logging)", + ) + + +class ModelCacheConfig(BaseModel): + """Configuration for model caching.""" + + enabled: bool = Field( + default=False, + description="Whether caching is enabled (default: False - no caching)", + ) + capacity_per_model: int = Field( + default=50000, description="Maximum number of entries in the cache per model" + ) + store: str = Field( + default="memory", description="Cache store: 'memory', 'filesystem', 'redis'" + ) + store_config: Dict[str, Any] = Field( + default_factory=dict, description="Backend-specific configuration" + ) + stats: CacheStatsConfig = Field( + default_factory=CacheStatsConfig, + description="Configuration for cache statistics tracking and logging", + ) + + class Model(BaseModel): """Configuration of a model used by the rails engine. @@ -878,41 +911,6 @@ class AIDefenseRailConfig(BaseModel): ) -class CacheStatsConfig(BaseModel): - """Configuration for cache statistics tracking and logging.""" - - enabled: bool = Field( - default=False, - description="Whether cache statistics tracking is enabled", - ) - log_interval: Optional[float] = Field( - default=None, - description="Seconds between periodic cache stats logging to logs (None disables logging)", - ) - - -class ModelCacheConfig(BaseModel): - """Configuration for model caching.""" - - enabled: bool = Field( - default=False, - description="Whether caching is enabled (default: False - no caching)", - ) - capacity_per_model: int = Field( - default=50000, description="Maximum number of entries in the cache per model" - ) - store: str = Field( - default="memory", description="Cache store: 'memory', 'filesystem', 'redis'" - ) - store_config: Dict[str, Any] = Field( - default_factory=dict, description="Backend-specific configuration" - ) - stats: CacheStatsConfig = Field( - default_factory=CacheStatsConfig, - description="Configuration for cache statistics tracking and logging", - ) - - class RailsConfigData(BaseModel): """Configuration data for specific rails that are supported out-of-the-box.""" diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 7d3c350c5..660907ed1 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -527,7 +527,16 @@ def _init_llms(self): self.runtime.register_action_param("llms", llms) - # Create cache per model + # Initialize caches for models + self._init_model_caches() + + def _init_model_caches(self): + """ + Initialize caches for models that have caching configured. + + Creates per-model cache instances and registers them as action parameters. + Only models that are not 'main' or 'embeddings' types are eligible for caching. + """ for model in self.config.models: if model.type not in ["main", "embeddings"]: cache = None @@ -556,8 +565,7 @@ def _init_llms(self): self.runtime.register_action_param(f"model_cache_{model.type}", cache) log.info( - f"Initialized model '{model.type}' with cache %s", - "enabled" if cache else "disabled", + f"Initialized model '{model.type}' with cache {'enabled' if cache else 'disabled'}" ) def _get_embeddings_search_provider_instance( From 026980dafaae62b02300a2ca8eb70e9e075ac216 Mon Sep 17 00:00:00 2001 From: Oren Hazai Date: Sun, 12 Oct 2025 17:07:55 +0300 Subject: [PATCH 08/19] completely remove persistence --- .vscode/settings.json | 3 +- nemoguardrails/cache/interface.py | 29 ------ nemoguardrails/cache/lfu.py | 156 ------------------------------ tests/test_cache_lfu.py | 30 +----- 4 files changed, 6 insertions(+), 212 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 38eb07063..a19072329 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -46,7 +46,7 @@ "editor.defaultFormatter": "ms-python.black-formatter" }, "python.envFile": "${workspaceFolder}/.venv", - "python.languageServer": "Pylance", + "python.languageServer": "None", "python.testing.pytestEnabled": true, "python.testing.pytestArgs": [ "${workspaceFolder}/tests", @@ -55,7 +55,6 @@ "python.testing.unittestEnabled": false, //"python.envFile": "${workspaceFolder}/python_release.env", - // MYPY "mypy-type-checker.args": [ "--ignore-missing-imports", diff --git a/nemoguardrails/cache/interface.py b/nemoguardrails/cache/interface.py index d724d6999..f07040098 100644 --- a/nemoguardrails/cache/interface.py +++ b/nemoguardrails/cache/interface.py @@ -18,10 +18,6 @@ This module defines the abstract base class for cache implementations that can be used interchangeably throughout the guardrails system. - -Cache implementations may optionally support persistence by overriding -the persist_now() method and supports_persistence() method. Persistence -allows cache state to be saved to and loaded from external storage. """ from abc import ABC, abstractmethod @@ -121,31 +117,6 @@ def capacity(self) -> int: """ pass - def persist_now(self) -> None: - """ - Force immediate persistence of cache to storage. - - This is an optional method that cache implementations can override - if they support persistence. The default implementation does nothing. - - Implementations that support persistence should save the current - cache state to their configured storage backend. - """ - # Default no-op implementation - pass - - def supports_persistence(self) -> bool: - """ - Check if this cache implementation supports persistence. - - Returns: - True if the cache supports persistence, False otherwise. - - The default implementation returns False. Cache implementations - that support persistence should override this to return True. - """ - return False - def get_stats(self) -> dict: """ Get cache statistics. diff --git a/nemoguardrails/cache/lfu.py b/nemoguardrails/cache/lfu.py index 4f8e450c0..fb76c6b10 100644 --- a/nemoguardrails/cache/lfu.py +++ b/nemoguardrails/cache/lfu.py @@ -16,9 +16,7 @@ """Least Frequently Used (LFU) cache implementation.""" import asyncio -import json import logging -import os import threading import time from typing import Any, Callable, Optional @@ -88,8 +86,6 @@ def __init__( self, capacity: int, track_stats: bool = False, - persistence_interval: Optional[float] = None, - persistence_path: Optional[str] = None, stats_logging_interval: Optional[float] = None, ) -> None: """ @@ -98,8 +94,6 @@ def __init__( Args: capacity: Maximum number of items the cache can hold track_stats: Enable tracking of cache statistics - persistence_interval: Seconds between periodic dumps to disk (None disables persistence) - persistence_path: Path to persistence file (defaults to 'lfu_cache.json' if persistence enabled) stats_logging_interval: Seconds between periodic stats logging (None disables logging) """ if capacity < 0: @@ -114,12 +108,6 @@ def __init__( self.freq_map: dict[int, DoublyLinkedList] = {} # frequency -> list of nodes self.min_freq = 0 # Track minimum frequency for eviction - # Persistence configuration - self.persistence_interval = persistence_interval - self.persistence_path = persistence_path or "lfu_cache.json" - # Initialize to None to ensure first check doesn't trigger immediately - self.last_persist_time = None - # Stats logging configuration self.stats_logging_interval = stats_logging_interval # Initialize to None to ensure first check doesn't trigger immediately @@ -135,10 +123,6 @@ def __init__( "updates": 0, } - # Load from disk if persistence is enabled and file exists - if self.persistence_interval is not None: - self._load_from_disk() - def _update_node_freq(self, node: LFUNode) -> None: """Update the frequency of a node and move it to the appropriate frequency list.""" old_freq = node.freq @@ -175,9 +159,6 @@ def get(self, key: Any, default: Any = None) -> Any: The value associated with the key, or default if not found """ with self._lock: - # Check if we should persist - self._check_and_persist() - # Check if we should log stats self._check_and_log_stats() @@ -203,9 +184,6 @@ def put(self, key: Any, value: Any) -> None: value: The value to associate with the key """ with self._lock: - # Check if we should persist - self._check_and_persist() - # Check if we should log stats self._check_and_log_stats() @@ -312,109 +290,6 @@ def reset_stats(self) -> None: "updates": 0, } - def _check_and_persist(self) -> None: - """Check if enough time has passed and persist to disk if needed.""" - if self.persistence_interval is None: - return - - current_time = time.time() - - # Initialize timestamp on first check - if self.last_persist_time is None: - self.last_persist_time = current_time - return - - if current_time - self.last_persist_time >= self.persistence_interval: - self._persist_to_disk() - self.last_persist_time = current_time - - def _persist_to_disk(self) -> None: - """ - Serialize cache to disk. - - Stores cache data as JSON with node information including keys, values, - frequencies, and timestamps for reconstruction. - """ - if not self.key_map: - # If cache is empty, remove the persistence file - if os.path.exists(self.persistence_path): - os.remove(self.persistence_path) - return - - cache_data = { - "capacity": self._capacity, - "min_freq": self.min_freq, - "nodes": [], - } - - # Serialize all nodes - for key, node in self.key_map.items(): - cache_data["nodes"].append( - { - "key": key, - "value": node.value, - "freq": node.freq, - "created_at": node.created_at, - "accessed_at": node.accessed_at, - } - ) - - # Write to disk - try: - with open(self.persistence_path, "w") as f: - json.dump(cache_data, f, indent=2) - except Exception as e: - # Silently fail on persistence errors to not disrupt cache operations - pass - - def _load_from_disk(self) -> None: - """ - Load cache from disk if persistence file exists. - - Reconstructs the cache state including frequency lists and node relationships. - """ - if not os.path.exists(self.persistence_path): - return - - try: - with open(self.persistence_path, "r") as f: - cache_data = json.load(f) - - # Reconstruct cache - self.min_freq = cache_data.get("min_freq", 0) - - for node_data in cache_data.get("nodes", []): - # Create node - node = LFUNode(node_data["key"], node_data["value"]) - node.freq = node_data["freq"] - node.created_at = node_data["created_at"] - node.accessed_at = node_data["accessed_at"] - - # Add to key map - self.key_map[node.key] = node - - # Add to appropriate frequency list - if node.freq not in self.freq_map: - self.freq_map[node.freq] = DoublyLinkedList() - self.freq_map[node.freq].append(node) - - except Exception as e: - # If loading fails, start with empty cache - self.key_map.clear() - self.freq_map.clear() - self.min_freq = 0 - - def persist_now(self) -> None: - """Force immediate persistence to disk (useful for shutdown).""" - with self._lock: - if self.persistence_interval is not None: - self._persist_to_disk() - self.last_persist_time = time.time() - - def supports_persistence(self) -> bool: - """Check if this cache instance supports persistence.""" - return self.persistence_interval is not None - def _check_and_log_stats(self) -> None: """Check if enough time has passed and log stats if needed.""" if not self.track_stats or self.stats_logging_interval is None: @@ -644,34 +519,3 @@ def capacity(self) -> int: # Reset statistics stats_cache.reset_stats() print(f"\nAfter reset: {stats_cache.get_stats()}") - - print("\n=== Cache with Persistence ===") - - # Create cache with persistence (5 second interval) - persist_cache = LFUCache( - capacity=3, persistence_interval=5.0, persistence_path="test_cache.json" - ) - - # Add some items - persist_cache.put("item1", "value1") - persist_cache.put("item2", "value2") - persist_cache.put("item3", "value3") - - # Force immediate persistence - persist_cache.persist_now() - print("Cache persisted to disk") - - # Create new cache instance that will load from disk - new_cache = LFUCache( - capacity=3, persistence_interval=5.0, persistence_path="test_cache.json" - ) - - # Verify data was loaded - print(f"Loaded item1: {new_cache.get('item1')}") # Should return 'value1' - print(f"Loaded item2: {new_cache.get('item2')}") # Should return 'value2' - print(f"Cache size after loading: {new_cache.size()}") # Should return 3 - - # Clean up - if os.path.exists("test_cache.json"): - os.remove("test_cache.json") - print("Cleaned up test persistence file") diff --git a/tests/test_cache_lfu.py b/tests/test_cache_lfu.py index 3f9340452..bb01bcbac 100644 --- a/tests/test_cache_lfu.py +++ b/tests/test_cache_lfu.py @@ -17,13 +17,11 @@ Comprehensive test suite for LFU Cache implementation. Tests all functionality including basic operations, eviction policies, -capacity management, edge cases, and persistence functionality. +capacity management, and edge cases. """ import asyncio -import json import os -import tempfile import threading import time import unittest @@ -368,15 +366,6 @@ def test_interface_methods_exist(self): class TestLFUCacheStatsLogging(unittest.TestCase): """Test cases for LFU Cache statistics logging functionality.""" - def setUp(self): - """Set up test fixtures.""" - self.test_file = tempfile.mktemp() - - def tearDown(self): - """Clean up test files.""" - if os.path.exists(self.test_file): - os.remove(self.test_file) - def test_stats_logging_disabled_by_default(self): """Test that stats logging is disabled when not configured.""" cache = LFUCache(5, track_stats=True) @@ -644,15 +633,6 @@ def test_stats_log_format_percentages(self): class TestContentSafetyCacheStatsConfig(unittest.TestCase): """Test cache stats configuration in content safety context.""" - def setUp(self): - """Set up test fixtures.""" - self.test_file = tempfile.mktemp() - - def tearDown(self): - """Clean up test files.""" - if os.path.exists(self.test_file): - os.remove(self.test_file) - def test_cache_config_with_stats_disabled(self): """Test cache configuration with stats disabled.""" from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig @@ -1193,11 +1173,11 @@ def worker(thread_id): f"Wrong value for {key}: expected {value}, got {retrieved}" ) - # Also work with some persistent keys (access multiple times) - persistent_key = f"persistent_{thread_id % 5}" + # Also work with some high-frequency keys (access multiple times) + high_freq_key = f"high_freq_{thread_id % 5}" for _ in range(3): # Access 3 times to increase frequency - small_cache.put(persistent_key, f"persistent_value_{thread_id}") - small_cache.get(persistent_key) + small_cache.put(high_freq_key, f"high_freq_value_{thread_id}") + small_cache.get(high_freq_key) # Run workers with ThreadPoolExecutor(max_workers=10) as executor: From f56a8e2abb21673e4c8ec52d380d18903e56336b Mon Sep 17 00:00:00 2001 From: Pouyan <13303554+Pouyanpi@users.noreply.github.com> Date: Thu, 16 Oct 2025 20:34:42 +0200 Subject: [PATCH 09/19] review: PR #1436 (#1451) --- examples/configs/content_safety/config.yml | 4 +- nemoguardrails/cache/__init__.py | 3 +- nemoguardrails/cache/utils.py | 160 +++++++++ .../library/content_safety/actions.py | 60 ++-- nemoguardrails/logging/explain.py | 4 + nemoguardrails/rails/llm/config.py | 8 +- nemoguardrails/rails/llm/llmrails.py | 77 +++-- nemoguardrails/tracing/constants.py | 5 +- nemoguardrails/tracing/span_extractors.py | 3 + nemoguardrails/tracing/spans.py | 7 + tests/test_cache_lfu.py | 18 +- tests/test_cache_utils.py | 315 ++++++++++++++++++ tests/test_content_safety_cache.py | 215 ++++++++++++ tests/test_integration_cache.py | 150 +++++++++ tests/test_llmrails.py | 185 ++++++++++ tests/tracing/spans/test_span_extractors.py | 73 ++++ 16 files changed, 1195 insertions(+), 92 deletions(-) create mode 100644 nemoguardrails/cache/utils.py create mode 100644 tests/test_cache_utils.py create mode 100644 tests/test_content_safety_cache.py create mode 100644 tests/test_integration_cache.py diff --git a/examples/configs/content_safety/config.yml b/examples/configs/content_safety/config.yml index 474a43dad..8e6e4a59c 100644 --- a/examples/configs/content_safety/config.yml +++ b/examples/configs/content_safety/config.yml @@ -9,7 +9,7 @@ models: # Model-specific cache configuration (optional) cache: enabled: true - capacity_per_model: 50000 # Larger cache for primary model + maxsize: 50000 # Larger cache for primary model stats: enabled: true log_interval: 60.0 # Log stats every minute @@ -22,3 +22,5 @@ rails: output: flows: - content safety check output $model=content_safety +tracing: + enabled: True diff --git a/nemoguardrails/cache/__init__.py b/nemoguardrails/cache/__init__.py index e7f22f070..91042bcb1 100644 --- a/nemoguardrails/cache/__init__.py +++ b/nemoguardrails/cache/__init__.py @@ -17,5 +17,6 @@ from nemoguardrails.cache.interface import CacheInterface from nemoguardrails.cache.lfu import LFUCache +from nemoguardrails.cache.utils import create_normalized_cache_key -__all__ = ["CacheInterface", "LFUCache"] +__all__ = ["CacheInterface", "LFUCache", "create_normalized_cache_key"] diff --git a/nemoguardrails/cache/utils.py b/nemoguardrails/cache/utils.py new file mode 100644 index 000000000..0291e2fa3 --- /dev/null +++ b/nemoguardrails/cache/utils.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import json +import re +from time import time +from typing import TYPE_CHECKING, List, Optional, TypedDict, Union + +from nemoguardrails.context import llm_call_info_var, llm_stats_var +from nemoguardrails.logging.processing_log import processing_log_var +from nemoguardrails.logging.stats import LLMStats + +if TYPE_CHECKING: + from nemoguardrails.cache.interface import CacheInterface + +PROMPT_PATTERN_WHITESPACES = re.compile(r"\s+") + + +class LLMStatsDict(TypedDict): + total_tokens: int + prompt_tokens: int + completion_tokens: int + + +class CacheEntry(TypedDict): + result: dict + llm_stats: Optional[LLMStatsDict] + + +def create_normalized_cache_key( + prompt: Union[str, List[dict]], normalize_whitespace: bool = True +) -> str: + """ + Create a normalized, hashed cache key from a prompt. + + This function generates a deterministic cache key by normalizing the prompt + and applying SHA-256 hashing. The normalization ensures that semantically + equivalent prompts produce the same cache key. + + Args: + prompt: The prompt to be cached. Can be: + - str: A single prompt string (for completion models) + - List[dict]: A list of message dictionaries for chat models + (e.g., [{"type": "user", "content": "Hello"}]) + Note: render_task_prompt() returns Union[str, List[dict]] + normalize_whitespace: Whether to normalize whitespace characters. + When True, collapses all whitespace sequences to single spaces and + strips leading/trailing whitespace. Default: True + + Returns: + A SHA-256 hex digest string (64 characters) suitable for use as a cache key + + Raises: + TypeError: If prompt is not a str or List[dict] + + Examples: + >>> create_normalized_cache_key("Hello world") + '64ec88ca00b268e5ba1a35678a1b5316d212f4f366b2477232534a8aeca37f3c' + + >>> create_normalized_cache_key([{"type": "user", "content": "Hello"}]) + 'b2f5c9d8e3a1f7b6c4d2e5f8a9c1d3e5f7b9a2c4d6e8f1a3b5c7d9e2f4a6b8' + """ + if isinstance(prompt, str): + prompt_str = prompt + elif isinstance(prompt, list): + if not all(isinstance(p, dict) for p in prompt): + raise TypeError( + f"All elements in prompt list must be dictionaries (messages). " + f"Got types: {[type(p).__name__ for p in prompt]}" + ) + prompt_str = json.dumps(prompt, sort_keys=True) + else: + raise TypeError( + f"Invalid type for prompt: {type(prompt).__name__}. " + f"Expected str or List[dict]." + ) + + if normalize_whitespace: + prompt_str = PROMPT_PATTERN_WHITESPACES.sub(" ", prompt_str).strip() + + return hashlib.sha256(prompt_str.encode("utf-8")).hexdigest() + + +def restore_llm_stats_from_cache( + cached_stats: LLMStatsDict, cache_read_duration: float +) -> None: + llm_stats = llm_stats_var.get() + if llm_stats is None: + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + llm_stats.inc("total_calls") + llm_stats.inc("total_time", cache_read_duration) + llm_stats.inc("total_tokens", cached_stats.get("total_tokens", 0)) + llm_stats.inc("total_prompt_tokens", cached_stats.get("prompt_tokens", 0)) + llm_stats.inc("total_completion_tokens", cached_stats.get("completion_tokens", 0)) + + llm_call_info = llm_call_info_var.get() + if llm_call_info: + llm_call_info.duration = cache_read_duration + llm_call_info.total_tokens = cached_stats.get("total_tokens", 0) + llm_call_info.prompt_tokens = cached_stats.get("prompt_tokens", 0) + llm_call_info.completion_tokens = cached_stats.get("completion_tokens", 0) + llm_call_info.from_cache = True + llm_call_info.started_at = time() - cache_read_duration + llm_call_info.finished_at = time() + + +def extract_llm_stats_for_cache() -> Optional[LLMStatsDict]: + llm_call_info = llm_call_info_var.get() + if llm_call_info: + return { + "total_tokens": llm_call_info.total_tokens or 0, + "prompt_tokens": llm_call_info.prompt_tokens or 0, + "completion_tokens": llm_call_info.completion_tokens or 0, + } + return None + + +def get_from_cache_and_restore_stats( + cache: "CacheInterface", cache_key: str +) -> Optional[dict]: + cached_entry = cache.get(cache_key) + if cached_entry is None: + return None + + cache_read_start = time() + final_result = cached_entry["result"] + cached_stats = cached_entry.get("llm_stats") + cache_read_duration = time() - cache_read_start + + if cached_stats: + restore_llm_stats_from_cache(cached_stats, cache_read_duration) + + processing_log = processing_log_var.get() + if processing_log: + llm_call_info = llm_call_info_var.get() + if llm_call_info: + processing_log.append( + { + "type": "llm_call_info", + "timestamp": time(), + "data": llm_call_info, + } + ) + + return final_result diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py index c8fb0e662..93f5908c9 100644 --- a/nemoguardrails/library/content_safety/actions.py +++ b/nemoguardrails/library/content_safety/actions.py @@ -13,15 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging -import re -from typing import Dict, List, Optional, Union +from typing import Dict, Optional from langchain_core.language_models.llms import BaseLLM from nemoguardrails.actions.actions import action from nemoguardrails.actions.llm.utils import llm_call +from nemoguardrails.cache import CacheInterface +from nemoguardrails.cache.utils import ( + CacheEntry, + create_normalized_cache_key, + extract_llm_stats_for_cache, + get_from_cache_and_restore_stats, +) from nemoguardrails.context import llm_call_info_var from nemoguardrails.llm.taskmanager import LLMTaskManager from nemoguardrails.logging.explain import LLMCallInfo @@ -29,38 +34,13 @@ log = logging.getLogger(__name__) -PROMPT_PATTERN_WHITESPACES = re.compile(r"\s+") - - -def _create_cache_key(prompt: Union[str, List[str]]) -> str: - """Create a cache key from the prompt.""" - # can the prompt really be a list? - if isinstance(prompt, list): - prompt_str = json.dumps(prompt) - else: - prompt_str = prompt - - # Normalize the prompt by collapsing all whitespace sequences to a single space - # and stripping leading/trailing whitespace. This ensures semantically equivalent - # prompts map to the same cache key. No further normalization is currently needed. - return PROMPT_PATTERN_WHITESPACES.sub(" ", prompt_str).strip() - - -# Thread Safety Note: -# The content safety caching mechanism is thread-safe for single-node deployments. -# The underlying LFUCache uses threading.RLock to ensure atomic operations. -# -# However, this implementation is NOT suitable for distributed environments. -# For multi-node deployments, consider using distributed caching solutions -# like Redis or a shared database. - - @action() async def content_safety_check_input( llms: Dict[str, BaseLLM], llm_task_manager: LLMTaskManager, model_name: Optional[str] = None, context: Optional[dict] = None, + model_caches: Optional[Dict[str, CacheInterface]] = None, **kwargs, ) -> dict: _MAX_TOKENS = 3 @@ -103,21 +83,15 @@ async def content_safety_check_input( max_tokens = max_tokens or _MAX_TOKENS - # Check cache if available for this model - cached_result = None - cache_key = None - - # Try to get the model-specific cache - cache = kwargs.get(f"model_cache_{model_name}") + cache = model_caches.get(model_name) if model_caches else None if cache: - cache_key = _create_cache_key(check_input_prompt) - cached_result = cache.get(cache_key) + cache_key = create_normalized_cache_key(check_input_prompt) + cached_result = get_from_cache_and_restore_stats(cache, cache_key) if cached_result is not None: log.debug(f"Content safety cache hit for model '{model_name}'") return cached_result - # Make the actual LLM call result = await llm_call( llm, check_input_prompt, @@ -131,9 +105,13 @@ async def content_safety_check_input( final_result = {"allowed": is_safe, "policy_violations": violated_policies} - # Store in cache if available - if cache_key and cache: - cache.put(cache_key, final_result) + if cache: + cache_key = create_normalized_cache_key(check_input_prompt) + cache_entry: CacheEntry = { + "result": final_result, + "llm_stats": extract_llm_stats_for_cache(), + } + cache.put(cache_key, cache_entry) log.debug(f"Content safety result cached for model '{model_name}'") return final_result diff --git a/nemoguardrails/logging/explain.py b/nemoguardrails/logging/explain.py index edf7825c2..ad75e5c45 100644 --- a/nemoguardrails/logging/explain.py +++ b/nemoguardrails/logging/explain.py @@ -63,6 +63,10 @@ class LLMCallInfo(LLMCallSummary): default="unknown", description="The provider of the model used for the LLM call, e.g. 'openai', 'nvidia'.", ) + from_cache: bool = Field( + default=False, + description="Whether this response was retrieved from cache.", + ) class ExplainInfo(BaseModel): diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 0a14c33cd..0c037c092 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -89,15 +89,9 @@ class ModelCacheConfig(BaseModel): default=False, description="Whether caching is enabled (default: False - no caching)", ) - capacity_per_model: int = Field( + maxsize: int = Field( default=50000, description="Maximum number of entries in the cache per model" ) - store: str = Field( - default="memory", description="Cache store: 'memory', 'filesystem', 'redis'" - ) - store_config: Dict[str, Any] = Field( - default_factory=dict, description="Backend-specific configuration" - ) stats: CacheStatsConfig = Field( default_factory=CacheStatsConfig, description="Configuration for cache statistics tracking and logging", diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 660907ed1..c9c87b077 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -50,7 +50,7 @@ ) from nemoguardrails.actions.output_mapping import is_output_blocked from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx -from nemoguardrails.cache.lfu import LFUCache +from nemoguardrails.cache import CacheInterface, LFUCache from nemoguardrails.colang import parse_colang_file from nemoguardrails.colang.v1_0.runtime.flows import _normalize_flow_id, compute_context from nemoguardrails.colang.v1_0.runtime.runtime import Runtime, RuntimeV1_0 @@ -527,47 +527,60 @@ def _init_llms(self): self.runtime.register_action_param("llms", llms) - # Initialize caches for models - self._init_model_caches() + self._initialize_model_caches() - def _init_model_caches(self): + def _create_model_cache(self, model) -> LFUCache: """ - Initialize caches for models that have caching configured. + Create cache instance for a model based on its configuration. - Creates per-model cache instances and registers them as action parameters. - Only models that are not 'main' or 'embeddings' types are eligible for caching. + Args: + model: The model configuration object + + Returns: + LFUCache: The cache instance """ - for model in self.config.models: - if model.type not in ["main", "embeddings"]: - cache = None - - # Create cache if configured - if model.cache and model.cache.enabled: - if model.cache.store == "memory": - stats_logging_interval = None - if ( - model.cache.stats.enabled - and model.cache.stats.log_interval is not None - ): - stats_logging_interval = model.cache.stats.log_interval - - cache = LFUCache( - capacity=model.cache.capacity_per_model, - track_stats=model.cache.stats.enabled, - stats_logging_interval=stats_logging_interval, - ) - log.info( - f"Created cache for model '{model.type}' with capacity {model.cache.capacity_per_model}" - ) + if model.cache.maxsize <= 0: + raise ValueError( + f"Invalid cache capacity for model '{model.type}': {model.cache.maxsize}. " + "Capacity must be greater than 0. Skipping cache creation." + ) - # Register the cache for this specific model - self.runtime.register_action_param(f"model_cache_{model.type}", cache) + stats_logging_interval = None + if model.cache.stats.enabled and model.cache.stats.log_interval is not None: + stats_logging_interval = model.cache.stats.log_interval + + cache = LFUCache( + capacity=model.cache.maxsize, + track_stats=model.cache.stats.enabled, + stats_logging_interval=stats_logging_interval, + ) + + log.info( + f"Created cache for model '{model.type}' with capacity {model.cache.maxsize}" + ) + + return cache + + def _initialize_model_caches(self) -> None: + """Initialize caches for configured models.""" + model_caches: Optional[Dict[str, CacheInterface]] = dict() + for model in self.config.models: + if model.type in ["main", "embeddings"]: + continue + + if model.cache and model.cache.enabled: + cache = self._create_model_cache(model) + model_caches[model.type] = cache log.info( - f"Initialized model '{model.type}' with cache {'enabled' if cache else 'disabled'}" + f"Initialized model '{model.type}' with cache %s", + "enabled" if cache else "disabled", ) + if model_caches: + self.runtime.register_action_param("model_caches", model_caches) + def _get_embeddings_search_provider_instance( self, esp_config: Optional[EmbeddingSearchProvider] = None ) -> EmbeddingsIndex: diff --git a/nemoguardrails/tracing/constants.py b/nemoguardrails/tracing/constants.py index 3e0bf3179..2cf411413 100644 --- a/nemoguardrails/tracing/constants.py +++ b/nemoguardrails/tracing/constants.py @@ -119,7 +119,10 @@ class GuardrailsAttributes: ACTION_NAME = "action.name" ACTION_HAS_LLM_CALLS = "action.has_llm_calls" ACTION_LLM_CALLS_COUNT = "action.llm_calls_count" - ACTION_PARAM_PREFIX = "action.param." # For dynamic action parameters + ACTION_PARAM_PREFIX = "action.param." + + # llm attributes (application-level, not provider-level) + LLM_CACHE_HIT = "llm.cache.hit" class SpanNames: diff --git a/nemoguardrails/tracing/span_extractors.py b/nemoguardrails/tracing/span_extractors.py index cca40024f..e40318d01 100644 --- a/nemoguardrails/tracing/span_extractors.py +++ b/nemoguardrails/tracing/span_extractors.py @@ -257,6 +257,8 @@ def extract_spans( max_tokens = llm_call.raw_response.get("max_tokens") top_p = llm_call.raw_response.get("top_p") + cache_hit = hasattr(llm_call, "from_cache") and llm_call.from_cache + llm_span = LLMSpan( span_id=new_uuid(), name=span_name, @@ -276,6 +278,7 @@ def extract_spans( top_p=top_p, response_id=response_id, response_finish_reasons=finish_reasons, + cache_hit=cache_hit, # TODO: add error to LLMCallInfo for future release # error=( # True diff --git a/nemoguardrails/tracing/spans.py b/nemoguardrails/tracing/spans.py index fb89fb394..e8e7ec565 100644 --- a/nemoguardrails/tracing/spans.py +++ b/nemoguardrails/tracing/spans.py @@ -270,6 +270,11 @@ class LLMSpan(BaseSpan): default=None, description="Finish reasons for each choice" ) + cache_hit: bool = Field( + default=False, + description="Whether this LLM response was served from application cache", + ) + def to_otel_attributes(self) -> Dict[str, Any]: """Convert to OTel attributes.""" attributes = self._base_attributes() @@ -320,6 +325,8 @@ def to_otel_attributes(self) -> Dict[str, Any]: GenAIAttributes.GEN_AI_RESPONSE_FINISH_REASONS ] = self.response_finish_reasons + attributes[GuardrailsAttributes.LLM_CACHE_HIT] = self.cache_hit + return attributes diff --git a/tests/test_cache_lfu.py b/tests/test_cache_lfu.py index bb01bcbac..2b1d986f7 100644 --- a/tests/test_cache_lfu.py +++ b/tests/test_cache_lfu.py @@ -638,11 +638,11 @@ def test_cache_config_with_stats_disabled(self): from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig cache_config = ModelCacheConfig( - enabled=True, capacity_per_model=1000, stats=CacheStatsConfig(enabled=False) + enabled=True, maxsize=1000, stats=CacheStatsConfig(enabled=False) ) cache = LFUCache( - capacity=cache_config.capacity_per_model, + capacity=cache_config.maxsize, track_stats=cache_config.stats.enabled, stats_logging_interval=None, ) @@ -657,12 +657,12 @@ def test_cache_config_with_stats_tracking_only(self): cache_config = ModelCacheConfig( enabled=True, - capacity_per_model=1000, + maxsize=1000, stats=CacheStatsConfig(enabled=True, log_interval=None), ) cache = LFUCache( - capacity=cache_config.capacity_per_model, + capacity=cache_config.maxsize, track_stats=cache_config.stats.enabled, stats_logging_interval=cache_config.stats.log_interval, ) @@ -678,12 +678,12 @@ def test_cache_config_with_stats_logging(self): cache_config = ModelCacheConfig( enabled=True, - capacity_per_model=1000, + maxsize=1000, stats=CacheStatsConfig(enabled=True, log_interval=60.0), ) cache = LFUCache( - capacity=cache_config.capacity_per_model, + capacity=cache_config.maxsize, track_stats=cache_config.stats.enabled, stats_logging_interval=cache_config.stats.log_interval, ) @@ -700,7 +700,7 @@ def test_cache_config_default_stats(self): cache_config = ModelCacheConfig(enabled=True) cache = LFUCache( - capacity=cache_config.capacity_per_model, + capacity=cache_config.maxsize, track_stats=cache_config.stats.enabled, stats_logging_interval=None, ) @@ -715,13 +715,13 @@ def test_cache_config_from_dict(self): config_dict = { "enabled": True, - "capacity_per_model": 5000, + "maxsize": 5000, "stats": {"enabled": True, "log_interval": 120.0}, } cache_config = ModelCacheConfig(**config_dict) self.assertTrue(cache_config.enabled) - self.assertEqual(cache_config.capacity_per_model, 5000) + self.assertEqual(cache_config.maxsize, 5000) self.assertTrue(cache_config.stats.enabled) self.assertEqual(cache_config.stats.log_interval, 120.0) diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py new file mode 100644 index 000000000..3e0d98a57 --- /dev/null +++ b/tests/test_cache_utils.py @@ -0,0 +1,315 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import pytest + +from nemoguardrails.cache.lfu import LFUCache +from nemoguardrails.cache.utils import ( + create_normalized_cache_key, + extract_llm_stats_for_cache, + get_from_cache_and_restore_stats, + restore_llm_stats_from_cache, +) +from nemoguardrails.context import llm_call_info_var, llm_stats_var +from nemoguardrails.logging.explain import LLMCallInfo +from nemoguardrails.logging.stats import LLMStats + + +class TestCacheUtils: + def test_create_normalized_cache_key_returns_sha256_hash(self): + key = create_normalized_cache_key("Hello world") + assert len(key) == 64 + assert all(c in "0123456789abcdef" for c in key) + + @pytest.mark.parametrize( + "prompt", + [ + "Hello world", + "", + " Hello world ", + "Hello world test", + "Hello\t\n\r world", + "Hello \n\t world", + ], + ) + def test_create_normalized_cache_key_with_whitespace_normalization(self, prompt): + key = create_normalized_cache_key(prompt, normalize_whitespace=True) + assert len(key) == 64 + assert all(c in "0123456789abcdef" for c in key) + + @pytest.mark.parametrize( + "prompt", + [ + "Hello world", + "Hello \n\t world", + " spaces ", + ], + ) + def test_create_normalized_cache_key_without_whitespace_normalization(self, prompt): + key = create_normalized_cache_key(prompt, normalize_whitespace=False) + assert len(key) == 64 + assert all(c in "0123456789abcdef" for c in key) + + @pytest.mark.parametrize( + "prompt1,prompt2", + [ + ("Hello \n world", "Hello world"), + ("test\t\nstring", "test string"), + (" leading", "leading"), + ], + ) + def test_create_normalized_cache_key_consistent_for_same_input( + self, prompt1, prompt2 + ): + key1 = create_normalized_cache_key(prompt1, normalize_whitespace=True) + key2 = create_normalized_cache_key(prompt2, normalize_whitespace=True) + assert key1 == key2 + + @pytest.mark.parametrize( + "prompt1,prompt2", + [ + ("Hello world", "Hello world!"), + ("test", "testing"), + ("case", "Case"), + ], + ) + def test_create_normalized_cache_key_different_for_different_input( + self, prompt1, prompt2 + ): + key1 = create_normalized_cache_key(prompt1) + key2 = create_normalized_cache_key(prompt2) + assert key1 != key2 + + def test_create_normalized_cache_key_invalid_type_raises_error(self): + with pytest.raises(TypeError, match="Invalid type for prompt: int"): + create_normalized_cache_key(123) + + with pytest.raises(TypeError, match="Invalid type for prompt: dict"): + create_normalized_cache_key({"key": "value"}) + + def test_create_normalized_cache_key_list_of_dicts(self): + messages = [ + {"type": "user", "content": "Hello"}, + {"type": "assistant", "content": "Hi there!"}, + ] + key = create_normalized_cache_key(messages) + assert len(key) == 64 + assert all(c in "0123456789abcdef" for c in key) + + def test_create_normalized_cache_key_list_of_dicts_order_independent(self): + messages1 = [ + {"content": "Hello", "role": "user"}, + {"content": "Hi there!", "role": "assistant"}, + ] + messages2 = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + key1 = create_normalized_cache_key(messages1) + key2 = create_normalized_cache_key(messages2) + assert key1 == key2 + + def test_create_normalized_cache_key_invalid_list_raises_error(self): + with pytest.raises( + TypeError, + match="All elements in prompt list must be dictionaries", + ): + create_normalized_cache_key(["hello", "world"]) + + with pytest.raises( + TypeError, + match="All elements in prompt list must be dictionaries", + ): + create_normalized_cache_key([{"key": "value"}, "test"]) + + with pytest.raises( + TypeError, + match="All elements in prompt list must be dictionaries", + ): + create_normalized_cache_key([123, 456]) + + def test_extract_llm_stats_for_cache_with_llm_call_info(self): + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info.total_tokens = 100 + llm_call_info.prompt_tokens = 50 + llm_call_info.completion_tokens = 50 + llm_call_info_var.set(llm_call_info) + + stats = extract_llm_stats_for_cache() + + assert stats is not None + assert stats["total_tokens"] == 100 + assert stats["prompt_tokens"] == 50 + assert stats["completion_tokens"] == 50 + + llm_call_info_var.set(None) + + def test_extract_llm_stats_for_cache_without_llm_call_info(self): + llm_call_info_var.set(None) + + stats = extract_llm_stats_for_cache() + + assert stats is None + + def test_extract_llm_stats_for_cache_with_none_values(self): + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info.total_tokens = None + llm_call_info.prompt_tokens = None + llm_call_info.completion_tokens = None + llm_call_info_var.set(llm_call_info) + + stats = extract_llm_stats_for_cache() + + assert stats is not None + assert stats["total_tokens"] == 0 + assert stats["prompt_tokens"] == 0 + assert stats["completion_tokens"] == 0 + + llm_call_info_var.set(None) + + def test_restore_llm_stats_from_cache_creates_new_llm_stats(self): + llm_stats_var.set(None) + llm_call_info_var.set(None) + + cached_stats = { + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + } + + restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.01) + + llm_stats = llm_stats_var.get() + assert llm_stats is not None + assert llm_stats.get_stat("total_calls") == 1 + assert llm_stats.get_stat("total_time") == 0.01 + assert llm_stats.get_stat("total_tokens") == 100 + assert llm_stats.get_stat("total_prompt_tokens") == 50 + assert llm_stats.get_stat("total_completion_tokens") == 50 + + llm_stats_var.set(None) + + def test_restore_llm_stats_from_cache_updates_existing_llm_stats(self): + llm_stats = LLMStats() + llm_stats.inc("total_calls", 5) + llm_stats.inc("total_time", 1.0) + llm_stats.inc("total_tokens", 200) + llm_stats_var.set(llm_stats) + + cached_stats = { + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + } + + restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.5) + + llm_stats = llm_stats_var.get() + assert llm_stats.get_stat("total_calls") == 6 + assert llm_stats.get_stat("total_time") == 1.5 + assert llm_stats.get_stat("total_tokens") == 300 + + llm_stats_var.set(None) + + def test_restore_llm_stats_from_cache_updates_llm_call_info(self): + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info_var.set(llm_call_info) + llm_stats_var.set(None) + + cached_stats = { + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + } + + restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.02) + + updated_info = llm_call_info_var.get() + assert updated_info is not None + assert updated_info.duration == 0.02 + assert updated_info.total_tokens == 100 + assert updated_info.prompt_tokens == 50 + assert updated_info.completion_tokens == 50 + assert updated_info.from_cache is True + assert updated_info.started_at is not None + assert updated_info.finished_at is not None + + llm_call_info_var.set(None) + llm_stats_var.set(None) + + def test_get_from_cache_and_restore_stats_cache_miss(self): + cache = LFUCache(capacity=10) + llm_call_info_var.set(None) + llm_stats_var.set(None) + + result = get_from_cache_and_restore_stats(cache, "nonexistent_key") + + assert result is None + + llm_call_info_var.set(None) + llm_stats_var.set(None) + + def test_get_from_cache_and_restore_stats_cache_hit(self): + cache = LFUCache(capacity=10) + cache_entry = { + "result": {"allowed": True, "policy_violations": []}, + "llm_stats": { + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + }, + } + cache.put("test_key", cache_entry) + + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info_var.set(llm_call_info) + llm_stats_var.set(None) + + result = get_from_cache_and_restore_stats(cache, "test_key") + + assert result is not None + assert result == {"allowed": True, "policy_violations": []} + + llm_stats = llm_stats_var.get() + assert llm_stats is not None + assert llm_stats.get_stat("total_calls") == 1 + assert llm_stats.get_stat("total_tokens") == 100 + + updated_info = llm_call_info_var.get() + assert updated_info.from_cache is True + + llm_call_info_var.set(None) + llm_stats_var.set(None) + + def test_get_from_cache_and_restore_stats_without_llm_stats(self): + cache = LFUCache(capacity=10) + cache_entry = { + "result": {"allowed": False, "policy_violations": ["policy1"]}, + "llm_stats": None, + } + cache.put("test_key", cache_entry) + + llm_call_info_var.set(None) + llm_stats_var.set(None) + + result = get_from_cache_and_restore_stats(cache, "test_key") + + assert result is not None + assert result == {"allowed": False, "policy_violations": ["policy1"]} + + llm_call_info_var.set(None) + llm_stats_var.set(None) diff --git a/tests/test_content_safety_cache.py b/tests/test_content_safety_cache.py new file mode 100644 index 000000000..d81465b3d --- /dev/null +++ b/tests/test_content_safety_cache.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import pytest + +from nemoguardrails.cache.lfu import LFUCache +from nemoguardrails.cache.utils import create_normalized_cache_key +from nemoguardrails.context import llm_call_info_var, llm_stats_var +from nemoguardrails.library.content_safety.actions import content_safety_check_input +from nemoguardrails.logging.explain import LLMCallInfo +from nemoguardrails.logging.stats import LLMStats +from tests.utils import FakeLLM + + +@pytest.fixture +def mock_task_manager(): + tm = MagicMock() + tm.render_task_prompt.return_value = "test prompt" + tm.get_stop_tokens.return_value = [] + tm.get_max_tokens.return_value = 3 + tm.parse_task_output.return_value = [True, "policy1"] + return tm + + +@pytest.fixture +def fake_llm_with_stats(): + llm = FakeLLM(responses=["safe"]) + return {"test_model": llm} + + +@pytest.mark.asyncio +async def test_content_safety_cache_stores_result_and_stats( + fake_llm_with_stats, mock_task_manager +): + cache = LFUCache(capacity=10) + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + result = await content_safety_check_input( + llms=fake_llm_with_stats, + llm_task_manager=mock_task_manager, + model_name="test_model", + context={"user_message": "test input"}, + model_caches={"test_model": cache}, + ) + + assert result["allowed"] is True + assert result["policy_violations"] == ["policy1"] + assert cache.size() == 1 + + llm_call_info = llm_call_info_var.get() + + cache_key = create_normalized_cache_key("test prompt") + cached_entry = cache.get(cache_key) + assert cached_entry is not None + assert "result" in cached_entry + assert "llm_stats" in cached_entry + + if llm_call_info and ( + llm_call_info.total_tokens + or llm_call_info.prompt_tokens + or llm_call_info.completion_tokens + ): + assert cached_entry["llm_stats"] is not None + else: + assert cached_entry["llm_stats"] is None or all( + v == 0 for v in cached_entry["llm_stats"].values() + ) + + +@pytest.mark.asyncio +async def test_content_safety_cache_retrieves_result_and_restores_stats( + fake_llm_with_stats, mock_task_manager +): + cache = LFUCache(capacity=10) + + cache_entry = { + "result": {"allowed": True, "policy_violations": ["policy1"]}, + "llm_stats": { + "total_tokens": 100, + "prompt_tokens": 80, + "completion_tokens": 20, + }, + } + cache_key = create_normalized_cache_key("test prompt") + cache.put(cache_key, cache_entry) + + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + result = await content_safety_check_input( + llms=fake_llm_with_stats, + llm_task_manager=mock_task_manager, + model_name="test_model", + context={"user_message": "test input"}, + model_caches={"test_model": cache}, + ) + + llm_call_info = llm_call_info_var.get() + + assert result == cache_entry["result"] + assert llm_stats.get_stat("total_calls") == 1 + assert llm_stats.get_stat("total_tokens") == 100 + assert llm_stats.get_stat("total_prompt_tokens") == 80 + assert llm_stats.get_stat("total_completion_tokens") == 20 + + assert llm_call_info.from_cache is True + assert llm_call_info.total_tokens == 100 + assert llm_call_info.prompt_tokens == 80 + assert llm_call_info.completion_tokens == 20 + assert llm_call_info.duration is not None + + +@pytest.mark.asyncio +async def test_content_safety_cache_duration_reflects_cache_read_time( + fake_llm_with_stats, mock_task_manager +): + cache = LFUCache(capacity=10) + + cache_entry = { + "result": {"allowed": True, "policy_violations": []}, + "llm_stats": { + "total_tokens": 50, + "prompt_tokens": 40, + "completion_tokens": 10, + }, + } + cache_key = create_normalized_cache_key("test prompt") + cache.put(cache_key, cache_entry) + + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + await content_safety_check_input( + llms=fake_llm_with_stats, + llm_task_manager=mock_task_manager, + model_name="test_model", + context={"user_message": "test input"}, + model_caches={"test_model": cache}, + ) + + llm_call_info = llm_call_info_var.get() + cache_duration = llm_call_info.duration + + assert cache_duration is not None + assert cache_duration < 0.1 + assert llm_call_info.from_cache is True + + +@pytest.mark.asyncio +async def test_content_safety_without_cache_does_not_store( + fake_llm_with_stats, mock_task_manager +): + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + llm_call_info = LLMCallInfo(task="content_safety_check_input $model=test_model") + llm_call_info_var.set(llm_call_info) + + result = await content_safety_check_input( + llms=fake_llm_with_stats, + llm_task_manager=mock_task_manager, + model_name="test_model", + context={"user_message": "test input"}, + ) + + assert result["allowed"] is True + assert llm_call_info.from_cache is False + + +@pytest.mark.asyncio +async def test_content_safety_cache_handles_missing_stats_gracefully( + fake_llm_with_stats, mock_task_manager +): + cache = LFUCache(capacity=10) + + cache_entry = { + "result": {"allowed": True, "policy_violations": []}, + "llm_stats": None, + } + cache_key = create_normalized_cache_key("test_key") + cache.put(cache_key, cache_entry) + + mock_task_manager.render_task_prompt.return_value = "test_key" + + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + llm_call_info = LLMCallInfo(task="content_safety_check_input $model=test_model") + llm_call_info_var.set(llm_call_info) + + result = await content_safety_check_input( + llms=fake_llm_with_stats, + llm_task_manager=mock_task_manager, + model_name="test_model", + context={"user_message": "test input"}, + model_caches={"test_model": cache}, + ) + + assert result["allowed"] is True + assert llm_stats.get_stat("total_calls") == 0 diff --git a/tests/test_integration_cache.py b/tests/test_integration_cache.py new file mode 100644 index 000000000..3c022031e --- /dev/null +++ b/tests/test_integration_cache.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails.rails.llm.config import CacheStatsConfig, Model, ModelCacheConfig +from tests.utils import FakeLLM + + +@pytest.mark.asyncio +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +async def test_end_to_end_cache_integration_with_content_safety(mock_init_llm_model): + mock_llm = FakeLLM(responses=["express greeting"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + ), + Model( + type="content_safety", + engine="fake", + model="fake-content-safety", + cache=ModelCacheConfig( + enabled=True, + maxsize=100, + stats=CacheStatsConfig(enabled=True), + ), + ), + ] + ) + + llm = FakeLLM(responses=["express greeting"]) + + rails = LLMRails(config=config, llm=llm, verbose=False) + + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + assert "content_safety" in model_caches + cache = model_caches["content_safety"] + assert cache is not None + + assert cache.size() == 0 + + messages = [{"role": "user", "content": "Hello!"}] + await rails.generate_async(messages=messages) + + if cache.size() > 0: + initial_stats = cache.get_stats() + await rails.generate_async(messages=messages) + second_stats = cache.get_stats() + assert second_stats["hits"] >= initial_stats["hits"] + + +@pytest.mark.asyncio +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +async def test_cache_isolation_between_models(mock_init_llm_model): + mock_llm = FakeLLM(responses=["safe"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + ), + Model( + type="content_safety", + engine="fake", + model="fake-content-safety", + cache=ModelCacheConfig(enabled=True, maxsize=50), + ), + Model( + type="jailbreak_detection", + engine="fake", + model="fake-jailbreak", + cache=ModelCacheConfig(enabled=True, maxsize=100), + ), + ] + ) + + llm = FakeLLM(responses=["safe"]) + + rails = LLMRails(config=config, llm=llm, verbose=False) + + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + assert "content_safety" in model_caches + assert "jailbreak_detection" in model_caches + + content_safety_cache = model_caches["content_safety"] + jailbreak_cache = model_caches["jailbreak_detection"] + + assert content_safety_cache is not jailbreak_cache + assert content_safety_cache.capacity == 50 + assert jailbreak_cache.capacity == 100 + + content_safety_cache.put("key1", "value1") + assert content_safety_cache.get("key1") == "value1" + assert jailbreak_cache.get("key1") is None + + +@pytest.mark.asyncio +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +async def test_cache_disabled_for_main_model_in_integration(mock_init_llm_model): + mock_llm = FakeLLM(responses=["safe"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + cache=ModelCacheConfig(enabled=True, maxsize=100), + ), + Model( + type="content_safety", + engine="fake", + model="fake-content-safety", + cache=ModelCacheConfig(enabled=True, maxsize=100), + ), + ] + ) + + llm = FakeLLM(responses=["safe"]) + + rails = LLMRails(config=config, llm=llm, verbose=False) + + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + assert "main" not in model_caches or model_caches["main"] is None + assert "content_safety" in model_caches + assert model_caches["content_safety"] is not None diff --git a/tests/test_llmrails.py b/tests/test_llmrails.py index 9b8a2b300..0e87853d6 100644 --- a/tests/test_llmrails.py +++ b/tests/test_llmrails.py @@ -1187,3 +1187,188 @@ def test_explain_calls_ensure_explain_info(): info = rails.explain() assert info == ExplainInfo() assert rails.explain_info == ExplainInfo() + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_cache_initialization_disabled_by_default(mock_init_llm_model): + mock_llm = FakeLLM(responses=["response"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + ), + Model( + type="content_safety", + engine="fake", + model="fake", + ), + ] + ) + + rails = LLMRails(config=config, verbose=False) + model_caches = rails.runtime.registered_action_params.get("model_caches") + + assert model_caches is None or len(model_caches) == 0 + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_cache_initialization_with_enabled_cache(mock_init_llm_model): + from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig + + mock_llm = FakeLLM(responses=["response"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + ), + Model( + type="content_safety", + engine="fake", + model="fake", + cache=ModelCacheConfig( + enabled=True, + maxsize=1000, + stats=CacheStatsConfig(enabled=False), + ), + ), + ] + ) + + rails = LLMRails(config=config, verbose=False) + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + + assert "content_safety" in model_caches + assert model_caches["content_safety"] is not None + assert model_caches["content_safety"].capacity == 1000 + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_cache_not_created_for_main_and_embeddings_models(mock_init_llm_model): + from nemoguardrails.rails.llm.config import ModelCacheConfig + + mock_llm = FakeLLM(responses=["response"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + cache=ModelCacheConfig(enabled=True, maxsize=1000), + ), + Model( + type="embeddings", + engine="fake", + model="fake", + cache=ModelCacheConfig(enabled=True, maxsize=1000), + ), + ] + ) + + rails = LLMRails(config=config, verbose=False) + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + + assert "main" not in model_caches + assert "embeddings" not in model_caches + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_cache_initialization_with_zero_capacity_raises_error(mock_init_llm_model): + from nemoguardrails.rails.llm.config import ModelCacheConfig + + mock_llm = FakeLLM(responses=["response"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="content_safety", + engine="fake", + model="fake", + cache=ModelCacheConfig(enabled=True, maxsize=0), + ), + ] + ) + + with pytest.raises(ValueError, match="Invalid cache capacity"): + LLMRails(config=config, verbose=False) + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_cache_initialization_with_stats_enabled(mock_init_llm_model): + from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig + + mock_llm = FakeLLM(responses=["response"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="content_safety", + engine="fake", + model="fake", + cache=ModelCacheConfig( + enabled=True, + maxsize=5000, + stats=CacheStatsConfig(enabled=True, log_interval=60.0), + ), + ), + ] + ) + + rails = LLMRails(config=config, verbose=False) + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + + cache = model_caches["content_safety"] + assert cache is not None + assert cache.track_stats is True + assert cache.stats_logging_interval == 60.0 + assert cache.supports_stats_logging() is True + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_cache_initialization_with_multiple_models(mock_init_llm_model): + from nemoguardrails.rails.llm.config import ModelCacheConfig + + mock_llm = FakeLLM(responses=["response"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + ), + Model( + type="content_safety", + engine="fake", + model="fake", + cache=ModelCacheConfig(enabled=True, maxsize=1000), + ), + Model( + type="jailbreak_detection", + engine="fake", + model="fake", + cache=ModelCacheConfig(enabled=True, maxsize=2000), + ), + ] + ) + + rails = LLMRails(config=config, verbose=False) + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + + assert "main" not in model_caches + assert "content_safety" in model_caches + assert "jailbreak_detection" in model_caches + assert model_caches["content_safety"].capacity == 1000 + assert model_caches["jailbreak_detection"].capacity == 2000 diff --git a/tests/tracing/spans/test_span_extractors.py b/tests/tracing/spans/test_span_extractors.py index 9c9c85c05..9b2c24db1 100644 --- a/tests/tracing/spans/test_span_extractors.py +++ b/tests/tracing/spans/test_span_extractors.py @@ -25,6 +25,7 @@ SpanLegacy, create_span_extractor, ) +from nemoguardrails.tracing.constants import GuardrailsAttributes from nemoguardrails.tracing.spans import LLMSpan, is_opentelemetry_span @@ -181,6 +182,78 @@ def test_span_extractor_conversation_events(self, test_data): # Content not included by default (privacy) assert "final_transcript" not in user_event.body + def test_span_extractor_cache_hit_attribute(self): + """Test that cached LLM calls are marked with cache_hit typed field.""" + llm_call_cached = LLMCallInfo( + task="generate_user_intent", + prompt="What is the weather?", + completion="I cannot provide weather information.", + llm_model_name="gpt-4", + llm_provider_name="openai", + prompt_tokens=10, + completion_tokens=20, + total_tokens=30, + started_at=time.time(), + finished_at=time.time() + 0.001, + duration=0.001, + from_cache=True, + ) + + llm_call_not_cached = LLMCallInfo( + task="generate_bot_message", + prompt="Generate a response", + completion="Here is a response", + llm_model_name="gpt-3.5-turbo", + llm_provider_name="openai", + prompt_tokens=5, + completion_tokens=15, + total_tokens=20, + started_at=time.time(), + finished_at=time.time() + 1.0, + duration=1.0, + from_cache=False, + ) + + action = ExecutedAction( + action_name="test_action", + action_params={}, + llm_calls=[llm_call_cached, llm_call_not_cached], + started_at=time.time(), + finished_at=time.time() + 1.5, + duration=1.5, + ) + + rail = ActivatedRail( + type="input", + name="test_rail", + decisions=["continue"], + executed_actions=[action], + stop=False, + started_at=time.time(), + finished_at=time.time() + 2.0, + duration=2.0, + ) + + extractor = SpanExtractorV2() + spans = extractor.extract_spans([rail]) + + llm_spans = [s for s in spans if isinstance(s, LLMSpan)] + assert len(llm_spans) == 2 + + cached_span = next(s for s in llm_spans if "gpt-4" in s.name) + assert cached_span.cache_hit is True + + attributes = cached_span.to_otel_attributes() + assert GuardrailsAttributes.LLM_CACHE_HIT in attributes + assert attributes[GuardrailsAttributes.LLM_CACHE_HIT] is True + + not_cached_span = next(s for s in llm_spans if "gpt-3.5-turbo" in s.name) + assert not_cached_span.cache_hit is False + + attributes = not_cached_span.to_otel_attributes() + assert GuardrailsAttributes.LLM_CACHE_HIT in attributes + assert attributes[GuardrailsAttributes.LLM_CACHE_HIT] is False + class TestSpanFormatConfiguration: """Test span format configuration and factory.""" From 1bf9f63a03eacb516c1b0bfbcbacf20cfa643cf4 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Thu, 16 Oct 2025 21:03:04 +0200 Subject: [PATCH 10/19] revert vscode changes --- .vscode/settings.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index a19072329..38eb07063 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -46,7 +46,7 @@ "editor.defaultFormatter": "ms-python.black-formatter" }, "python.envFile": "${workspaceFolder}/.venv", - "python.languageServer": "None", + "python.languageServer": "Pylance", "python.testing.pytestEnabled": true, "python.testing.pytestArgs": [ "${workspaceFolder}/tests", @@ -55,6 +55,7 @@ "python.testing.unittestEnabled": false, //"python.envFile": "${workspaceFolder}/python_release.env", + // MYPY "mypy-type-checker.args": [ "--ignore-missing-imports", From 333d54867660d1d9635a37f08bb4bfe1d9142ae6 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Thu, 16 Oct 2025 18:47:58 +0200 Subject: [PATCH 11/19] move nemoguardrails.cache to nemoguardrails.llm --- .../library/content_safety/actions.py | 6 +++--- nemoguardrails/{ => llm}/cache/README.md | 0 nemoguardrails/{ => llm}/cache/__init__.py | 6 +++--- nemoguardrails/{ => llm}/cache/interface.py | 0 nemoguardrails/{ => llm}/cache/lfu.py | 2 +- nemoguardrails/{ => llm}/cache/utils.py | 2 +- nemoguardrails/rails/llm/llmrails.py | 2 +- tests/test_cache_lfu.py | 20 +++++++++---------- tests/test_cache_utils.py | 6 +++--- tests/test_content_safety_cache.py | 4 ++-- 10 files changed, 24 insertions(+), 24 deletions(-) rename nemoguardrails/{ => llm}/cache/README.md (100%) rename nemoguardrails/{ => llm}/cache/__init__.py (81%) rename nemoguardrails/{ => llm}/cache/interface.py (100%) rename nemoguardrails/{ => llm}/cache/lfu.py (99%) rename nemoguardrails/{ => llm}/cache/utils.py (98%) diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py index 93f5908c9..c1ac780b3 100644 --- a/nemoguardrails/library/content_safety/actions.py +++ b/nemoguardrails/library/content_safety/actions.py @@ -20,14 +20,14 @@ from nemoguardrails.actions.actions import action from nemoguardrails.actions.llm.utils import llm_call -from nemoguardrails.cache import CacheInterface -from nemoguardrails.cache.utils import ( +from nemoguardrails.context import llm_call_info_var +from nemoguardrails.llm.cache import CacheInterface +from nemoguardrails.llm.cache.utils import ( CacheEntry, create_normalized_cache_key, extract_llm_stats_for_cache, get_from_cache_and_restore_stats, ) -from nemoguardrails.context import llm_call_info_var from nemoguardrails.llm.taskmanager import LLMTaskManager from nemoguardrails.logging.explain import LLMCallInfo diff --git a/nemoguardrails/cache/README.md b/nemoguardrails/llm/cache/README.md similarity index 100% rename from nemoguardrails/cache/README.md rename to nemoguardrails/llm/cache/README.md diff --git a/nemoguardrails/cache/__init__.py b/nemoguardrails/llm/cache/__init__.py similarity index 81% rename from nemoguardrails/cache/__init__.py rename to nemoguardrails/llm/cache/__init__.py index 91042bcb1..0bb1aeb9f 100644 --- a/nemoguardrails/cache/__init__.py +++ b/nemoguardrails/llm/cache/__init__.py @@ -15,8 +15,8 @@ """General-purpose caching utilities for NeMo Guardrails.""" -from nemoguardrails.cache.interface import CacheInterface -from nemoguardrails.cache.lfu import LFUCache -from nemoguardrails.cache.utils import create_normalized_cache_key +from nemoguardrails.llm.cache.interface import CacheInterface +from nemoguardrails.llm.cache.lfu import LFUCache +from nemoguardrails.llm.cache.utils import create_normalized_cache_key __all__ = ["CacheInterface", "LFUCache", "create_normalized_cache_key"] diff --git a/nemoguardrails/cache/interface.py b/nemoguardrails/llm/cache/interface.py similarity index 100% rename from nemoguardrails/cache/interface.py rename to nemoguardrails/llm/cache/interface.py diff --git a/nemoguardrails/cache/lfu.py b/nemoguardrails/llm/cache/lfu.py similarity index 99% rename from nemoguardrails/cache/lfu.py rename to nemoguardrails/llm/cache/lfu.py index fb76c6b10..f1f61e667 100644 --- a/nemoguardrails/cache/lfu.py +++ b/nemoguardrails/llm/cache/lfu.py @@ -21,7 +21,7 @@ import time from typing import Any, Callable, Optional -from nemoguardrails.cache.interface import CacheInterface +from nemoguardrails.llm.cache.interface import CacheInterface log = logging.getLogger(__name__) diff --git a/nemoguardrails/cache/utils.py b/nemoguardrails/llm/cache/utils.py similarity index 98% rename from nemoguardrails/cache/utils.py rename to nemoguardrails/llm/cache/utils.py index 0291e2fa3..0f6818b53 100644 --- a/nemoguardrails/cache/utils.py +++ b/nemoguardrails/llm/cache/utils.py @@ -24,7 +24,7 @@ from nemoguardrails.logging.stats import LLMStats if TYPE_CHECKING: - from nemoguardrails.cache.interface import CacheInterface + from nemoguardrails.llm.cache.interface import CacheInterface PROMPT_PATTERN_WHITESPACES = re.compile(r"\s+") diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index c9c87b077..35e81bec1 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -50,7 +50,6 @@ ) from nemoguardrails.actions.output_mapping import is_output_blocked from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx -from nemoguardrails.cache import CacheInterface, LFUCache from nemoguardrails.colang import parse_colang_file from nemoguardrails.colang.v1_0.runtime.flows import _normalize_flow_id, compute_context from nemoguardrails.colang.v1_0.runtime.runtime import Runtime, RuntimeV1_0 @@ -71,6 +70,7 @@ from nemoguardrails.embeddings.providers import register_embedding_provider from nemoguardrails.embeddings.providers.base import EmbeddingModel from nemoguardrails.kb.kb import KnowledgeBase +from nemoguardrails.llm.cache import CacheInterface, LFUCache from nemoguardrails.llm.models.initializer import ( ModelInitializationError, init_llm_model, diff --git a/tests/test_cache_lfu.py b/tests/test_cache_lfu.py index 2b1d986f7..e137b7f1a 100644 --- a/tests/test_cache_lfu.py +++ b/tests/test_cache_lfu.py @@ -29,7 +29,7 @@ from typing import Any from unittest.mock import MagicMock, patch -from nemoguardrails.cache.lfu import LFUCache +from nemoguardrails.llm.cache.lfu import LFUCache class TestLFUCache(unittest.TestCase): @@ -395,7 +395,7 @@ def test_log_stats_now(self): cache.get("nonexistent") with patch.object( - logging.getLogger("nemoguardrails.cache.lfu"), "info" + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" ) as mock_log: cache.log_stats_now() @@ -425,7 +425,7 @@ def test_periodic_stats_logging(self): cache.put("key2", "value2") with patch.object( - logging.getLogger("nemoguardrails.cache.lfu"), "info" + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" ) as mock_log: # Initial operations shouldn't trigger logging cache.get("key1") @@ -461,7 +461,7 @@ def test_stats_logging_with_empty_cache(self): time.sleep(0.2) with patch.object( - logging.getLogger("nemoguardrails.cache.lfu"), "info" + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" ) as mock_log: # This will trigger stats logging with the previous miss already counted cache.get("another_nonexistent") # Trigger check @@ -490,7 +490,7 @@ def test_stats_logging_with_full_cache(self): cache.put("key4", "value4") with patch.object( - logging.getLogger("nemoguardrails.cache.lfu"), "info" + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" ) as mock_log: time.sleep(0.2) cache.get("key4") # Trigger check @@ -517,7 +517,7 @@ def test_stats_logging_high_hit_rate(self): cache.get("nonexistent") with patch.object( - logging.getLogger("nemoguardrails.cache.lfu"), "info" + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" ) as mock_log: cache.log_stats_now() @@ -537,7 +537,7 @@ def test_stats_logging_without_tracking(self): cache.get("key1") with patch.object( - logging.getLogger("nemoguardrails.cache.lfu"), "info" + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" ) as mock_log: cache.log_stats_now() @@ -552,7 +552,7 @@ def test_stats_logging_interval_timing(self): cache = LFUCache(5, track_stats=True, stats_logging_interval=1.0) with patch.object( - logging.getLogger("nemoguardrails.cache.lfu"), "info" + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" ) as mock_log: # Multiple operations within interval for i in range(10): @@ -582,7 +582,7 @@ def test_stats_logging_with_updates(self): cache.put("key1", "updated_again") # Another update with patch.object( - logging.getLogger("nemoguardrails.cache.lfu"), "info" + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" ) as mock_log: cache.log_stats_now() @@ -621,7 +621,7 @@ def test_stats_log_format_percentages(self): cache.get(f"miss_key_{i}") with patch.object( - logging.getLogger("nemoguardrails.cache.lfu"), "info" + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" ) as mock_log: cache.log_stats_now() diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 3e0d98a57..5e824074d 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -17,14 +17,14 @@ import pytest -from nemoguardrails.cache.lfu import LFUCache -from nemoguardrails.cache.utils import ( +from nemoguardrails.context import llm_call_info_var, llm_stats_var +from nemoguardrails.llm.cache.lfu import LFUCache +from nemoguardrails.llm.cache.utils import ( create_normalized_cache_key, extract_llm_stats_for_cache, get_from_cache_and_restore_stats, restore_llm_stats_from_cache, ) -from nemoguardrails.context import llm_call_info_var, llm_stats_var from nemoguardrails.logging.explain import LLMCallInfo from nemoguardrails.logging.stats import LLMStats diff --git a/tests/test_content_safety_cache.py b/tests/test_content_safety_cache.py index d81465b3d..5261f0251 100644 --- a/tests/test_content_safety_cache.py +++ b/tests/test_content_safety_cache.py @@ -17,10 +17,10 @@ import pytest -from nemoguardrails.cache.lfu import LFUCache -from nemoguardrails.cache.utils import create_normalized_cache_key from nemoguardrails.context import llm_call_info_var, llm_stats_var from nemoguardrails.library.content_safety.actions import content_safety_check_input +from nemoguardrails.llm.cache.lfu import LFUCache +from nemoguardrails.llm.cache.utils import create_normalized_cache_key from nemoguardrails.logging.explain import LLMCallInfo from nemoguardrails.logging.stats import LLMStats from tests.utils import FakeLLM From 5494d6d147486438f7f32213ba1cfd8cb8153631 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Thu, 16 Oct 2025 18:51:00 +0200 Subject: [PATCH 12/19] fix api --- nemoguardrails/llm/cache/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemoguardrails/llm/cache/__init__.py b/nemoguardrails/llm/cache/__init__.py index 0bb1aeb9f..9530009da 100644 --- a/nemoguardrails/llm/cache/__init__.py +++ b/nemoguardrails/llm/cache/__init__.py @@ -19,4 +19,4 @@ from nemoguardrails.llm.cache.lfu import LFUCache from nemoguardrails.llm.cache.utils import create_normalized_cache_key -__all__ = ["CacheInterface", "LFUCache", "create_normalized_cache_key"] +__all__ = ["CacheInterface", "LFUCache"] From c3aff64fcb191e61ddf9f8094a114cef6aa6f5b7 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Thu, 16 Oct 2025 21:04:51 +0200 Subject: [PATCH 13/19] remove README.md from cache package --- nemoguardrails/llm/cache/README.md | 223 ----------------------------- 1 file changed, 223 deletions(-) delete mode 100644 nemoguardrails/llm/cache/README.md diff --git a/nemoguardrails/llm/cache/README.md b/nemoguardrails/llm/cache/README.md deleted file mode 100644 index dfa18c036..000000000 --- a/nemoguardrails/llm/cache/README.md +++ /dev/null @@ -1,223 +0,0 @@ -# Content Safety LLM Call Caching - -## Overview - -The content safety checks in `actions.py` now use an LFU (Least Frequently Used) cache to improve performance by avoiding redundant LLM calls for identical safety checks. - -## Implementation Details - -### Cache Configuration - -- Per-model caches: Each model gets its own LFU cache instance -- Default capacity: 50,000 entries per model -- Eviction policy: LFU with LRU tiebreaker -- Statistics tracking: Disabled by default (configurable) -- Tracks timestamps: `created_at` and `accessed_at` for each entry -- Cache creation: Automatic when a model is initialized with cache enabled -- Supported model types: Any non-`main` and non-`embeddings` model type (typically content safety models) - -### Cached Functions - -`content_safety_check_input()` - Caches safety checks for user inputs - -### Cache Key Components - -The cache key is generated from: - -- The rendered prompt (normalized for whitespace) - -Since temperature is fixed (1e-20) and stop/max_tokens are derived from the model configuration, they don't need to be part of the cache key. - -### How It Works - -1. **Before LLM Call**: - - Generate cache key from request parameters - - Check if result exists in cache - - If found, return cached result (cache hit) - -2. **After LLM Call**: - - If not in cache, make the actual LLM call - - Store the result in cache for future use - -### Cache Management - -The caching system automatically creates and manages separate caches for each model. Key features: - -- **Automatic Creation**: Caches are created when the model is initialized with cache configuration -- **Isolated Storage**: Each model maintains its own cache, preventing cross-model interference -- **Default Settings**: Each cache has 50,000 entry capacity (configurable) -- **Per-Model Configuration**: Cache is configured per model in the YAML configuration - -### Statistics and Monitoring - -The cache supports detailed statistics tracking and periodic logging for monitoring cache performance: - -```yaml -models: - - type: content_safety - engine: nim - model: nvidia/llama-3.1-nemoguard-8b-content-safety - cache: - enabled: true - capacity_per_model: 10000 - store: memory # Currently only 'memory' is supported - stats: - enabled: true # Enable stats tracking - log_interval: 60.0 # Log stats every minute -``` - -**Statistics Features:** - -1. **Tracking Only**: Set `stats.enabled: true` with no `log_interval` to track stats without logging -2. **Automatic Logging**: Set both `stats.enabled: true` and `log_interval` for periodic logging - -**Statistics Tracked:** - -- **Hits**: Number of cache hits (successful lookups) -- **Misses**: Number of cache misses (failed lookups) -- **Hit Rate**: Percentage of requests served from cache -- **Evictions**: Number of items removed due to capacity -- **Puts**: Number of new items added to cache -- **Updates**: Number of existing items updated -- **Current Size**: Number of items currently in cache - -**Log Format:** - -``` -LFU Cache Statistics - Size: 2456/10000 | Hits: 15234 | Misses: 2456 | Hit Rate: 86.11% | Evictions: 0 | Puts: 2456 | Updates: 0 -``` - -**Usage Examples:** - -The cache is managed internally by the NeMo Guardrails framework. When you configure a model with caching enabled, the framework automatically: - -1. Creates an LFU cache instance for that model -2. Passes the cache to content safety actions via kwargs -3. Tracks statistics if configured -4. Logs statistics at the specified interval - -**Configuration Options:** - -- `stats.enabled`: Enable/disable statistics tracking (default: false) -- `stats.log_interval`: Seconds between automatic stats logs (None = no logging) - -**Notes:** - -- Stats logging requires stats tracking to be enabled -- Logs appear at INFO level in the `nemoguardrails.cache.lfu` logger -- Stats are reset when cache is cleared or when `reset_stats()` is called -- Each model maintains independent statistics - -### Example Configuration - -```yaml -# config.yml -models: - - type: main - engine: openai - model: gpt-4 - - - type: content_safety - engine: nim - model: nvidia/llama-3.1-nemoguard-8b-content-safety - cache: - enabled: true - capacity_per_model: 50000 - store: memory - stats: - enabled: true - log_interval: 300.0 # Log stats every 5 minutes - -rails: - input: - flows: - - content safety check input model="content_safety" -``` - -### Example Usage - -```python -from nemoguardrails import RailsConfig, LLMRails - -# The cache is automatically configured based on your YAML config -config = RailsConfig.from_path("./config.yml") -rails = LLMRails(config) - -# Content safety checks will be cached automatically -response = await rails.generate_async( - messages=[{"role": "user", "content": "Hello, how are you?"}] -) -``` - -### Thread Safety - -The content safety caching system is **thread-safe** for single-node deployments: - -1. **LFUCache Implementation**: - - Uses `threading.RLock` for all operations - - All public methods (`get`, `put`, `size`, `clear`, etc.) are protected by locks - - Supports atomic `get_or_compute()` operations that prevent duplicate computations - -2. **LLMRails Model Initialization**: - - Thread-safe cache creation during model initialization - - Ensures only one cache instance per model across all threads - -3. **Key Features**: - - **No Data Corruption**: Concurrent operations maintain data integrity - - **No Race Conditions**: Proper locking prevents race conditions - - **Atomic Operations**: `get_or_compute()` ensures expensive computations happen only once - - **Minimal Lock Contention**: Efficient locking patterns minimize performance impact - -4. **Usage in Web Servers**: - - Safe for use in multi-threaded web servers (FastAPI, Flask, etc.) - - Handles concurrent requests without issues - - Each thread sees consistent cache state - -**Note**: This implementation is designed for single-node deployments. For distributed systems, consider using external caching solutions like Redis. - -### Benefits - -1. **Performance**: Eliminates redundant LLM calls for identical inputs -2. **Cost Savings**: Reduces API calls to LLM services -3. **Consistency**: Ensures identical inputs always produce identical outputs -4. **Smart Eviction**: LFU policy keeps frequently checked content in cache -5. **Model Isolation**: Each model has its own cache, preventing interference between different safety models -6. **Statistics Tracking**: Monitor cache performance with hit rates, evictions, and more per model -7. **Timestamp Tracking**: Track when entries were created and last accessed -8. **Efficiency**: LFU eviction algorithm ensures the most useful entries remain in cache -9. **Thread Safety**: Safe for concurrent access in multi-threaded environments - -### Example Usage Pattern - -```python -# First call - takes ~500ms (LLM API call) -result = await content_safety_check_input( - llms=llms, - llm_task_manager=task_manager, - model_name="content_safety", - context={"user_message": "Hello world"} -) - -# Subsequent identical calls - takes ~1ms (cache hit) -result = await content_safety_check_input( - llms=llms, - llm_task_manager=task_manager, - model_name="content_safety", - context={"user_message": "Hello world"} -) -``` - -### Logging - -The implementation includes debug logging: - -- Cache creation: `"Created cache for model '{model_name}' with capacity {capacity}"` -- Cache hits: `"Content safety cache hit for model '{model_name}'"` -- Cache stores: `"Content safety result cached for model '{model_name}'"` - -Enable debug logging to monitor cache behavior: - -```python -import logging -logging.getLogger("nemoguardrails.library.content_safety.actions").setLevel(logging.DEBUG) -``` From 893d9a168b2b3f7d485e99f832509308618f24ee Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Thu, 16 Oct 2025 21:12:41 +0200 Subject: [PATCH 14/19] revert content_safety config --- examples/configs/content_safety/README.md | 151 +------------------- examples/configs/content_safety/config.yml | 10 -- examples/configs/content_safety/prompts.yml | 1 + 3 files changed, 5 insertions(+), 157 deletions(-) diff --git a/examples/configs/content_safety/README.md b/examples/configs/content_safety/README.md index 2194fa571..35a2d2a45 100644 --- a/examples/configs/content_safety/README.md +++ b/examples/configs/content_safety/README.md @@ -1,153 +1,10 @@ -# Content Safety Configuration +# NemoGuard ContentSafety Usage Example -This example demonstrates how to configure content safety rails with NeMo Guardrails, from basic setup to advanced per-model configurations. - -## Features - -- **Input Safety Checks**: Validates user inputs before processing -- **Output Safety Checks**: Ensures bot responses are appropriate -- **Thread Safety**: Fully thread-safe for use in multi-threaded web servers -- **Per-Model Caching**: Optional caching with configurable settings per model -- **Multiple Models**: Support for different content safety models with different configurations - -## Folder Structure +This example showcases the use of NVIDIA's [NemoGuard ContentSafety model](./../../../docs/user-guides/advanced/nemoguard-contentsafety-deployment.md) for topical and dialogue moderation. The structure of the config folder is the following: -- `config.yml` - The main configuration file with model definitions, rails configuration, and cache settings -- `prompts.yml` - Contains the content safety prompt templates used by the safety models to evaluate content - -## Configuration Overview - -### Basic Configuration - -The simplest configuration uses a single content safety model: - -```yaml -models: - - type: main - engine: nim - model: meta/llama-3.3-70b-instruct - - - type: content_safety - engine: nim - model: nvidia/llama-3.1-nemoguard-8b-content-safety - -rails: - input: - flows: - - content safety check input $model=content_safety - output: - flows: - - content safety check output $model=content_safety -``` - -### Advanced Configuration with Per-Model Caching - -For production environments, you can configure caching per model: - -```yaml -models: - - type: content_safety - engine: nim - model: nvidia/llama-3.1-nemoguard-8b-content-safety - cache: - enabled: true - capacity_per_model: 50000 # Larger cache for primary model - stats: - enabled: true - log_interval: 60.0 # Log stats every 60 seconds - - - type: llama_guard - engine: vllm_openai - model: meta-llama/Llama-Guard-7b - cache: - enabled: true - capacity_per_model: 25000 # Medium cache - stats: - enabled: false # No stats for this model -``` - -## How It Works - -1. **User Input**: When a user sends a message, it's checked by the content safety model(s) -2. **Safety Evaluation**: The content safety model evaluates the input -3. **Caching** (optional): Results are cached to avoid duplicate API calls -4. **Response Generation**: If safe, the main model generates a response -5. **Output Check**: The response is also checked for safety before returning to the user - -## Cache Configuration Options - -### Default Behavior (No Caching) - -By default, caching is **disabled**. Models without cache configuration will have no caching: - -```yaml -models: - - type: content_safety - engine: nim - model: nvidia/llama-3.1-nemoguard-8b-content-safety - # No cache config = no caching (default) -``` - -### Enabling Cache - -Add cache configuration to any model definition: - -```yaml -cache: - enabled: true # Enable caching - capacity_per_model: 10000 # Cache capacity (number of entries) - store: "memory" # Cache storage type (currently only memory) - stats: - enabled: true # Enable statistics tracking - log_interval: 300.0 # Log stats every 5 minutes (optional) -``` - -## Architecture - -Each content safety model gets its own dedicated cache instance, providing: - -- **Isolated cache management** per model -- **Different cache capacities** for different models -- **Model-specific performance tuning** -- **Thread-safe concurrent access** - -## Thread Safety - -The content safety implementation is fully thread-safe: - -- **Concurrent Requests**: Safely handles multiple simultaneous safety checks -- **Efficient Locking**: Uses RLock for minimal performance impact -- **Atomic Operations**: Prevents duplicate LLM calls for the same content - -This makes it suitable for: - -- Multi-threaded web servers (FastAPI, Flask, Django) -- Concurrent request processing -- High-traffic applications - -## Running the Example - -```bash -# From the NeMo-Guardrails root directory -nemoguardrails server --config examples/configs/content_safety/ -``` +- `config.yml` - The config file holding all the configuration options for the model. +- `prompts.yml` - The config file holding the topical rules used for topical and dialogue moderation by the current guardrail configuration. Please see the docs for more details about the [recommended ContentSafety deployment](./../../../docs/user-guides/advanced/nemoguard-contentsafety-deployment.md) methods, either using locally downloaded NIMs or NVIDIA AI Enterprise (NVAIE). - -## Benefits - -1. **Performance**: Avoid redundant content safety API calls -2. **Cost Savings**: Reduce API usage for repeated content -3. **Flexibility**: Enable caching only for models that benefit from it -4. **Clean Architecture**: Each model has its own dedicated cache -5. **Scalability**: Easy to add new models with different caching strategies - -## Tips - -- Start with no caching to establish baseline performance -- Enable caching for frequently-used models first -- Use stats logging to monitor cache effectiveness -- Adjust cache capacity based on your usage patterns -- Consider different cache sizes for different models based on their usage diff --git a/examples/configs/content_safety/config.yml b/examples/configs/content_safety/config.yml index 8e6e4a59c..f6808bf14 100644 --- a/examples/configs/content_safety/config.yml +++ b/examples/configs/content_safety/config.yml @@ -6,21 +6,11 @@ models: - type: content_safety engine: nim model: nvidia/llama-3.1-nemoguard-8b-content-safety - # Model-specific cache configuration (optional) - cache: - enabled: true - maxsize: 50000 # Larger cache for primary model - stats: - enabled: true - log_interval: 60.0 # Log stats every minute rails: input: flows: - # You can use multiple content safety models - content safety check input $model=content_safety output: flows: - content safety check output $model=content_safety -tracing: - enabled: True diff --git a/examples/configs/content_safety/prompts.yml b/examples/configs/content_safety/prompts.yml index dfd8b45a8..1321a6461 100644 --- a/examples/configs/content_safety/prompts.yml +++ b/examples/configs/content_safety/prompts.yml @@ -1,5 +1,6 @@ # These are the default prompts released by Meta, except for policy O7, which was added to address direct insults. prompts: + - task: content_safety_check_input $model=content_safety content: | Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below. From af5aee988fd2707fbf52756c7fd653d82a4fbb68 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 17 Oct 2025 09:19:49 +0200 Subject: [PATCH 15/19] remove main from lfu module --- nemoguardrails/llm/cache/lfu.py | 51 --------------------------------- 1 file changed, 51 deletions(-) diff --git a/nemoguardrails/llm/cache/lfu.py b/nemoguardrails/llm/cache/lfu.py index f1f61e667..553ecdaee 100644 --- a/nemoguardrails/llm/cache/lfu.py +++ b/nemoguardrails/llm/cache/lfu.py @@ -468,54 +468,3 @@ def contains(self, key: Any) -> bool: def capacity(self) -> int: """Get the maximum capacity of the cache.""" return self._capacity - - -# Example usage and testing -if __name__ == "__main__": - print("=== Basic LFU Cache Example ===") - # Create a basic LFU cache - cache = LFUCache(3) - - cache.put("a", 1) - cache.put("b", 2) - cache.put("c", 3) - - print(f"Get 'a': {cache.get('a')}") # Returns 1, frequency of 'a' becomes 2 - print(f"Get 'b': {cache.get('b')}") # Returns 2, frequency of 'b' becomes 2 - - cache.put("d", 4) # Evicts 'c' (least frequently used) - - print(f"Get 'c': {cache.get('c', 'Not found')}") # Returns 'Not found' - print(f"Get 'd': {cache.get('d')}") # Returns 4 - print(f"Cache size: {cache.size()}") # Returns 3 - - print("\n=== Cache with Statistics Tracking ===") - - # Create cache with statistics tracking - stats_cache = LFUCache(capacity=5, track_stats=True) - - # Add some items - for i in range(6): - stats_cache.put(f"key{i}", f"value{i}") - - # Access some items to change frequencies - for _ in range(3): - stats_cache.get("key4") # Increase frequency - stats_cache.get("key5") # Increase frequency - - # Some cache misses - stats_cache.get("nonexistent1") - stats_cache.get("nonexistent2") - - # Check statistics - print(f"\nCache statistics: {stats_cache.get_stats()}") - - # Update existing key - stats_cache.put("key4", "updated_value4") - - # Check updated statistics - print(f"\nUpdated statistics: {stats_cache.get_stats()}") - - # Reset statistics - stats_cache.reset_stats() - print(f"\nAfter reset: {stats_cache.get_stats()}") From a0d5c43d2adf524d15bc11e7c0443b0bc57894dd Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 17 Oct 2025 09:28:36 +0200 Subject: [PATCH 16/19] improve coverage for missing utils functionality --- nemoguardrails/llm/cache/utils.py | 2 +- tests/test_cache_utils.py | 61 +++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/nemoguardrails/llm/cache/utils.py b/nemoguardrails/llm/cache/utils.py index 0f6818b53..1cf02fe4d 100644 --- a/nemoguardrails/llm/cache/utils.py +++ b/nemoguardrails/llm/cache/utils.py @@ -146,7 +146,7 @@ def get_from_cache_and_restore_stats( restore_llm_stats_from_cache(cached_stats, cache_read_duration) processing_log = processing_log_var.get() - if processing_log: + if processing_log is not None: llm_call_info = llm_call_info_var.get() if llm_call_info: processing_log.append( diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 5e824074d..b7b1180e5 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -26,6 +26,7 @@ restore_llm_stats_from_cache, ) from nemoguardrails.logging.explain import LLMCallInfo +from nemoguardrails.logging.processing_log import processing_log_var from nemoguardrails.logging.stats import LLMStats @@ -313,3 +314,63 @@ def test_get_from_cache_and_restore_stats_without_llm_stats(self): llm_call_info_var.set(None) llm_stats_var.set(None) + + def test_get_from_cache_and_restore_stats_with_processing_log(self): + cache = LFUCache(capacity=10) + cache_entry = { + "result": {"allowed": True, "policy_violations": []}, + "llm_stats": { + "total_tokens": 80, + "prompt_tokens": 60, + "completion_tokens": 20, + }, + } + cache.put("test_key", cache_entry) + + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info_var.set(llm_call_info) + llm_stats_var.set(None) + + processing_log = [] + processing_log_var.set(processing_log) + + result = get_from_cache_and_restore_stats(cache, "test_key") + + assert result is not None + assert result == {"allowed": True, "policy_violations": []} + + retrieved_log = processing_log_var.get() + assert len(retrieved_log) == 1 + assert retrieved_log[0]["type"] == "llm_call_info" + assert "timestamp" in retrieved_log[0] + assert "data" in retrieved_log[0] + assert retrieved_log[0]["data"] == llm_call_info + + llm_call_info_var.set(None) + llm_stats_var.set(None) + processing_log_var.set(None) + + def test_get_from_cache_and_restore_stats_without_processing_log(self): + cache = LFUCache(capacity=10) + cache_entry = { + "result": {"allowed": True, "policy_violations": []}, + "llm_stats": { + "total_tokens": 50, + "prompt_tokens": 30, + "completion_tokens": 20, + }, + } + cache.put("test_key", cache_entry) + + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info_var.set(llm_call_info) + llm_stats_var.set(None) + processing_log_var.set(None) + + result = get_from_cache_and_restore_stats(cache, "test_key") + + assert result is not None + assert result == {"allowed": True, "policy_violations": []} + + llm_call_info_var.set(None) + llm_stats_var.set(None) From f7feb5cc03c9cfe50096d836a903c17d209fdf2c Mon Sep 17 00:00:00 2001 From: Pouyan <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 17 Oct 2025 09:31:38 +0200 Subject: [PATCH 17/19] Apply suggestion from @Pouyanpi Signed-off-by: Pouyan <13303554+Pouyanpi@users.noreply.github.com> --- nemoguardrails/llm/cache/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemoguardrails/llm/cache/__init__.py b/nemoguardrails/llm/cache/__init__.py index 9530009da..de2a31c6b 100644 --- a/nemoguardrails/llm/cache/__init__.py +++ b/nemoguardrails/llm/cache/__init__.py @@ -17,6 +17,5 @@ from nemoguardrails.llm.cache.interface import CacheInterface from nemoguardrails.llm.cache.lfu import LFUCache -from nemoguardrails.llm.cache.utils import create_normalized_cache_key __all__ = ["CacheInterface", "LFUCache"] From 98865d5b390a684bddb8f20a1d3e9f0b9d12288d Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 17 Oct 2025 09:36:26 +0200 Subject: [PATCH 18/19] rename capacity to maxsize --- nemoguardrails/llm/cache/interface.py | 8 ++--- nemoguardrails/llm/cache/lfu.py | 28 ++++++++--------- nemoguardrails/rails/llm/llmrails.py | 6 ++-- tests/test_cache_lfu.py | 44 +++++++++++++-------------- tests/test_cache_utils.py | 10 +++--- tests/test_content_safety_cache.py | 8 ++--- tests/test_integration_cache.py | 4 +-- tests/test_llmrails.py | 10 +++--- 8 files changed, 59 insertions(+), 59 deletions(-) diff --git a/nemoguardrails/llm/cache/interface.py b/nemoguardrails/llm/cache/interface.py index f07040098..74457913b 100644 --- a/nemoguardrails/llm/cache/interface.py +++ b/nemoguardrails/llm/cache/interface.py @@ -51,7 +51,7 @@ def put(self, key: Any, value: Any) -> None: """ Store an item in the cache. - If the cache is at capacity, this method should evict an item + If the cache is at maxsize, this method should evict an item according to the cache's eviction policy (e.g., LFU, LRU, etc.). Args: @@ -108,9 +108,9 @@ def contains(self, key: Any) -> bool: @property @abstractmethod - def capacity(self) -> int: + def maxsize(self) -> int: """ - Get the maximum capacity of the cache. + Get the maximum size of the cache. Returns: The maximum number of items the cache can hold. @@ -129,7 +129,7 @@ def get_stats(self) -> dict: - evictions: Number of items evicted - hit_rate: Percentage of requests that were hits - current_size: Current number of items in cache - - capacity: Maximum capacity of the cache + - maxsize: Maximum size of the cache The default implementation returns a message indicating that statistics tracking is not supported. diff --git a/nemoguardrails/llm/cache/lfu.py b/nemoguardrails/llm/cache/lfu.py index 553ecdaee..973224757 100644 --- a/nemoguardrails/llm/cache/lfu.py +++ b/nemoguardrails/llm/cache/lfu.py @@ -78,13 +78,13 @@ class LFUCache(CacheInterface): """ Least Frequently Used (LFU) Cache implementation. - When the cache reaches capacity, it evicts the least frequently used item. + When the cache reaches maxsize, it evicts the least frequently used item. If there are ties in frequency, it evicts the least recently used among them. """ def __init__( self, - capacity: int, + maxsize: int, track_stats: bool = False, stats_logging_interval: Optional[float] = None, ) -> None: @@ -92,14 +92,14 @@ def __init__( Initialize the LFU cache. Args: - capacity: Maximum number of items the cache can hold + maxsize: Maximum number of items the cache can hold track_stats: Enable tracking of cache statistics stats_logging_interval: Seconds between periodic stats logging (None disables logging) """ - if capacity < 0: + if maxsize < 0: raise ValueError("Capacity must be non-negative") - self._capacity = capacity + self._maxsize = maxsize self.track_stats = track_stats self._lock = threading.RLock() # Thread-safe access self._computing: dict[Any, asyncio.Future] = {} # Track keys being computed @@ -187,7 +187,7 @@ def put(self, key: Any, value: Any) -> None: # Check if we should log stats self._check_and_log_stats() - if self._capacity == 0: + if self._maxsize == 0: return if key in self.key_map: @@ -200,7 +200,7 @@ def put(self, key: Any, value: Any) -> None: self.stats["updates"] += 1 else: # Add new key - if len(self.key_map) >= self._capacity: + if len(self.key_map) >= self._maxsize: # Need to evict least frequently used item self._evict_lfu() @@ -268,7 +268,7 @@ def get_stats(self) -> dict: stats = self.stats.copy() stats["current_size"] = len(self.key_map) # Direct access within lock - stats["capacity"] = self._capacity + stats["maxsize"] = self._maxsize # Calculate hit rate total_requests = stats["hits"] + stats["misses"] @@ -313,7 +313,7 @@ def _log_stats(self) -> None: # Format the log message log_msg = ( f"LFU Cache Statistics - " - f"Size: {stats['current_size']}/{stats['capacity']} | " + f"Size: {stats['current_size']}/{stats['maxsize']} | " f"Hits: {stats['hits']} | " f"Misses: {stats['misses']} | " f"Hit Rate: {stats['hit_rate']:.2%} | " @@ -416,12 +416,12 @@ async def get_or_compute( return node.value # Now add to cache using internal logic - if self._capacity == 0: + if self._maxsize == 0: future.set_result(computed_value) return computed_value # Add new key - if len(self.key_map) >= self._capacity: + if len(self.key_map) >= self._maxsize: self._evict_lfu() # Create new node and add to cache @@ -465,6 +465,6 @@ def contains(self, key: Any) -> bool: return key in self.key_map @property - def capacity(self) -> int: - """Get the maximum capacity of the cache.""" - return self._capacity + def maxsize(self) -> int: + """Get the maximum size of the cache.""" + return self._maxsize diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 35e81bec1..d811464eb 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -542,7 +542,7 @@ def _create_model_cache(self, model) -> LFUCache: if model.cache.maxsize <= 0: raise ValueError( - f"Invalid cache capacity for model '{model.type}': {model.cache.maxsize}. " + f"Invalid cache maxsize for model '{model.type}': {model.cache.maxsize}. " "Capacity must be greater than 0. Skipping cache creation." ) @@ -551,13 +551,13 @@ def _create_model_cache(self, model) -> LFUCache: stats_logging_interval = model.cache.stats.log_interval cache = LFUCache( - capacity=model.cache.maxsize, + maxsize=model.cache.maxsize, track_stats=model.cache.stats.enabled, stats_logging_interval=stats_logging_interval, ) log.info( - f"Created cache for model '{model.type}' with capacity {model.cache.maxsize}" + f"Created cache for model '{model.type}' with maxsize {model.cache.maxsize}" ) return cache diff --git a/tests/test_cache_lfu.py b/tests/test_cache_lfu.py index e137b7f1a..b5bc5acda 100644 --- a/tests/test_cache_lfu.py +++ b/tests/test_cache_lfu.py @@ -17,7 +17,7 @@ Comprehensive test suite for LFU Cache implementation. Tests all functionality including basic operations, eviction policies, -capacity management, and edge cases. +maxsize management, and edge cases. """ import asyncio @@ -41,16 +41,16 @@ def setUp(self): def test_initialization(self): """Test cache initialization with various capacities.""" - # Normal capacity + # Normal maxsize cache = LFUCache(5) self.assertEqual(cache.size(), 0) self.assertTrue(cache.is_empty()) - # Zero capacity + # Zero maxsize cache_zero = LFUCache(0) self.assertEqual(cache_zero.size(), 0) - # Negative capacity should raise error + # Negative maxsize should raise error with self.assertRaises(ValueError): LFUCache(-1) @@ -165,7 +165,7 @@ def test_complex_eviction_scenario(self): # - 'f' (freq 1) - just added # - 'e' (freq 1) was evicted as it was least recently used among freq 1 items - # Check that we're at capacity + # Check that we're at maxsize self.assertEqual(cache.size(), 4) # 'a' should definitely still be there (highest frequency) @@ -181,8 +181,8 @@ def test_complex_eviction_scenario(self): # 'e' should have been evicted (freq 1, LRU among freq 1 items) self.assertIsNone(cache.get("e")) - def test_zero_capacity_cache(self): - """Test cache with zero capacity.""" + def test_zero_maxsize_cache(self): + """Test cache with zero maxsize.""" cache = LFUCache(0) # Put should not store anything @@ -258,17 +258,17 @@ def test_none_values(self): # Verify key exists self.assertEqual(self.cache.size(), 1) - def test_size_and_capacity(self): - """Test size tracking and capacity limits.""" + def test_size_and_maxsize(self): + """Test size tracking and maxsize limits.""" # Start empty self.assertEqual(self.cache.size(), 0) - # Add items up to capacity + # Add items up to maxsize for i in range(3): self.cache.put(f"key{i}", f"value{i}") self.assertEqual(self.cache.size(), i + 1) - # Add more items - size should stay at capacity + # Add more items - size should stay at maxsize for i in range(3, 10): self.cache.put(f"key{i}", f"value{i}") self.assertEqual(self.cache.size(), 3) @@ -360,7 +360,7 @@ def test_interface_methods_exist(self): self.assertTrue(callable(getattr(cache, "clear", None))) # Check property - self.assertEqual(cache.capacity, 5) + self.assertEqual(cache.maxsize, 5) class TestLFUCacheStatsLogging(unittest.TestCase): @@ -475,7 +475,7 @@ def test_stats_logging_with_empty_cache(self): self.assertIn("Hit Rate: 0.00%", log_message) def test_stats_logging_with_full_cache(self): - """Test stats logging when cache is at capacity.""" + """Test stats logging when cache is at maxsize.""" import logging from unittest.mock import patch @@ -642,7 +642,7 @@ def test_cache_config_with_stats_disabled(self): ) cache = LFUCache( - capacity=cache_config.maxsize, + maxsize=cache_config.maxsize, track_stats=cache_config.stats.enabled, stats_logging_interval=None, ) @@ -662,7 +662,7 @@ def test_cache_config_with_stats_tracking_only(self): ) cache = LFUCache( - capacity=cache_config.maxsize, + maxsize=cache_config.maxsize, track_stats=cache_config.stats.enabled, stats_logging_interval=cache_config.stats.log_interval, ) @@ -683,7 +683,7 @@ def test_cache_config_with_stats_logging(self): ) cache = LFUCache( - capacity=cache_config.maxsize, + maxsize=cache_config.maxsize, track_stats=cache_config.stats.enabled, stats_logging_interval=cache_config.stats.log_interval, ) @@ -700,7 +700,7 @@ def test_cache_config_default_stats(self): cache_config = ModelCacheConfig(enabled=True) cache = LFUCache( - capacity=cache_config.maxsize, + maxsize=cache_config.maxsize, track_stats=cache_config.stats.enabled, stats_logging_interval=None, ) @@ -826,7 +826,7 @@ def worker(thread_id): for future in futures: future.result() - # Cache should still be at capacity + # Cache should still be at maxsize self.assertEqual(small_cache.size(), 10) def test_concurrent_clear_operations(self): @@ -1005,7 +1005,7 @@ def worker(thread_id): size = self.cache.size() is_empty = self.cache.is_empty() - # Size should never be negative or exceed capacity + # Size should never be negative or exceed maxsize if size < 0 or size > 100: results.append(f"Invalid size: {size}") @@ -1025,7 +1025,7 @@ def worker(thread_id): def test_concurrent_contains_operations(self): """Test thread safety of contains method.""" # Use a larger cache to avoid evictions during the test - # Need capacity for: 50 existing + (5 threads × 100 new keys) = 550+ + # Need maxsize for: 50 existing + (5 threads × 100 new keys) = 550+ large_cache = LFUCache(1000, track_stats=True) # Pre-populate cache @@ -1149,7 +1149,7 @@ async def run_test(): asyncio.run(run_test()) def test_concurrent_operations_with_evictions(self): - """Test thread safety when cache is at capacity and evictions occur.""" + """Test thread safety when cache is at maxsize and evictions occur.""" # Small cache to force evictions small_cache = LFUCache(50, track_stats=True) data_integrity_errors = [] @@ -1192,7 +1192,7 @@ def worker(thread_id): f"Data integrity errors: {data_integrity_errors}", ) - # Cache should be at capacity + # Cache should be at maxsize self.assertEqual(small_cache.size(), 50) # Stats should show many evictions diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index b7b1180e5..7201e8f0f 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -253,7 +253,7 @@ def test_restore_llm_stats_from_cache_updates_llm_call_info(self): llm_stats_var.set(None) def test_get_from_cache_and_restore_stats_cache_miss(self): - cache = LFUCache(capacity=10) + cache = LFUCache(maxsize=10) llm_call_info_var.set(None) llm_stats_var.set(None) @@ -265,7 +265,7 @@ def test_get_from_cache_and_restore_stats_cache_miss(self): llm_stats_var.set(None) def test_get_from_cache_and_restore_stats_cache_hit(self): - cache = LFUCache(capacity=10) + cache = LFUCache(maxsize=10) cache_entry = { "result": {"allowed": True, "policy_violations": []}, "llm_stats": { @@ -297,7 +297,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self): llm_stats_var.set(None) def test_get_from_cache_and_restore_stats_without_llm_stats(self): - cache = LFUCache(capacity=10) + cache = LFUCache(maxsize=10) cache_entry = { "result": {"allowed": False, "policy_violations": ["policy1"]}, "llm_stats": None, @@ -316,7 +316,7 @@ def test_get_from_cache_and_restore_stats_without_llm_stats(self): llm_stats_var.set(None) def test_get_from_cache_and_restore_stats_with_processing_log(self): - cache = LFUCache(capacity=10) + cache = LFUCache(maxsize=10) cache_entry = { "result": {"allowed": True, "policy_violations": []}, "llm_stats": { @@ -351,7 +351,7 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self): processing_log_var.set(None) def test_get_from_cache_and_restore_stats_without_processing_log(self): - cache = LFUCache(capacity=10) + cache = LFUCache(maxsize=10) cache_entry = { "result": {"allowed": True, "policy_violations": []}, "llm_stats": { diff --git a/tests/test_content_safety_cache.py b/tests/test_content_safety_cache.py index 5261f0251..d503f95eb 100644 --- a/tests/test_content_safety_cache.py +++ b/tests/test_content_safety_cache.py @@ -46,7 +46,7 @@ def fake_llm_with_stats(): async def test_content_safety_cache_stores_result_and_stats( fake_llm_with_stats, mock_task_manager ): - cache = LFUCache(capacity=10) + cache = LFUCache(maxsize=10) llm_stats = LLMStats() llm_stats_var.set(llm_stats) @@ -86,7 +86,7 @@ async def test_content_safety_cache_stores_result_and_stats( async def test_content_safety_cache_retrieves_result_and_restores_stats( fake_llm_with_stats, mock_task_manager ): - cache = LFUCache(capacity=10) + cache = LFUCache(maxsize=10) cache_entry = { "result": {"allowed": True, "policy_violations": ["policy1"]}, @@ -129,7 +129,7 @@ async def test_content_safety_cache_retrieves_result_and_restores_stats( async def test_content_safety_cache_duration_reflects_cache_read_time( fake_llm_with_stats, mock_task_manager ): - cache = LFUCache(capacity=10) + cache = LFUCache(maxsize=10) cache_entry = { "result": {"allowed": True, "policy_violations": []}, @@ -186,7 +186,7 @@ async def test_content_safety_without_cache_does_not_store( async def test_content_safety_cache_handles_missing_stats_gracefully( fake_llm_with_stats, mock_task_manager ): - cache = LFUCache(capacity=10) + cache = LFUCache(maxsize=10) cache_entry = { "result": {"allowed": True, "policy_violations": []}, diff --git a/tests/test_integration_cache.py b/tests/test_integration_cache.py index 3c022031e..81719cc7f 100644 --- a/tests/test_integration_cache.py +++ b/tests/test_integration_cache.py @@ -109,8 +109,8 @@ async def test_cache_isolation_between_models(mock_init_llm_model): jailbreak_cache = model_caches["jailbreak_detection"] assert content_safety_cache is not jailbreak_cache - assert content_safety_cache.capacity == 50 - assert jailbreak_cache.capacity == 100 + assert content_safety_cache.maxsize == 50 + assert jailbreak_cache.maxsize == 100 content_safety_cache.put("key1", "value1") assert content_safety_cache.get("key1") == "value1" diff --git a/tests/test_llmrails.py b/tests/test_llmrails.py index 0e87853d6..a5f28a986 100644 --- a/tests/test_llmrails.py +++ b/tests/test_llmrails.py @@ -1247,7 +1247,7 @@ def test_cache_initialization_with_enabled_cache(mock_init_llm_model): assert "content_safety" in model_caches assert model_caches["content_safety"] is not None - assert model_caches["content_safety"].capacity == 1000 + assert model_caches["content_safety"].maxsize == 1000 @patch("nemoguardrails.rails.llm.llmrails.init_llm_model") @@ -1282,7 +1282,7 @@ def test_cache_not_created_for_main_and_embeddings_models(mock_init_llm_model): @patch("nemoguardrails.rails.llm.llmrails.init_llm_model") -def test_cache_initialization_with_zero_capacity_raises_error(mock_init_llm_model): +def test_cache_initialization_with_zero_maxsize_raises_error(mock_init_llm_model): from nemoguardrails.rails.llm.config import ModelCacheConfig mock_llm = FakeLLM(responses=["response"]) @@ -1299,7 +1299,7 @@ def test_cache_initialization_with_zero_capacity_raises_error(mock_init_llm_mode ] ) - with pytest.raises(ValueError, match="Invalid cache capacity"): + with pytest.raises(ValueError, match="Invalid cache maxsize"): LLMRails(config=config, verbose=False) @@ -1370,5 +1370,5 @@ def test_cache_initialization_with_multiple_models(mock_init_llm_model): assert "main" not in model_caches assert "content_safety" in model_caches assert "jailbreak_detection" in model_caches - assert model_caches["content_safety"].capacity == 1000 - assert model_caches["jailbreak_detection"].capacity == 2000 + assert model_caches["content_safety"].maxsize == 1000 + assert model_caches["jailbreak_detection"].maxsize == 2000 From c8fa1a8a08ec0bbe6a7a854e25573ad5b7a4675d Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Fri, 17 Oct 2025 10:03:35 +0200 Subject: [PATCH 19/19] add test for cache interface --- tests/test_cache_interface.py | 104 ++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 tests/test_cache_interface.py diff --git a/tests/test_cache_interface.py b/tests/test_cache_interface.py new file mode 100644 index 000000000..190ba49bf --- /dev/null +++ b/tests/test_cache_interface.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import Any + +import pytest + +from nemoguardrails.llm.cache.interface import CacheInterface + + +class MinimalCache(CacheInterface): + def __init__(self, maxsize: int = 10): + self._cache = {} + self._maxsize = maxsize + + def get(self, key: Any, default: Any = None) -> Any: + return self._cache.get(key, default) + + def put(self, key: Any, value: Any) -> None: + self._cache[key] = value + + def size(self) -> int: + return len(self._cache) + + def is_empty(self) -> bool: + return len(self._cache) == 0 + + def clear(self) -> None: + self._cache.clear() + + @property + def maxsize(self) -> int: + return self._maxsize + + +@pytest.fixture +def cache(): + return MinimalCache() + + +def test_contains(cache): + cache.put("key1", "value1") + assert cache.contains("key1") + + +def test_get_stats(cache): + stats = cache.get_stats() + assert isinstance(stats, dict) + assert "message" in stats + + +def test_reset_stats(cache): + cache.reset_stats() + + +def test_log_stats_now(cache): + cache.log_stats_now() + + +def test_supports_stats_logging(cache): + assert cache.supports_stats_logging() is False + + +@pytest.mark.asyncio +async def test_get_or_compute_cache_hit(cache): + cache.put("key1", "cached_value") + + async def compute_fn(): + return "computed_value" + + result = await cache.get_or_compute("key1", compute_fn) + assert result == "cached_value" + + +@pytest.mark.asyncio +async def test_get_or_compute_cache_miss(cache): + async def compute_fn(): + return "computed_value" + + result = await cache.get_or_compute("key1", compute_fn) + assert result == "computed_value" + assert cache.get("key1") == "computed_value" + + +@pytest.mark.asyncio +async def test_get_or_compute_exception(cache): + async def failing_compute(): + raise ValueError("Computation failed") + + result = await cache.get_or_compute("key1", failing_compute, default="fallback") + assert result == "fallback"