diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py index 2407210fa..c1ac780b3 100644 --- a/nemoguardrails/library/content_safety/actions.py +++ b/nemoguardrails/library/content_safety/actions.py @@ -21,6 +21,13 @@ from nemoguardrails.actions.actions import action from nemoguardrails.actions.llm.utils import llm_call from nemoguardrails.context import llm_call_info_var +from nemoguardrails.llm.cache import CacheInterface +from nemoguardrails.llm.cache.utils import ( + CacheEntry, + create_normalized_cache_key, + extract_llm_stats_for_cache, + get_from_cache_and_restore_stats, +) from nemoguardrails.llm.taskmanager import LLMTaskManager from nemoguardrails.logging.explain import LLMCallInfo @@ -33,6 +40,7 @@ async def content_safety_check_input( llm_task_manager: LLMTaskManager, model_name: Optional[str] = None, context: Optional[dict] = None, + model_caches: Optional[Dict[str, CacheInterface]] = None, **kwargs, ) -> dict: _MAX_TOKENS = 3 @@ -75,6 +83,15 @@ async def content_safety_check_input( max_tokens = max_tokens or _MAX_TOKENS + cache = model_caches.get(model_name) if model_caches else None + + if cache: + cache_key = create_normalized_cache_key(check_input_prompt) + cached_result = get_from_cache_and_restore_stats(cache, cache_key) + if cached_result is not None: + log.debug(f"Content safety cache hit for model '{model_name}'") + return cached_result + result = await llm_call( llm, check_input_prompt, @@ -86,7 +103,18 @@ async def content_safety_check_input( is_safe, *violated_policies = result - return {"allowed": is_safe, "policy_violations": violated_policies} + final_result = {"allowed": is_safe, "policy_violations": violated_policies} + + if cache: + cache_key = create_normalized_cache_key(check_input_prompt) + cache_entry: CacheEntry = { + "result": final_result, + "llm_stats": extract_llm_stats_for_cache(), + } + cache.put(cache_key, cache_entry) + log.debug(f"Content safety result cached for model '{model_name}'") + + return final_result def content_safety_check_output_mapping(result: dict) -> bool: diff --git a/nemoguardrails/llm/cache/__init__.py b/nemoguardrails/llm/cache/__init__.py new file mode 100644 index 000000000..de2a31c6b --- /dev/null +++ b/nemoguardrails/llm/cache/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General-purpose caching utilities for NeMo Guardrails.""" + +from nemoguardrails.llm.cache.interface import CacheInterface +from nemoguardrails.llm.cache.lfu import LFUCache + +__all__ = ["CacheInterface", "LFUCache"] diff --git a/nemoguardrails/llm/cache/interface.py b/nemoguardrails/llm/cache/interface.py new file mode 100644 index 000000000..74457913b --- /dev/null +++ b/nemoguardrails/llm/cache/interface.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cache interface for NeMo Guardrails caching system. + +This module defines the abstract base class for cache implementations +that can be used interchangeably throughout the guardrails system. +""" + +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional + + +class CacheInterface(ABC): + """ + Abstract base class defining the interface for cache implementations. + + All cache implementations must inherit from this class and implement + the required methods to ensure compatibility with the caching system. + """ + + @abstractmethod + def get(self, key: Any, default: Any = None) -> Any: + """ + Retrieve an item from the cache. + + Args: + key: The key to look up in the cache. + default: Value to return if key is not found (default: None). + + Returns: + The value associated with the key, or default if not found. + """ + pass + + @abstractmethod + def put(self, key: Any, value: Any) -> None: + """ + Store an item in the cache. + + If the cache is at maxsize, this method should evict an item + according to the cache's eviction policy (e.g., LFU, LRU, etc.). + + Args: + key: The key to store. + value: The value to associate with the key. + """ + pass + + @abstractmethod + def size(self) -> int: + """ + Get the current number of items in the cache. + + Returns: + The number of items currently stored in the cache. + """ + pass + + @abstractmethod + def is_empty(self) -> bool: + """ + Check if the cache is empty. + + Returns: + True if the cache contains no items, False otherwise. + """ + pass + + @abstractmethod + def clear(self) -> None: + """ + Remove all items from the cache. + + After calling this method, the cache should be empty. + """ + pass + + def contains(self, key: Any) -> bool: + """ + Check if a key exists in the cache. + + This is an optional method that can be overridden for efficiency. + The default implementation uses get() to check existence. + + Args: + key: The key to check. + + Returns: + True if the key exists in the cache, False otherwise. + """ + # Default implementation - can be overridden for efficiency + sentinel = object() + return self.get(key, sentinel) is not sentinel + + @property + @abstractmethod + def maxsize(self) -> int: + """ + Get the maximum size of the cache. + + Returns: + The maximum number of items the cache can hold. + """ + pass + + def get_stats(self) -> dict: + """ + Get cache statistics. + + Returns: + Dictionary with cache statistics. The format and contents + may vary by implementation. Common fields include: + - hits: Number of cache hits + - misses: Number of cache misses + - evictions: Number of items evicted + - hit_rate: Percentage of requests that were hits + - current_size: Current number of items in cache + - maxsize: Maximum size of the cache + + The default implementation returns a message indicating that + statistics tracking is not supported. + """ + return { + "message": "Statistics tracking is not supported by this cache implementation" + } + + def reset_stats(self) -> None: + """ + Reset cache statistics. + + This is an optional method that cache implementations can override + if they support statistics tracking. The default implementation does nothing. + """ + # Default no-op implementation + pass + + def log_stats_now(self) -> None: + """ + Force immediate logging of cache statistics. + + This is an optional method that cache implementations can override + if they support statistics logging. The default implementation does nothing. + + Implementations that support statistics logging should output the + current cache statistics to their configured logging backend. + """ + # Default no-op implementation + pass + + def supports_stats_logging(self) -> bool: + """ + Check if this cache implementation supports statistics logging. + + Returns: + True if the cache supports statistics logging, False otherwise. + + The default implementation returns False. Cache implementations + that support statistics logging should override this to return True + when logging is enabled. + """ + return False + + async def get_or_compute( + self, key: Any, compute_fn: Callable[[], Any], default: Any = None + ) -> Any: + """ + Atomically get a value from the cache or compute it if not present. + + This method ensures that the compute function is called at most once + even in the presence of concurrent requests for the same key. + + Args: + key: The key to look up + compute_fn: Async function to compute the value if key is not found + default: Value to return if compute_fn raises an exception + + Returns: + The cached value or the computed value + + This is an optional method with a default implementation. Cache + implementations should override this for better thread-safety guarantees. + """ + # Default implementation - not thread-safe for computation + value = self.get(key) + if value is not None: + return value + + try: + computed_value = await compute_fn() + self.put(key, computed_value) + return computed_value + except Exception: + return default diff --git a/nemoguardrails/llm/cache/lfu.py b/nemoguardrails/llm/cache/lfu.py new file mode 100644 index 000000000..973224757 --- /dev/null +++ b/nemoguardrails/llm/cache/lfu.py @@ -0,0 +1,470 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Least Frequently Used (LFU) cache implementation.""" + +import asyncio +import logging +import threading +import time +from typing import Any, Callable, Optional + +from nemoguardrails.llm.cache.interface import CacheInterface + +log = logging.getLogger(__name__) + + +class LFUNode: + """Node for the LFU cache doubly linked list.""" + + def __init__(self, key: Any, value: Any) -> None: + self.key = key + self.value = value + self.freq = 1 + self.prev: Optional["LFUNode"] = None + self.next: Optional["LFUNode"] = None + self.created_at = time.time() + self.accessed_at = self.created_at + + +class DoublyLinkedList: + """Doubly linked list to maintain nodes with the same frequency.""" + + def __init__(self) -> None: + # Create dummy head and tail nodes + self.head = LFUNode(None, None) + self.tail = LFUNode(None, None) + self.head.next = self.tail + self.tail.prev = self.head + self.size = 0 + + def append(self, node: LFUNode) -> None: + """Add node to the end of the list (before tail).""" + node.prev = self.tail.prev + node.next = self.tail + self.tail.prev.next = node + self.tail.prev = node + self.size += 1 + + def pop(self, node: Optional[LFUNode] = None) -> Optional[LFUNode]: + """Remove and return a node. If no node specified, removes the first node.""" + if self.size == 0: + return None + + if node is None: + node = self.head.next + + # Remove node from the list + node.prev.next = node.next + node.next.prev = node.prev + self.size -= 1 + + return node + + +class LFUCache(CacheInterface): + """ + Least Frequently Used (LFU) Cache implementation. + + When the cache reaches maxsize, it evicts the least frequently used item. + If there are ties in frequency, it evicts the least recently used among them. + """ + + def __init__( + self, + maxsize: int, + track_stats: bool = False, + stats_logging_interval: Optional[float] = None, + ) -> None: + """ + Initialize the LFU cache. + + Args: + maxsize: Maximum number of items the cache can hold + track_stats: Enable tracking of cache statistics + stats_logging_interval: Seconds between periodic stats logging (None disables logging) + """ + if maxsize < 0: + raise ValueError("Capacity must be non-negative") + + self._maxsize = maxsize + self.track_stats = track_stats + self._lock = threading.RLock() # Thread-safe access + self._computing: dict[Any, asyncio.Future] = {} # Track keys being computed + + self.key_map: dict[Any, LFUNode] = {} # key -> node mapping + self.freq_map: dict[int, DoublyLinkedList] = {} # frequency -> list of nodes + self.min_freq = 0 # Track minimum frequency for eviction + + # Stats logging configuration + self.stats_logging_interval = stats_logging_interval + # Initialize to None to ensure first check doesn't trigger immediately + self.last_stats_log_time = None + + # Statistics tracking + if self.track_stats: + self.stats = { + "hits": 0, + "misses": 0, + "evictions": 0, + "puts": 0, + "updates": 0, + } + + def _update_node_freq(self, node: LFUNode) -> None: + """Update the frequency of a node and move it to the appropriate frequency list.""" + old_freq = node.freq + old_list = self.freq_map[old_freq] + + # Remove node from current frequency list + old_list.pop(node) + + # Update min_freq if necessary + if self.min_freq == old_freq and old_list.size == 0: + self.min_freq += 1 + # Clean up empty frequency lists + del self.freq_map[old_freq] + + # Increment frequency and add to new list + node.freq += 1 + new_freq = node.freq + node.accessed_at = time.time() # Update access time + + if new_freq not in self.freq_map: + self.freq_map[new_freq] = DoublyLinkedList() + + self.freq_map[new_freq].append(node) + + def get(self, key: Any, default: Any = None) -> Any: + """ + Get an item from the cache. + + Args: + key: The key to look up + default: Value to return if key is not found + + Returns: + The value associated with the key, or default if not found + """ + with self._lock: + # Check if we should log stats + self._check_and_log_stats() + + if key not in self.key_map: + if self.track_stats: + self.stats["misses"] += 1 + return default + + node = self.key_map[key] + + if self.track_stats: + self.stats["hits"] += 1 + + self._update_node_freq(node) + return node.value + + def put(self, key: Any, value: Any) -> None: + """ + Put an item into the cache. + + Args: + key: The key to store + value: The value to associate with the key + """ + with self._lock: + # Check if we should log stats + self._check_and_log_stats() + + if self._maxsize == 0: + return + + if key in self.key_map: + # Update existing key + node = self.key_map[key] + node.value = value + node.created_at = time.time() # Reset creation time on update + self._update_node_freq(node) + if self.track_stats: + self.stats["updates"] += 1 + else: + # Add new key + if len(self.key_map) >= self._maxsize: + # Need to evict least frequently used item + self._evict_lfu() + + # Create new node and add to cache + new_node = LFUNode(key, value) + self.key_map[key] = new_node + + # Add to frequency 1 list + if 1 not in self.freq_map: + self.freq_map[1] = DoublyLinkedList() + + self.freq_map[1].append(new_node) + self.min_freq = 1 + + if self.track_stats: + self.stats["puts"] += 1 + + def _evict_lfu(self) -> None: + """Evict the least frequently used item from the cache.""" + if self.min_freq in self.freq_map: + lfu_list = self.freq_map[self.min_freq] + node_to_evict = lfu_list.pop() # Remove least recently used among LFU + + if node_to_evict: + del self.key_map[node_to_evict.key] + + if self.track_stats: + self.stats["evictions"] += 1 + + # Clean up empty frequency list + if lfu_list.size == 0: + del self.freq_map[self.min_freq] + + def size(self) -> int: + """Return the current size of the cache.""" + with self._lock: + return len(self.key_map) + + def is_empty(self) -> bool: + """Check if the cache is empty.""" + with self._lock: + return len(self.key_map) == 0 + + def clear(self) -> None: + """Clear all items from the cache.""" + with self._lock: + if self.track_stats: + # Track number of items evicted + self.stats["evictions"] += len(self.key_map) + + self.key_map.clear() + self.freq_map.clear() + self.min_freq = 0 + + def get_stats(self) -> dict: + """ + Get cache statistics. + + Returns: + Dictionary with cache statistics (if tracking is enabled) + """ + with self._lock: + if not self.track_stats: + return {"message": "Statistics tracking is disabled"} + + stats = self.stats.copy() + stats["current_size"] = len(self.key_map) # Direct access within lock + stats["maxsize"] = self._maxsize + + # Calculate hit rate + total_requests = stats["hits"] + stats["misses"] + stats["hit_rate"] = ( + stats["hits"] / total_requests if total_requests > 0 else 0.0 + ) + + return stats + + def reset_stats(self) -> None: + """Reset cache statistics.""" + with self._lock: + if self.track_stats: + self.stats = { + "hits": 0, + "misses": 0, + "evictions": 0, + "puts": 0, + "updates": 0, + } + + def _check_and_log_stats(self) -> None: + """Check if enough time has passed and log stats if needed.""" + if not self.track_stats or self.stats_logging_interval is None: + return + + current_time = time.time() + + # Initialize timestamp on first check + if self.last_stats_log_time is None: + self.last_stats_log_time = current_time + return + + if current_time - self.last_stats_log_time >= self.stats_logging_interval: + self._log_stats() + self.last_stats_log_time = current_time + + def _log_stats(self) -> None: + """Log current cache statistics.""" + stats = self.get_stats() + + # Format the log message + log_msg = ( + f"LFU Cache Statistics - " + f"Size: {stats['current_size']}/{stats['maxsize']} | " + f"Hits: {stats['hits']} | " + f"Misses: {stats['misses']} | " + f"Hit Rate: {stats['hit_rate']:.2%} | " + f"Evictions: {stats['evictions']} | " + f"Puts: {stats['puts']} | " + f"Updates: {stats['updates']}" + ) + + log.info(log_msg) + + def log_stats_now(self) -> None: + """Force immediate logging of cache statistics.""" + if self.track_stats: + self._log_stats() + self.last_stats_log_time = time.time() + + def supports_stats_logging(self) -> bool: + """Check if this cache instance supports stats logging.""" + return self.track_stats and self.stats_logging_interval is not None + + async def get_or_compute( + self, key: Any, compute_fn: Callable[[], Any], default: Any = None + ) -> Any: + """ + Atomically get a value from the cache or compute it if not present. + + This method ensures that the compute function is called at most once + even in the presence of concurrent requests for the same key. + + Args: + key: The key to look up + compute_fn: Async function to compute the value if key is not found + default: Value to return if compute_fn raises an exception + + Returns: + The cached value or the computed value + """ + # First check if the value is already in cache + future = None + with self._lock: + if key in self.key_map: + node = self.key_map[key] + if self.track_stats: + self.stats["hits"] += 1 + self._update_node_freq(node) + return node.value + + # Check if this key is already being computed + if key in self._computing: + future = self._computing[key] + + # If the key is being computed, wait for it outside the lock + if future is not None: + try: + return await future + except Exception: + return default + + # Create a future for this computation + future = asyncio.Future() + with self._lock: + # Double-check the cache and computing dict + if key in self.key_map: + node = self.key_map[key] + if self.track_stats: + self.stats["hits"] += 1 + self._update_node_freq(node) + return node.value + + if key in self._computing: + # Another thread started computing while we were waiting + future = self._computing[key] + else: + # We'll be the ones computing + self._computing[key] = future + + # If another thread is computing, wait for it + if not future.done() and self._computing.get(key) is not future: + try: + return await self._computing[key] + except Exception: + return default + + # We're responsible for computing the value + try: + computed_value = await compute_fn() + + # Store the computed value in cache + with self._lock: + # Remove from computing dict + self._computing.pop(key, None) + + # Check one more time if someone else added it + if key in self.key_map: + node = self.key_map[key] + if self.track_stats: + self.stats["hits"] += 1 + self._update_node_freq(node) + future.set_result(node.value) + return node.value + + # Now add to cache using internal logic + if self._maxsize == 0: + future.set_result(computed_value) + return computed_value + + # Add new key + if len(self.key_map) >= self._maxsize: + self._evict_lfu() + + # Create new node and add to cache + new_node = LFUNode(key, computed_value) + self.key_map[key] = new_node + + # Add to frequency 1 list + if 1 not in self.freq_map: + self.freq_map[1] = DoublyLinkedList() + + self.freq_map[1].append(new_node) + self.min_freq = 1 + + if self.track_stats: + self.stats["puts"] += 1 + + # Set the result in the future + future.set_result(computed_value) + return computed_value + + except Exception as e: + with self._lock: + self._computing.pop(key, None) + future.set_exception(e) + return default + + def contains(self, key: Any) -> bool: + """ + Check if a key exists in the cache without updating its frequency. + + This is more efficient than the default implementation which uses get() + and has the side effect of updating frequency counts. + + Args: + key: The key to check + + Returns: + True if the key exists in the cache, False otherwise + """ + with self._lock: + return key in self.key_map + + @property + def maxsize(self) -> int: + """Get the maximum size of the cache.""" + return self._maxsize diff --git a/nemoguardrails/llm/cache/utils.py b/nemoguardrails/llm/cache/utils.py new file mode 100644 index 000000000..1cf02fe4d --- /dev/null +++ b/nemoguardrails/llm/cache/utils.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import json +import re +from time import time +from typing import TYPE_CHECKING, List, Optional, TypedDict, Union + +from nemoguardrails.context import llm_call_info_var, llm_stats_var +from nemoguardrails.logging.processing_log import processing_log_var +from nemoguardrails.logging.stats import LLMStats + +if TYPE_CHECKING: + from nemoguardrails.llm.cache.interface import CacheInterface + +PROMPT_PATTERN_WHITESPACES = re.compile(r"\s+") + + +class LLMStatsDict(TypedDict): + total_tokens: int + prompt_tokens: int + completion_tokens: int + + +class CacheEntry(TypedDict): + result: dict + llm_stats: Optional[LLMStatsDict] + + +def create_normalized_cache_key( + prompt: Union[str, List[dict]], normalize_whitespace: bool = True +) -> str: + """ + Create a normalized, hashed cache key from a prompt. + + This function generates a deterministic cache key by normalizing the prompt + and applying SHA-256 hashing. The normalization ensures that semantically + equivalent prompts produce the same cache key. + + Args: + prompt: The prompt to be cached. Can be: + - str: A single prompt string (for completion models) + - List[dict]: A list of message dictionaries for chat models + (e.g., [{"type": "user", "content": "Hello"}]) + Note: render_task_prompt() returns Union[str, List[dict]] + normalize_whitespace: Whether to normalize whitespace characters. + When True, collapses all whitespace sequences to single spaces and + strips leading/trailing whitespace. Default: True + + Returns: + A SHA-256 hex digest string (64 characters) suitable for use as a cache key + + Raises: + TypeError: If prompt is not a str or List[dict] + + Examples: + >>> create_normalized_cache_key("Hello world") + '64ec88ca00b268e5ba1a35678a1b5316d212f4f366b2477232534a8aeca37f3c' + + >>> create_normalized_cache_key([{"type": "user", "content": "Hello"}]) + 'b2f5c9d8e3a1f7b6c4d2e5f8a9c1d3e5f7b9a2c4d6e8f1a3b5c7d9e2f4a6b8' + """ + if isinstance(prompt, str): + prompt_str = prompt + elif isinstance(prompt, list): + if not all(isinstance(p, dict) for p in prompt): + raise TypeError( + f"All elements in prompt list must be dictionaries (messages). " + f"Got types: {[type(p).__name__ for p in prompt]}" + ) + prompt_str = json.dumps(prompt, sort_keys=True) + else: + raise TypeError( + f"Invalid type for prompt: {type(prompt).__name__}. " + f"Expected str or List[dict]." + ) + + if normalize_whitespace: + prompt_str = PROMPT_PATTERN_WHITESPACES.sub(" ", prompt_str).strip() + + return hashlib.sha256(prompt_str.encode("utf-8")).hexdigest() + + +def restore_llm_stats_from_cache( + cached_stats: LLMStatsDict, cache_read_duration: float +) -> None: + llm_stats = llm_stats_var.get() + if llm_stats is None: + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + llm_stats.inc("total_calls") + llm_stats.inc("total_time", cache_read_duration) + llm_stats.inc("total_tokens", cached_stats.get("total_tokens", 0)) + llm_stats.inc("total_prompt_tokens", cached_stats.get("prompt_tokens", 0)) + llm_stats.inc("total_completion_tokens", cached_stats.get("completion_tokens", 0)) + + llm_call_info = llm_call_info_var.get() + if llm_call_info: + llm_call_info.duration = cache_read_duration + llm_call_info.total_tokens = cached_stats.get("total_tokens", 0) + llm_call_info.prompt_tokens = cached_stats.get("prompt_tokens", 0) + llm_call_info.completion_tokens = cached_stats.get("completion_tokens", 0) + llm_call_info.from_cache = True + llm_call_info.started_at = time() - cache_read_duration + llm_call_info.finished_at = time() + + +def extract_llm_stats_for_cache() -> Optional[LLMStatsDict]: + llm_call_info = llm_call_info_var.get() + if llm_call_info: + return { + "total_tokens": llm_call_info.total_tokens or 0, + "prompt_tokens": llm_call_info.prompt_tokens or 0, + "completion_tokens": llm_call_info.completion_tokens or 0, + } + return None + + +def get_from_cache_and_restore_stats( + cache: "CacheInterface", cache_key: str +) -> Optional[dict]: + cached_entry = cache.get(cache_key) + if cached_entry is None: + return None + + cache_read_start = time() + final_result = cached_entry["result"] + cached_stats = cached_entry.get("llm_stats") + cache_read_duration = time() - cache_read_start + + if cached_stats: + restore_llm_stats_from_cache(cached_stats, cache_read_duration) + + processing_log = processing_log_var.get() + if processing_log is not None: + llm_call_info = llm_call_info_var.get() + if llm_call_info: + processing_log.append( + { + "type": "llm_call_info", + "timestamp": time(), + "data": llm_call_info, + } + ) + + return final_result diff --git a/nemoguardrails/logging/explain.py b/nemoguardrails/logging/explain.py index edf7825c2..ad75e5c45 100644 --- a/nemoguardrails/logging/explain.py +++ b/nemoguardrails/logging/explain.py @@ -63,6 +63,10 @@ class LLMCallInfo(LLMCallSummary): default="unknown", description="The provider of the model used for the LLM call, e.g. 'openai', 'nvidia'.", ) + from_cache: bool = Field( + default=False, + description="Whether this response was retrieved from cache.", + ) class ExplainInfo(BaseModel): diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 6b6a9b64a..0c037c092 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -69,6 +69,35 @@ colang_path_dirs.append(guardrails_stdlib_path) +class CacheStatsConfig(BaseModel): + """Configuration for cache statistics tracking and logging.""" + + enabled: bool = Field( + default=False, + description="Whether cache statistics tracking is enabled", + ) + log_interval: Optional[float] = Field( + default=None, + description="Seconds between periodic cache stats logging to logs (None disables logging)", + ) + + +class ModelCacheConfig(BaseModel): + """Configuration for model caching.""" + + enabled: bool = Field( + default=False, + description="Whether caching is enabled (default: False - no caching)", + ) + maxsize: int = Field( + default=50000, description="Maximum number of entries in the cache per model" + ) + stats: CacheStatsConfig = Field( + default_factory=CacheStatsConfig, + description="Configuration for cache statistics tracking and logging", + ) + + class Model(BaseModel): """Configuration of a model used by the rails engine. @@ -97,6 +126,12 @@ class Model(BaseModel): description="Whether the mode is 'text' completion or 'chat' completion. Allowed values are 'chat' or 'text'.", ) + # Cache configuration specific to this model (for content safety models) + cache: Optional["ModelCacheConfig"] = Field( + default=None, + description="Cache configuration for this specific model (primarily used for content safety models)", + ) + @model_validator(mode="before") @classmethod def set_and_validate_model(cls, data: Any) -> Any: diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index e736a32df..d811464eb 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -70,6 +70,7 @@ from nemoguardrails.embeddings.providers import register_embedding_provider from nemoguardrails.embeddings.providers.base import EmbeddingModel from nemoguardrails.kb.kb import KnowledgeBase +from nemoguardrails.llm.cache import CacheInterface, LFUCache from nemoguardrails.llm.models.initializer import ( ModelInitializationError, init_llm_model, @@ -500,6 +501,7 @@ def _init_llms(self): kwargs=kwargs, ) + # Configure the model based on its type if llm_config.type == "main": # If a main LLM was already injected, skip creating another # one. Otherwise, create and register it. @@ -525,6 +527,60 @@ def _init_llms(self): self.runtime.register_action_param("llms", llms) + self._initialize_model_caches() + + def _create_model_cache(self, model) -> LFUCache: + """ + Create cache instance for a model based on its configuration. + + Args: + model: The model configuration object + + Returns: + LFUCache: The cache instance + """ + + if model.cache.maxsize <= 0: + raise ValueError( + f"Invalid cache maxsize for model '{model.type}': {model.cache.maxsize}. " + "Capacity must be greater than 0. Skipping cache creation." + ) + + stats_logging_interval = None + if model.cache.stats.enabled and model.cache.stats.log_interval is not None: + stats_logging_interval = model.cache.stats.log_interval + + cache = LFUCache( + maxsize=model.cache.maxsize, + track_stats=model.cache.stats.enabled, + stats_logging_interval=stats_logging_interval, + ) + + log.info( + f"Created cache for model '{model.type}' with maxsize {model.cache.maxsize}" + ) + + return cache + + def _initialize_model_caches(self) -> None: + """Initialize caches for configured models.""" + model_caches: Optional[Dict[str, CacheInterface]] = dict() + for model in self.config.models: + if model.type in ["main", "embeddings"]: + continue + + if model.cache and model.cache.enabled: + cache = self._create_model_cache(model) + model_caches[model.type] = cache + + log.info( + f"Initialized model '{model.type}' with cache %s", + "enabled" if cache else "disabled", + ) + + if model_caches: + self.runtime.register_action_param("model_caches", model_caches) + def _get_embeddings_search_provider_instance( self, esp_config: Optional[EmbeddingSearchProvider] = None ) -> EmbeddingsIndex: diff --git a/nemoguardrails/tracing/constants.py b/nemoguardrails/tracing/constants.py index 3e0bf3179..2cf411413 100644 --- a/nemoguardrails/tracing/constants.py +++ b/nemoguardrails/tracing/constants.py @@ -119,7 +119,10 @@ class GuardrailsAttributes: ACTION_NAME = "action.name" ACTION_HAS_LLM_CALLS = "action.has_llm_calls" ACTION_LLM_CALLS_COUNT = "action.llm_calls_count" - ACTION_PARAM_PREFIX = "action.param." # For dynamic action parameters + ACTION_PARAM_PREFIX = "action.param." + + # llm attributes (application-level, not provider-level) + LLM_CACHE_HIT = "llm.cache.hit" class SpanNames: diff --git a/nemoguardrails/tracing/span_extractors.py b/nemoguardrails/tracing/span_extractors.py index cca40024f..e40318d01 100644 --- a/nemoguardrails/tracing/span_extractors.py +++ b/nemoguardrails/tracing/span_extractors.py @@ -257,6 +257,8 @@ def extract_spans( max_tokens = llm_call.raw_response.get("max_tokens") top_p = llm_call.raw_response.get("top_p") + cache_hit = hasattr(llm_call, "from_cache") and llm_call.from_cache + llm_span = LLMSpan( span_id=new_uuid(), name=span_name, @@ -276,6 +278,7 @@ def extract_spans( top_p=top_p, response_id=response_id, response_finish_reasons=finish_reasons, + cache_hit=cache_hit, # TODO: add error to LLMCallInfo for future release # error=( # True diff --git a/nemoguardrails/tracing/spans.py b/nemoguardrails/tracing/spans.py index fb89fb394..e8e7ec565 100644 --- a/nemoguardrails/tracing/spans.py +++ b/nemoguardrails/tracing/spans.py @@ -270,6 +270,11 @@ class LLMSpan(BaseSpan): default=None, description="Finish reasons for each choice" ) + cache_hit: bool = Field( + default=False, + description="Whether this LLM response was served from application cache", + ) + def to_otel_attributes(self) -> Dict[str, Any]: """Convert to OTel attributes.""" attributes = self._base_attributes() @@ -320,6 +325,8 @@ def to_otel_attributes(self) -> Dict[str, Any]: GenAIAttributes.GEN_AI_RESPONSE_FINISH_REASONS ] = self.response_finish_reasons + attributes[GuardrailsAttributes.LLM_CACHE_HIT] = self.cache_hit + return attributes diff --git a/tests/test_cache_interface.py b/tests/test_cache_interface.py new file mode 100644 index 000000000..190ba49bf --- /dev/null +++ b/tests/test_cache_interface.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import Any + +import pytest + +from nemoguardrails.llm.cache.interface import CacheInterface + + +class MinimalCache(CacheInterface): + def __init__(self, maxsize: int = 10): + self._cache = {} + self._maxsize = maxsize + + def get(self, key: Any, default: Any = None) -> Any: + return self._cache.get(key, default) + + def put(self, key: Any, value: Any) -> None: + self._cache[key] = value + + def size(self) -> int: + return len(self._cache) + + def is_empty(self) -> bool: + return len(self._cache) == 0 + + def clear(self) -> None: + self._cache.clear() + + @property + def maxsize(self) -> int: + return self._maxsize + + +@pytest.fixture +def cache(): + return MinimalCache() + + +def test_contains(cache): + cache.put("key1", "value1") + assert cache.contains("key1") + + +def test_get_stats(cache): + stats = cache.get_stats() + assert isinstance(stats, dict) + assert "message" in stats + + +def test_reset_stats(cache): + cache.reset_stats() + + +def test_log_stats_now(cache): + cache.log_stats_now() + + +def test_supports_stats_logging(cache): + assert cache.supports_stats_logging() is False + + +@pytest.mark.asyncio +async def test_get_or_compute_cache_hit(cache): + cache.put("key1", "cached_value") + + async def compute_fn(): + return "computed_value" + + result = await cache.get_or_compute("key1", compute_fn) + assert result == "cached_value" + + +@pytest.mark.asyncio +async def test_get_or_compute_cache_miss(cache): + async def compute_fn(): + return "computed_value" + + result = await cache.get_or_compute("key1", compute_fn) + assert result == "computed_value" + assert cache.get("key1") == "computed_value" + + +@pytest.mark.asyncio +async def test_get_or_compute_exception(cache): + async def failing_compute(): + raise ValueError("Computation failed") + + result = await cache.get_or_compute("key1", failing_compute, default="fallback") + assert result == "fallback" diff --git a/tests/test_cache_lfu.py b/tests/test_cache_lfu.py new file mode 100644 index 000000000..b5bc5acda --- /dev/null +++ b/tests/test_cache_lfu.py @@ -0,0 +1,1205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comprehensive test suite for LFU Cache implementation. + +Tests all functionality including basic operations, eviction policies, +maxsize management, and edge cases. +""" + +import asyncio +import os +import threading +import time +import unittest +from concurrent.futures import ThreadPoolExecutor +from typing import Any +from unittest.mock import MagicMock, patch + +from nemoguardrails.llm.cache.lfu import LFUCache + + +class TestLFUCache(unittest.TestCase): + """Test cases for LFU Cache implementation.""" + + def setUp(self): + """Set up test fixtures.""" + self.cache = LFUCache(3) + + def test_initialization(self): + """Test cache initialization with various capacities.""" + # Normal maxsize + cache = LFUCache(5) + self.assertEqual(cache.size(), 0) + self.assertTrue(cache.is_empty()) + + # Zero maxsize + cache_zero = LFUCache(0) + self.assertEqual(cache_zero.size(), 0) + + # Negative maxsize should raise error + with self.assertRaises(ValueError): + LFUCache(-1) + + def test_basic_put_get(self): + """Test basic put and get operations.""" + # Put and get single item + self.cache.put("key1", "value1") + self.assertEqual(self.cache.get("key1"), "value1") + self.assertEqual(self.cache.size(), 1) + + # Put and get multiple items + self.cache.put("key2", "value2") + self.cache.put("key3", "value3") + + self.assertEqual(self.cache.get("key1"), "value1") + self.assertEqual(self.cache.get("key2"), "value2") + self.assertEqual(self.cache.get("key3"), "value3") + self.assertEqual(self.cache.size(), 3) + + def test_get_nonexistent_key(self): + """Test getting non-existent keys.""" + # Default behavior (returns None) + self.assertIsNone(self.cache.get("nonexistent")) + + # With custom default + self.assertEqual(self.cache.get("nonexistent", "default"), "default") + + # After adding some items + self.cache.put("key1", "value1") + self.assertIsNone(self.cache.get("key2")) + self.assertEqual(self.cache.get("key2", 42), 42) + + def test_update_existing_key(self): + """Test updating values for existing keys.""" + self.cache.put("key1", "value1") + self.cache.put("key2", "value2") + + # Update existing key + self.cache.put("key1", "new_value1") + self.assertEqual(self.cache.get("key1"), "new_value1") + + # Size should not change + self.assertEqual(self.cache.size(), 2) + + def test_lfu_eviction_basic(self): + """Test basic LFU eviction when cache is full.""" + # Fill cache + self.cache.put("a", 1) + self.cache.put("b", 2) + self.cache.put("c", 3) + + # Access 'a' and 'b' to increase their frequency + self.cache.get("a") # freq: 2 + self.cache.get("b") # freq: 2 + # 'c' remains at freq: 1 + + # Add new item - should evict 'c' (lowest frequency) + self.cache.put("d", 4) + + self.assertEqual(self.cache.get("a"), 1) + self.assertEqual(self.cache.get("b"), 2) + self.assertEqual(self.cache.get("d"), 4) + self.assertIsNone(self.cache.get("c")) # Should be evicted + + def test_lfu_with_lru_tiebreaker(self): + """Test LRU eviction among items with same frequency.""" + # Fill cache - all items have frequency 1 + self.cache.put("a", 1) + self.cache.put("b", 2) + self.cache.put("c", 3) + + # Add new item - should evict 'a' (least recently used among freq 1) + self.cache.put("d", 4) + + self.assertIsNone(self.cache.get("a")) # Should be evicted + self.assertEqual(self.cache.get("b"), 2) + self.assertEqual(self.cache.get("c"), 3) + self.assertEqual(self.cache.get("d"), 4) + + def test_complex_eviction_scenario(self): + """Test complex eviction scenario with multiple frequency levels.""" + # Create a new cache for this test + cache = LFUCache(4) + + # Add items and create different frequency levels + cache.put("a", 1) + cache.put("b", 2) + cache.put("c", 3) + cache.put("d", 4) + + # Create frequency pattern: + # a: freq 3 (accessed 2 more times) + # b: freq 2 (accessed 1 more time) + # c: freq 2 (accessed 1 more time) + # d: freq 1 (not accessed) + + cache.get("a") + cache.get("a") + cache.get("b") + cache.get("c") + + # Add new item - should evict 'd' (freq 1) + cache.put("e", 5) + self.assertIsNone(cache.get("d")) + + # Add another item - should evict one of the least frequently used + cache.put("f", 6) + + # After eviction, we should have: + # - 'a' (freq 3) - definitely kept + # - 'b' (freq 2) and 'c' (freq 2) - higher frequency, both kept + # - 'f' (freq 1) - just added + # - 'e' (freq 1) was evicted as it was least recently used among freq 1 items + + # Check that we're at maxsize + self.assertEqual(cache.size(), 4) + + # 'a' should definitely still be there (highest frequency) + self.assertEqual(cache.get("a"), 1) + + # 'b' and 'c' should both be there (freq 2) + self.assertEqual(cache.get("b"), 2) + self.assertEqual(cache.get("c"), 3) + + # 'f' should be there (just added) + self.assertEqual(cache.get("f"), 6) + + # 'e' should have been evicted (freq 1, LRU among freq 1 items) + self.assertIsNone(cache.get("e")) + + def test_zero_maxsize_cache(self): + """Test cache with zero maxsize.""" + cache = LFUCache(0) + + # Put should not store anything + cache.put("key", "value") + self.assertEqual(cache.size(), 0) + self.assertIsNone(cache.get("key")) + + # Multiple puts + for i in range(10): + cache.put(f"key{i}", f"value{i}") + + self.assertEqual(cache.size(), 0) + self.assertTrue(cache.is_empty()) + + def test_clear_method(self): + """Test clearing the cache.""" + # Add items + self.cache.put("a", 1) + self.cache.put("b", 2) + self.cache.put("c", 3) + + # Verify items exist + self.assertEqual(self.cache.size(), 3) + self.assertFalse(self.cache.is_empty()) + + # Clear cache + self.cache.clear() + + # Verify cache is empty + self.assertEqual(self.cache.size(), 0) + self.assertTrue(self.cache.is_empty()) + + # Verify items are gone + self.assertIsNone(self.cache.get("a")) + self.assertIsNone(self.cache.get("b")) + self.assertIsNone(self.cache.get("c")) + + # Can still use cache after clear + self.cache.put("new_key", "new_value") + self.assertEqual(self.cache.get("new_key"), "new_value") + + def test_various_data_types(self): + """Test cache with various data types as keys and values.""" + # Integer keys + self.cache.put(1, "one") + self.cache.put(2, "two") + self.assertEqual(self.cache.get(1), "one") + self.assertEqual(self.cache.get(2), "two") + + # Tuple keys + self.cache.put((1, 2), "tuple_value") + self.assertEqual(self.cache.get((1, 2)), "tuple_value") + + # Clear for more tests + self.cache.clear() + + # Complex values + self.cache.put("list", [1, 2, 3]) + self.cache.put("dict", {"a": 1, "b": 2}) + self.cache.put("set", {1, 2, 3}) + + self.assertEqual(self.cache.get("list"), [1, 2, 3]) + self.assertEqual(self.cache.get("dict"), {"a": 1, "b": 2}) + self.assertEqual(self.cache.get("set"), {1, 2, 3}) + + def test_none_values(self): + """Test storing None as a value.""" + self.cache.put("key", None) + # get should return None for the value, not the default + self.assertIsNone(self.cache.get("key")) + self.assertEqual(self.cache.get("key", "default"), None) + + # Verify key exists + self.assertEqual(self.cache.size(), 1) + + def test_size_and_maxsize(self): + """Test size tracking and maxsize limits.""" + # Start empty + self.assertEqual(self.cache.size(), 0) + + # Add items up to maxsize + for i in range(3): + self.cache.put(f"key{i}", f"value{i}") + self.assertEqual(self.cache.size(), i + 1) + + # Add more items - size should stay at maxsize + for i in range(3, 10): + self.cache.put(f"key{i}", f"value{i}") + self.assertEqual(self.cache.size(), 3) + + def test_is_empty(self): + """Test is_empty method in various states.""" + # Initially empty + self.assertTrue(self.cache.is_empty()) + + # After adding item + self.cache.put("key", "value") + self.assertFalse(self.cache.is_empty()) + + # After clearing + self.cache.clear() + self.assertTrue(self.cache.is_empty()) + + def test_repeated_puts_same_key(self): + """Test repeated puts with the same key maintain size=1 and update frequency.""" + self.cache.put("key", "value1") + self.assertEqual(self.cache.size(), 1) + + # Track initial state + initial_stats = self.cache.get_stats() if self.cache.track_stats else None + + # Update same key multiple times + for i in range(10): + self.cache.put("key", f"value{i}") + self.assertEqual(self.cache.size(), 1) + + # Final value should be the last one + self.assertEqual(self.cache.get("key"), "value9") + + # Verify stats if tracking enabled + if self.cache.track_stats: + final_stats = self.cache.get_stats() + # Should have 10 updates (after initial put) + self.assertEqual(final_stats["updates"], 10) + + def test_access_pattern_preserves_frequently_used(self): + """Test that frequently accessed items are preserved during evictions.""" + # Create specific access pattern + cache = LFUCache(3) + + # Add three items + cache.put("rarely_used", 1) + cache.put("sometimes_used", 2) + cache.put("frequently_used", 3) + + # Create access pattern + # frequently_used: access 10 times + for _ in range(10): + cache.get("frequently_used") + + # sometimes_used: access 3 times + for _ in range(3): + cache.get("sometimes_used") + + # rarely_used: no additional access (freq = 1) + + # Add new items to trigger evictions + cache.put("new1", 4) # Should evict rarely_used + cache.put("new2", 5) # Should evict new1 (freq = 1) + + # frequently_used and sometimes_used should still be there + self.assertEqual(cache.get("frequently_used"), 3) + self.assertEqual(cache.get("sometimes_used"), 2) + + # rarely_used and new1 should be evicted + self.assertIsNone(cache.get("rarely_used")) + self.assertIsNone(cache.get("new1")) + + # new2 should be there + self.assertEqual(cache.get("new2"), 5) + + +class TestLFUCacheInterface(unittest.TestCase): + """Test that LFUCache properly implements CacheInterface.""" + + def test_interface_methods_exist(self): + """Verify all interface methods are implemented.""" + cache = LFUCache(5) + + # Check all required methods exist and are callable + self.assertTrue(callable(getattr(cache, "get", None))) + self.assertTrue(callable(getattr(cache, "put", None))) + self.assertTrue(callable(getattr(cache, "size", None))) + self.assertTrue(callable(getattr(cache, "is_empty", None))) + self.assertTrue(callable(getattr(cache, "clear", None))) + + # Check property + self.assertEqual(cache.maxsize, 5) + + +class TestLFUCacheStatsLogging(unittest.TestCase): + """Test cases for LFU Cache statistics logging functionality.""" + + def test_stats_logging_disabled_by_default(self): + """Test that stats logging is disabled when not configured.""" + cache = LFUCache(5, track_stats=True) + self.assertFalse(cache.supports_stats_logging()) + + def test_stats_logging_requires_tracking(self): + """Test that stats logging requires stats tracking to be enabled.""" + # Logging without tracking + cache = LFUCache(5, track_stats=False, stats_logging_interval=1.0) + self.assertFalse(cache.supports_stats_logging()) + + # Both enabled + cache = LFUCache(5, track_stats=True, stats_logging_interval=1.0) + self.assertTrue(cache.supports_stats_logging()) + + def test_log_stats_now(self): + """Test immediate stats logging.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=60.0) + + # Add some data + cache.put("key1", "value1") + cache.put("key2", "value2") + cache.get("key1") + cache.get("nonexistent") + + with patch.object( + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + # Verify log was called + self.assertEqual(mock_log.call_count, 1) + log_message = mock_log.call_args[0][0] + + # Check log format + self.assertIn("LFU Cache Statistics", log_message) + self.assertIn("Size: 2/5", log_message) + self.assertIn("Hits: 1", log_message) + self.assertIn("Misses: 1", log_message) + self.assertIn("Hit Rate: 50.00%", log_message) + self.assertIn("Evictions: 0", log_message) + self.assertIn("Puts: 2", log_message) + self.assertIn("Updates: 0", log_message) + + def test_periodic_stats_logging(self): + """Test automatic periodic stats logging.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.5) + + # Add some data + cache.put("key1", "value1") + cache.put("key2", "value2") + + with patch.object( + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" + ) as mock_log: + # Initial operations shouldn't trigger logging + cache.get("key1") + self.assertEqual(mock_log.call_count, 0) + + # Wait for interval to pass + time.sleep(0.6) + + # Next operation should trigger logging + cache.get("key1") + self.assertEqual(mock_log.call_count, 1) + + # Another operation without waiting shouldn't trigger + cache.get("key2") + self.assertEqual(mock_log.call_count, 1) + + # Wait again + time.sleep(0.6) + cache.put("key3", "value3") + self.assertEqual(mock_log.call_count, 2) + + def test_stats_logging_with_empty_cache(self): + """Test stats logging with empty cache.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.1) + + # Generate a miss first + cache.get("nonexistent") + + # Wait for interval to pass + time.sleep(0.2) + + with patch.object( + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" + ) as mock_log: + # This will trigger stats logging with the previous miss already counted + cache.get("another_nonexistent") # Trigger check + + self.assertEqual(mock_log.call_count, 1) + log_message = mock_log.call_args[0][0] + + self.assertIn("Size: 0/5", log_message) + self.assertIn("Hits: 0", log_message) + self.assertIn("Misses: 1", log_message) # The first miss is logged + self.assertIn("Hit Rate: 0.00%", log_message) + + def test_stats_logging_with_full_cache(self): + """Test stats logging when cache is at maxsize.""" + import logging + from unittest.mock import patch + + cache = LFUCache(3, track_stats=True, stats_logging_interval=0.1) + + # Fill cache + cache.put("key1", "value1") + cache.put("key2", "value2") + cache.put("key3", "value3") + + # Cause eviction + cache.put("key4", "value4") + + with patch.object( + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" + ) as mock_log: + time.sleep(0.2) + cache.get("key4") # Trigger check + + log_message = mock_log.call_args[0][0] + self.assertIn("Size: 3/3", log_message) + self.assertIn("Evictions: 1", log_message) + self.assertIn("Puts: 4", log_message) + + def test_stats_logging_high_hit_rate(self): + """Test stats logging with high hit rate.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.1) + + cache.put("key1", "value1") + + # Many hits + for _ in range(99): + cache.get("key1") + + # One miss + cache.get("nonexistent") + + with patch.object( + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + log_message = mock_log.call_args[0][0] + self.assertIn("Hit Rate: 99.00%", log_message) + self.assertIn("Hits: 99", log_message) + self.assertIn("Misses: 1", log_message) + + def test_stats_logging_without_tracking(self): + """Test that log_stats_now does nothing when tracking is disabled.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=False) + + cache.put("key1", "value1") + cache.get("key1") + + with patch.object( + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + # Should not log anything + self.assertEqual(mock_log.call_count, 0) + + def test_stats_logging_interval_timing(self): + """Test that stats logging respects the interval timing.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=1.0) + + with patch.object( + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" + ) as mock_log: + # Multiple operations within interval + for i in range(10): + cache.put(f"key{i}", f"value{i}") + cache.get(f"key{i}") + time.sleep(0.05) # Total time < 1.0 + + # Should not have logged yet + self.assertEqual(mock_log.call_count, 0) + + # Wait for interval to pass + time.sleep(0.6) + cache.get("key1") # Trigger check + + # Now should have logged once + self.assertEqual(mock_log.call_count, 1) + + def test_stats_logging_with_updates(self): + """Test stats logging includes update counts.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.1) + + cache.put("key1", "value1") + cache.put("key1", "updated_value1") # Update + cache.put("key1", "updated_again") # Another update + + with patch.object( + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + log_message = mock_log.call_args[0][0] + self.assertIn("Updates: 2", log_message) + self.assertIn("Puts: 1", log_message) + + def test_stats_log_format_percentages(self): + """Test that percentages in stats log are formatted correctly.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.1) + + # Test various hit rates + test_cases = [ + (0, 0, "0.00%"), # No requests + (1, 0, "100.00%"), # All hits + (0, 1, "0.00%"), # All misses + (1, 1, "50.00%"), # 50/50 + (2, 1, "66.67%"), # 2/3 + (99, 1, "99.00%"), # High hit rate + ] + + for hits, misses, expected_rate in test_cases: + cache.reset_stats() + + # Generate hits + if hits > 0: + cache.put("hit_key", "value") + for _ in range(hits): + cache.get("hit_key") + + # Generate misses + for i in range(misses): + cache.get(f"miss_key_{i}") + + with patch.object( + logging.getLogger("nemoguardrails.llm.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + if hits > 0 or misses > 0: + log_message = mock_log.call_args[0][0] + self.assertIn(f"Hit Rate: {expected_rate}", log_message) + + +class TestContentSafetyCacheStatsConfig(unittest.TestCase): + """Test cache stats configuration in content safety context.""" + + def test_cache_config_with_stats_disabled(self): + """Test cache configuration with stats disabled.""" + from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig + + cache_config = ModelCacheConfig( + enabled=True, maxsize=1000, stats=CacheStatsConfig(enabled=False) + ) + + cache = LFUCache( + maxsize=cache_config.maxsize, + track_stats=cache_config.stats.enabled, + stats_logging_interval=None, + ) + + self.assertIsNotNone(cache) + self.assertFalse(cache.track_stats) + self.assertFalse(cache.supports_stats_logging()) + + def test_cache_config_with_stats_tracking_only(self): + """Test cache configuration with stats tracking but no logging.""" + from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig + + cache_config = ModelCacheConfig( + enabled=True, + maxsize=1000, + stats=CacheStatsConfig(enabled=True, log_interval=None), + ) + + cache = LFUCache( + maxsize=cache_config.maxsize, + track_stats=cache_config.stats.enabled, + stats_logging_interval=cache_config.stats.log_interval, + ) + + self.assertIsNotNone(cache) + self.assertTrue(cache.track_stats) + self.assertFalse(cache.supports_stats_logging()) + self.assertIsNone(cache.stats_logging_interval) + + def test_cache_config_with_stats_logging(self): + """Test cache configuration with stats tracking and logging.""" + from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig + + cache_config = ModelCacheConfig( + enabled=True, + maxsize=1000, + stats=CacheStatsConfig(enabled=True, log_interval=60.0), + ) + + cache = LFUCache( + maxsize=cache_config.maxsize, + track_stats=cache_config.stats.enabled, + stats_logging_interval=cache_config.stats.log_interval, + ) + + self.assertIsNotNone(cache) + self.assertTrue(cache.track_stats) + self.assertTrue(cache.supports_stats_logging()) + self.assertEqual(cache.stats_logging_interval, 60.0) + + def test_cache_config_default_stats(self): + """Test cache configuration with default stats settings.""" + from nemoguardrails.rails.llm.config import ModelCacheConfig + + cache_config = ModelCacheConfig(enabled=True) + + cache = LFUCache( + maxsize=cache_config.maxsize, + track_stats=cache_config.stats.enabled, + stats_logging_interval=None, + ) + + self.assertIsNotNone(cache) + self.assertFalse(cache.track_stats) # Default is disabled + self.assertFalse(cache.supports_stats_logging()) + + def test_cache_config_from_dict(self): + """Test cache configuration creation from dictionary.""" + from nemoguardrails.rails.llm.config import ModelCacheConfig + + config_dict = { + "enabled": True, + "maxsize": 5000, + "stats": {"enabled": True, "log_interval": 120.0}, + } + + cache_config = ModelCacheConfig(**config_dict) + self.assertTrue(cache_config.enabled) + self.assertEqual(cache_config.maxsize, 5000) + self.assertTrue(cache_config.stats.enabled) + self.assertEqual(cache_config.stats.log_interval, 120.0) + + def test_cache_config_stats_validation(self): + """Test cache configuration validation for stats settings.""" + from nemoguardrails.rails.llm.config import CacheStatsConfig + + # Valid configurations + stats1 = CacheStatsConfig(enabled=True, log_interval=60.0) + self.assertTrue(stats1.enabled) + self.assertEqual(stats1.log_interval, 60.0) + + stats2 = CacheStatsConfig(enabled=True, log_interval=None) + self.assertTrue(stats2.enabled) + self.assertIsNone(stats2.log_interval) + + stats3 = CacheStatsConfig(enabled=False, log_interval=60.0) + self.assertFalse(stats3.enabled) + self.assertEqual(stats3.log_interval, 60.0) + + +class TestLFUCacheThreadSafety(unittest.TestCase): + """Test thread safety of LFU Cache implementation.""" + + def setUp(self): + """Set up test fixtures.""" + self.cache = LFUCache(100, track_stats=True) + + def test_concurrent_reads_writes(self): + """Test that concurrent reads and writes don't corrupt the cache.""" + num_threads = 10 + operations_per_thread = 100 + # Use a larger cache to avoid evictions during the test + large_cache = LFUCache(2000, track_stats=True) + errors = [] + + def worker(thread_id): + """Worker function that performs cache operations.""" + for i in range(operations_per_thread): + key = f"thread_{thread_id}_key_{i}" + value = f"thread_{thread_id}_value_{i}" + + # Put operation + large_cache.put(key, value) + + # Get operation - should always succeed with large cache + retrieved = large_cache.get(key) + + # Verify data integrity + if retrieved != value: + errors.append( + f"Data corruption for {key}: expected {value}, got {retrieved}" + ) + + # Access some shared keys + shared_key = f"shared_key_{i % 10}" + large_cache.put(shared_key, f"shared_value_{thread_id}_{i}") + large_cache.get(shared_key) + + # Run threads + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + for future in futures: + future.result() # Wait for completion and raise any exceptions + + # Check for any errors + self.assertEqual(len(errors), 0, f"Errors occurred: {errors[:5]}...") + + # Verify cache is still functional + test_key = "test_after_concurrent" + test_value = "test_value" + large_cache.put(test_key, test_value) + self.assertEqual(large_cache.get(test_key), test_value) + + # Check statistics are reasonable + stats = large_cache.get_stats() + self.assertGreater(stats["hits"], 0) + self.assertGreater(stats["puts"], 0) + + def test_concurrent_evictions(self): + """Test that concurrent operations during evictions don't corrupt the cache.""" + # Use a small cache to trigger frequent evictions + small_cache = LFUCache(10) + num_threads = 5 + operations_per_thread = 50 + + def worker(thread_id): + """Worker that adds many items to trigger evictions.""" + for i in range(operations_per_thread): + key = f"t{thread_id}_k{i}" + value = f"t{thread_id}_v{i}" + small_cache.put(key, value) + + # Try to get recently added items + if i > 0: + prev_key = f"t{thread_id}_k{i-1}" + small_cache.get(prev_key) # May or may not exist + + # Run threads + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + for future in futures: + future.result() + + # Cache should still be at maxsize + self.assertEqual(small_cache.size(), 10) + + def test_concurrent_clear_operations(self): + """Test concurrent clear operations with other operations.""" + + def writer(): + """Continuously write to cache.""" + for i in range(100): + self.cache.put(f"key_{i}", f"value_{i}") + time.sleep(0.001) # Small delay + + def clearer(): + """Periodically clear the cache.""" + for _ in range(5): + time.sleep(0.01) + self.cache.clear() + + def reader(): + """Continuously read from cache.""" + for i in range(100): + self.cache.get(f"key_{i}") + time.sleep(0.001) + + # Run operations concurrently + threads = [ + threading.Thread(target=writer), + threading.Thread(target=clearer), + threading.Thread(target=reader), + ] + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + # Cache should still be functional + self.cache.put("final_key", "final_value") + self.assertEqual(self.cache.get("final_key"), "final_value") + + def test_concurrent_stats_operations(self): + """Test that concurrent operations don't corrupt statistics.""" + + def worker(thread_id): + """Worker that performs operations and checks stats.""" + for i in range(50): + key = f"stats_key_{thread_id}_{i}" + self.cache.put(key, i) + self.cache.get(key) # Hit + self.cache.get(f"nonexistent_{thread_id}_{i}") # Miss + + # Periodically check stats + if i % 10 == 0: + stats = self.cache.get_stats() + # Just verify we can get stats without error + self.assertIsInstance(stats, dict) + self.assertIn("hits", stats) + self.assertIn("misses", stats) + + # Run threads + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + for future in futures: + future.result() + + # Final stats check + final_stats = self.cache.get_stats() + self.assertGreater(final_stats["hits"], 0) + self.assertGreater(final_stats["misses"], 0) + self.assertGreater(final_stats["puts"], 0) + + def test_get_or_compute_thread_safety(self): + """Test thread safety of get_or_compute method.""" + compute_count = threading.local() + compute_count.value = 0 + total_computes = [] + lock = threading.Lock() + + async def expensive_compute(): + """Simulate expensive computation that should only run once.""" + # Track how many times this is called + if not hasattr(compute_count, "value"): + compute_count.value = 0 + compute_count.value += 1 + + with lock: + total_computes.append(1) + + # Simulate expensive operation + await asyncio.sleep(0.1) + return f"computed_value_{len(total_computes)}" + + async def worker(thread_id): + """Worker that tries to get or compute the same key.""" + result = await self.cache.get_or_compute( + "shared_compute_key", expensive_compute, default="default" + ) + return result + + async def run_test(): + """Run the async test.""" + # Run multiple workers concurrently + tasks = [worker(i) for i in range(10)] + results = await asyncio.gather(*tasks) + + # All should get the same value + self.assertTrue( + all(r == results[0] for r in results), + f"All threads should get same value, got: {results}", + ) + + # Compute should have been called only once + self.assertEqual( + len(total_computes), + 1, + f"Compute should be called once, called {len(total_computes)} times", + ) + + return results[0] + + # Run the async test + result = asyncio.run(run_test()) + self.assertEqual(result, "computed_value_1") + + def test_get_or_compute_exception_handling(self): + """Test get_or_compute handles exceptions properly. + + NOTE: This test will produce "ValueError: Computation failed" messages in the test output. + These are EXPECTED and NORMAL - the test intentionally triggers failures to verify + that the cache handles exceptions correctly. Each of the 5 workers will generate one + error message, but all workers should receive the fallback value successfully. + """ + # Optional: Uncomment to see a message before the expected errors + # print("\n[test_get_or_compute_exception_handling] Note: The following 5 'ValueError: Computation failed' messages are expected...") + + call_count = [0] + + async def failing_compute(): + """Compute function that fails.""" + call_count[0] += 1 + raise ValueError("Computation failed") + + async def worker(): + """Worker that tries to compute.""" + result = await self.cache.get_or_compute( + "failing_key", failing_compute, default="fallback" + ) + return result + + async def run_test(): + """Run the async test.""" + # Multiple workers should all get the default value + tasks = [worker() for _ in range(5)] + results = await asyncio.gather(*tasks) + + # All should get the default value + self.assertTrue(all(r == "fallback" for r in results)) + + # The compute function might be called multiple times + # since failed computations aren't cached + self.assertGreaterEqual(call_count[0], 1) + + asyncio.run(run_test()) + + def test_thread_safe_size_operations(self): + """Test that size-related operations are thread-safe.""" + results = [] + + def worker(thread_id): + """Worker that checks size consistency.""" + for i in range(100): + # Add item + self.cache.put(f"size_key_{thread_id}_{i}", i) + + # Check size + size = self.cache.size() + is_empty = self.cache.is_empty() + + # Size should never be negative or exceed maxsize + if size < 0 or size > 100: + results.append(f"Invalid size: {size}") + + # is_empty should match size + if (size == 0) != is_empty: + results.append(f"Size {size} but is_empty={is_empty}") + + # Run workers + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(worker, i) for i in range(10)] + for future in futures: + future.result() + + # Check for any inconsistencies + self.assertEqual(len(results), 0, f"Inconsistencies found: {results}") + + def test_concurrent_contains_operations(self): + """Test thread safety of contains method.""" + # Use a larger cache to avoid evictions during the test + # Need maxsize for: 50 existing + (5 threads × 100 new keys) = 550+ + large_cache = LFUCache(1000, track_stats=True) + + # Pre-populate cache + for i in range(50): + large_cache.put(f"existing_key_{i}", f"value_{i}") + + results = [] + eviction_warnings = [] + + def worker(thread_id): + """Worker that checks contains and manipulates cache.""" + for i in range(100): + # Check existing keys + key = f"existing_key_{i % 50}" + if not large_cache.contains(key): + results.append(f"Thread {thread_id}: Missing key {key}") + + # Add new keys + new_key = f"new_key_{thread_id}_{i}" + large_cache.put(new_key, f"value_{thread_id}_{i}") + + # Check new key immediately + if not large_cache.contains(new_key): + # This could happen if cache is full and eviction occurred + # Track it separately as it's not a thread safety issue + eviction_warnings.append( + f"Thread {thread_id}: Key {new_key} possibly evicted" + ) + + # Check non-existent keys + if large_cache.contains(f"non_existent_{thread_id}_{i}"): + results.append(f"Thread {thread_id}: Found non-existent key") + + # Run workers + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + for future in futures: + future.result() + + # Check for any errors (not counting eviction warnings) + self.assertEqual(len(results), 0, f"Errors found: {results}") + + # Eviction warnings should be minimal with large cache + if eviction_warnings: + print(f"Note: {len(eviction_warnings)} keys were evicted during test") + + def test_concurrent_reset_stats(self): + """Test thread safety of reset_stats operations.""" + errors = [] + + def worker(thread_id): + """Worker that performs operations and resets stats.""" + for i in range(50): + # Perform operations + self.cache.put(f"key_{thread_id}_{i}", i) + self.cache.get(f"key_{thread_id}_{i}") + self.cache.get("non_existent") + + # Periodically reset stats + if i % 10 == 0: + self.cache.reset_stats() + + # Check stats integrity + stats = self.cache.get_stats() + if any(v < 0 for v in stats.values() if isinstance(v, (int, float))): + errors.append(f"Thread {thread_id}: Negative stat value: {stats}") + + # Run workers + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + for future in futures: + future.result() + + # Verify no errors + self.assertEqual(len(errors), 0, f"Stats errors: {errors[:5]}") + + def test_get_or_compute_concurrent_different_keys(self): + """Test get_or_compute with different keys being computed concurrently.""" + compute_counts = {} + lock = threading.Lock() + + async def compute_for_key(key): + """Compute function that tracks calls per key.""" + with lock: + compute_counts[key] = compute_counts.get(key, 0) + 1 + await asyncio.sleep(0.05) # Simulate work + return f"value_for_{key}" + + async def worker(thread_id, key_id): + """Worker that computes values for specific keys.""" + key = f"key_{key_id}" + result = await self.cache.get_or_compute( + key, lambda: compute_for_key(key), default="error" + ) + return key, result + + async def run_test(): + """Run concurrent computations for different keys.""" + # Create tasks for multiple keys, with some overlap + tasks = [] + for key_id in range(5): + for thread_id in range(3): # 3 threads per key + tasks.append(worker(thread_id, key_id)) + + results = await asyncio.gather(*tasks) + + # Verify each key was computed exactly once + for key_id in range(5): + key = f"key_{key_id}" + self.assertEqual( + compute_counts.get(key, 0), + 1, + f"{key} should be computed exactly once", + ) + + # Verify all threads got correct values + for key, value in results: + expected = f"value_for_{key}" + self.assertEqual(value, expected) + + asyncio.run(run_test()) + + def test_concurrent_operations_with_evictions(self): + """Test thread safety when cache is at maxsize and evictions occur.""" + # Small cache to force evictions + small_cache = LFUCache(50, track_stats=True) + data_integrity_errors = [] + + def worker(thread_id): + """Worker that handles potential evictions gracefully.""" + for i in range(100): + key = f"t{thread_id}_k{i}" + value = f"t{thread_id}_v{i}" + + # Put value + small_cache.put(key, value) + + # Immediately access to increase frequency + retrieved = small_cache.get(key) + + # Value might be None if evicted immediately (unlikely but possible) + if retrieved is not None and retrieved != value: + # This would indicate actual data corruption + data_integrity_errors.append( + f"Wrong value for {key}: expected {value}, got {retrieved}" + ) + + # Also work with some high-frequency keys (access multiple times) + high_freq_key = f"high_freq_{thread_id % 5}" + for _ in range(3): # Access 3 times to increase frequency + small_cache.put(high_freq_key, f"high_freq_value_{thread_id}") + small_cache.get(high_freq_key) + + # Run workers + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(worker, i) for i in range(10)] + for future in futures: + future.result() + + # Should have no data integrity errors (wrong values) + self.assertEqual( + len(data_integrity_errors), + 0, + f"Data integrity errors: {data_integrity_errors}", + ) + + # Cache should be at maxsize + self.assertEqual(small_cache.size(), 50) + + # Stats should show many evictions + stats = small_cache.get_stats() + self.assertGreater(stats["evictions"], 0) + self.assertGreater(stats["puts"], 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py new file mode 100644 index 000000000..7201e8f0f --- /dev/null +++ b/tests/test_cache_utils.py @@ -0,0 +1,376 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import pytest + +from nemoguardrails.context import llm_call_info_var, llm_stats_var +from nemoguardrails.llm.cache.lfu import LFUCache +from nemoguardrails.llm.cache.utils import ( + create_normalized_cache_key, + extract_llm_stats_for_cache, + get_from_cache_and_restore_stats, + restore_llm_stats_from_cache, +) +from nemoguardrails.logging.explain import LLMCallInfo +from nemoguardrails.logging.processing_log import processing_log_var +from nemoguardrails.logging.stats import LLMStats + + +class TestCacheUtils: + def test_create_normalized_cache_key_returns_sha256_hash(self): + key = create_normalized_cache_key("Hello world") + assert len(key) == 64 + assert all(c in "0123456789abcdef" for c in key) + + @pytest.mark.parametrize( + "prompt", + [ + "Hello world", + "", + " Hello world ", + "Hello world test", + "Hello\t\n\r world", + "Hello \n\t world", + ], + ) + def test_create_normalized_cache_key_with_whitespace_normalization(self, prompt): + key = create_normalized_cache_key(prompt, normalize_whitespace=True) + assert len(key) == 64 + assert all(c in "0123456789abcdef" for c in key) + + @pytest.mark.parametrize( + "prompt", + [ + "Hello world", + "Hello \n\t world", + " spaces ", + ], + ) + def test_create_normalized_cache_key_without_whitespace_normalization(self, prompt): + key = create_normalized_cache_key(prompt, normalize_whitespace=False) + assert len(key) == 64 + assert all(c in "0123456789abcdef" for c in key) + + @pytest.mark.parametrize( + "prompt1,prompt2", + [ + ("Hello \n world", "Hello world"), + ("test\t\nstring", "test string"), + (" leading", "leading"), + ], + ) + def test_create_normalized_cache_key_consistent_for_same_input( + self, prompt1, prompt2 + ): + key1 = create_normalized_cache_key(prompt1, normalize_whitespace=True) + key2 = create_normalized_cache_key(prompt2, normalize_whitespace=True) + assert key1 == key2 + + @pytest.mark.parametrize( + "prompt1,prompt2", + [ + ("Hello world", "Hello world!"), + ("test", "testing"), + ("case", "Case"), + ], + ) + def test_create_normalized_cache_key_different_for_different_input( + self, prompt1, prompt2 + ): + key1 = create_normalized_cache_key(prompt1) + key2 = create_normalized_cache_key(prompt2) + assert key1 != key2 + + def test_create_normalized_cache_key_invalid_type_raises_error(self): + with pytest.raises(TypeError, match="Invalid type for prompt: int"): + create_normalized_cache_key(123) + + with pytest.raises(TypeError, match="Invalid type for prompt: dict"): + create_normalized_cache_key({"key": "value"}) + + def test_create_normalized_cache_key_list_of_dicts(self): + messages = [ + {"type": "user", "content": "Hello"}, + {"type": "assistant", "content": "Hi there!"}, + ] + key = create_normalized_cache_key(messages) + assert len(key) == 64 + assert all(c in "0123456789abcdef" for c in key) + + def test_create_normalized_cache_key_list_of_dicts_order_independent(self): + messages1 = [ + {"content": "Hello", "role": "user"}, + {"content": "Hi there!", "role": "assistant"}, + ] + messages2 = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + key1 = create_normalized_cache_key(messages1) + key2 = create_normalized_cache_key(messages2) + assert key1 == key2 + + def test_create_normalized_cache_key_invalid_list_raises_error(self): + with pytest.raises( + TypeError, + match="All elements in prompt list must be dictionaries", + ): + create_normalized_cache_key(["hello", "world"]) + + with pytest.raises( + TypeError, + match="All elements in prompt list must be dictionaries", + ): + create_normalized_cache_key([{"key": "value"}, "test"]) + + with pytest.raises( + TypeError, + match="All elements in prompt list must be dictionaries", + ): + create_normalized_cache_key([123, 456]) + + def test_extract_llm_stats_for_cache_with_llm_call_info(self): + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info.total_tokens = 100 + llm_call_info.prompt_tokens = 50 + llm_call_info.completion_tokens = 50 + llm_call_info_var.set(llm_call_info) + + stats = extract_llm_stats_for_cache() + + assert stats is not None + assert stats["total_tokens"] == 100 + assert stats["prompt_tokens"] == 50 + assert stats["completion_tokens"] == 50 + + llm_call_info_var.set(None) + + def test_extract_llm_stats_for_cache_without_llm_call_info(self): + llm_call_info_var.set(None) + + stats = extract_llm_stats_for_cache() + + assert stats is None + + def test_extract_llm_stats_for_cache_with_none_values(self): + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info.total_tokens = None + llm_call_info.prompt_tokens = None + llm_call_info.completion_tokens = None + llm_call_info_var.set(llm_call_info) + + stats = extract_llm_stats_for_cache() + + assert stats is not None + assert stats["total_tokens"] == 0 + assert stats["prompt_tokens"] == 0 + assert stats["completion_tokens"] == 0 + + llm_call_info_var.set(None) + + def test_restore_llm_stats_from_cache_creates_new_llm_stats(self): + llm_stats_var.set(None) + llm_call_info_var.set(None) + + cached_stats = { + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + } + + restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.01) + + llm_stats = llm_stats_var.get() + assert llm_stats is not None + assert llm_stats.get_stat("total_calls") == 1 + assert llm_stats.get_stat("total_time") == 0.01 + assert llm_stats.get_stat("total_tokens") == 100 + assert llm_stats.get_stat("total_prompt_tokens") == 50 + assert llm_stats.get_stat("total_completion_tokens") == 50 + + llm_stats_var.set(None) + + def test_restore_llm_stats_from_cache_updates_existing_llm_stats(self): + llm_stats = LLMStats() + llm_stats.inc("total_calls", 5) + llm_stats.inc("total_time", 1.0) + llm_stats.inc("total_tokens", 200) + llm_stats_var.set(llm_stats) + + cached_stats = { + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + } + + restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.5) + + llm_stats = llm_stats_var.get() + assert llm_stats.get_stat("total_calls") == 6 + assert llm_stats.get_stat("total_time") == 1.5 + assert llm_stats.get_stat("total_tokens") == 300 + + llm_stats_var.set(None) + + def test_restore_llm_stats_from_cache_updates_llm_call_info(self): + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info_var.set(llm_call_info) + llm_stats_var.set(None) + + cached_stats = { + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + } + + restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.02) + + updated_info = llm_call_info_var.get() + assert updated_info is not None + assert updated_info.duration == 0.02 + assert updated_info.total_tokens == 100 + assert updated_info.prompt_tokens == 50 + assert updated_info.completion_tokens == 50 + assert updated_info.from_cache is True + assert updated_info.started_at is not None + assert updated_info.finished_at is not None + + llm_call_info_var.set(None) + llm_stats_var.set(None) + + def test_get_from_cache_and_restore_stats_cache_miss(self): + cache = LFUCache(maxsize=10) + llm_call_info_var.set(None) + llm_stats_var.set(None) + + result = get_from_cache_and_restore_stats(cache, "nonexistent_key") + + assert result is None + + llm_call_info_var.set(None) + llm_stats_var.set(None) + + def test_get_from_cache_and_restore_stats_cache_hit(self): + cache = LFUCache(maxsize=10) + cache_entry = { + "result": {"allowed": True, "policy_violations": []}, + "llm_stats": { + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + }, + } + cache.put("test_key", cache_entry) + + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info_var.set(llm_call_info) + llm_stats_var.set(None) + + result = get_from_cache_and_restore_stats(cache, "test_key") + + assert result is not None + assert result == {"allowed": True, "policy_violations": []} + + llm_stats = llm_stats_var.get() + assert llm_stats is not None + assert llm_stats.get_stat("total_calls") == 1 + assert llm_stats.get_stat("total_tokens") == 100 + + updated_info = llm_call_info_var.get() + assert updated_info.from_cache is True + + llm_call_info_var.set(None) + llm_stats_var.set(None) + + def test_get_from_cache_and_restore_stats_without_llm_stats(self): + cache = LFUCache(maxsize=10) + cache_entry = { + "result": {"allowed": False, "policy_violations": ["policy1"]}, + "llm_stats": None, + } + cache.put("test_key", cache_entry) + + llm_call_info_var.set(None) + llm_stats_var.set(None) + + result = get_from_cache_and_restore_stats(cache, "test_key") + + assert result is not None + assert result == {"allowed": False, "policy_violations": ["policy1"]} + + llm_call_info_var.set(None) + llm_stats_var.set(None) + + def test_get_from_cache_and_restore_stats_with_processing_log(self): + cache = LFUCache(maxsize=10) + cache_entry = { + "result": {"allowed": True, "policy_violations": []}, + "llm_stats": { + "total_tokens": 80, + "prompt_tokens": 60, + "completion_tokens": 20, + }, + } + cache.put("test_key", cache_entry) + + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info_var.set(llm_call_info) + llm_stats_var.set(None) + + processing_log = [] + processing_log_var.set(processing_log) + + result = get_from_cache_and_restore_stats(cache, "test_key") + + assert result is not None + assert result == {"allowed": True, "policy_violations": []} + + retrieved_log = processing_log_var.get() + assert len(retrieved_log) == 1 + assert retrieved_log[0]["type"] == "llm_call_info" + assert "timestamp" in retrieved_log[0] + assert "data" in retrieved_log[0] + assert retrieved_log[0]["data"] == llm_call_info + + llm_call_info_var.set(None) + llm_stats_var.set(None) + processing_log_var.set(None) + + def test_get_from_cache_and_restore_stats_without_processing_log(self): + cache = LFUCache(maxsize=10) + cache_entry = { + "result": {"allowed": True, "policy_violations": []}, + "llm_stats": { + "total_tokens": 50, + "prompt_tokens": 30, + "completion_tokens": 20, + }, + } + cache.put("test_key", cache_entry) + + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info_var.set(llm_call_info) + llm_stats_var.set(None) + processing_log_var.set(None) + + result = get_from_cache_and_restore_stats(cache, "test_key") + + assert result is not None + assert result == {"allowed": True, "policy_violations": []} + + llm_call_info_var.set(None) + llm_stats_var.set(None) diff --git a/tests/test_content_safety_cache.py b/tests/test_content_safety_cache.py new file mode 100644 index 000000000..d503f95eb --- /dev/null +++ b/tests/test_content_safety_cache.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import pytest + +from nemoguardrails.context import llm_call_info_var, llm_stats_var +from nemoguardrails.library.content_safety.actions import content_safety_check_input +from nemoguardrails.llm.cache.lfu import LFUCache +from nemoguardrails.llm.cache.utils import create_normalized_cache_key +from nemoguardrails.logging.explain import LLMCallInfo +from nemoguardrails.logging.stats import LLMStats +from tests.utils import FakeLLM + + +@pytest.fixture +def mock_task_manager(): + tm = MagicMock() + tm.render_task_prompt.return_value = "test prompt" + tm.get_stop_tokens.return_value = [] + tm.get_max_tokens.return_value = 3 + tm.parse_task_output.return_value = [True, "policy1"] + return tm + + +@pytest.fixture +def fake_llm_with_stats(): + llm = FakeLLM(responses=["safe"]) + return {"test_model": llm} + + +@pytest.mark.asyncio +async def test_content_safety_cache_stores_result_and_stats( + fake_llm_with_stats, mock_task_manager +): + cache = LFUCache(maxsize=10) + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + result = await content_safety_check_input( + llms=fake_llm_with_stats, + llm_task_manager=mock_task_manager, + model_name="test_model", + context={"user_message": "test input"}, + model_caches={"test_model": cache}, + ) + + assert result["allowed"] is True + assert result["policy_violations"] == ["policy1"] + assert cache.size() == 1 + + llm_call_info = llm_call_info_var.get() + + cache_key = create_normalized_cache_key("test prompt") + cached_entry = cache.get(cache_key) + assert cached_entry is not None + assert "result" in cached_entry + assert "llm_stats" in cached_entry + + if llm_call_info and ( + llm_call_info.total_tokens + or llm_call_info.prompt_tokens + or llm_call_info.completion_tokens + ): + assert cached_entry["llm_stats"] is not None + else: + assert cached_entry["llm_stats"] is None or all( + v == 0 for v in cached_entry["llm_stats"].values() + ) + + +@pytest.mark.asyncio +async def test_content_safety_cache_retrieves_result_and_restores_stats( + fake_llm_with_stats, mock_task_manager +): + cache = LFUCache(maxsize=10) + + cache_entry = { + "result": {"allowed": True, "policy_violations": ["policy1"]}, + "llm_stats": { + "total_tokens": 100, + "prompt_tokens": 80, + "completion_tokens": 20, + }, + } + cache_key = create_normalized_cache_key("test prompt") + cache.put(cache_key, cache_entry) + + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + result = await content_safety_check_input( + llms=fake_llm_with_stats, + llm_task_manager=mock_task_manager, + model_name="test_model", + context={"user_message": "test input"}, + model_caches={"test_model": cache}, + ) + + llm_call_info = llm_call_info_var.get() + + assert result == cache_entry["result"] + assert llm_stats.get_stat("total_calls") == 1 + assert llm_stats.get_stat("total_tokens") == 100 + assert llm_stats.get_stat("total_prompt_tokens") == 80 + assert llm_stats.get_stat("total_completion_tokens") == 20 + + assert llm_call_info.from_cache is True + assert llm_call_info.total_tokens == 100 + assert llm_call_info.prompt_tokens == 80 + assert llm_call_info.completion_tokens == 20 + assert llm_call_info.duration is not None + + +@pytest.mark.asyncio +async def test_content_safety_cache_duration_reflects_cache_read_time( + fake_llm_with_stats, mock_task_manager +): + cache = LFUCache(maxsize=10) + + cache_entry = { + "result": {"allowed": True, "policy_violations": []}, + "llm_stats": { + "total_tokens": 50, + "prompt_tokens": 40, + "completion_tokens": 10, + }, + } + cache_key = create_normalized_cache_key("test prompt") + cache.put(cache_key, cache_entry) + + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + await content_safety_check_input( + llms=fake_llm_with_stats, + llm_task_manager=mock_task_manager, + model_name="test_model", + context={"user_message": "test input"}, + model_caches={"test_model": cache}, + ) + + llm_call_info = llm_call_info_var.get() + cache_duration = llm_call_info.duration + + assert cache_duration is not None + assert cache_duration < 0.1 + assert llm_call_info.from_cache is True + + +@pytest.mark.asyncio +async def test_content_safety_without_cache_does_not_store( + fake_llm_with_stats, mock_task_manager +): + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + llm_call_info = LLMCallInfo(task="content_safety_check_input $model=test_model") + llm_call_info_var.set(llm_call_info) + + result = await content_safety_check_input( + llms=fake_llm_with_stats, + llm_task_manager=mock_task_manager, + model_name="test_model", + context={"user_message": "test input"}, + ) + + assert result["allowed"] is True + assert llm_call_info.from_cache is False + + +@pytest.mark.asyncio +async def test_content_safety_cache_handles_missing_stats_gracefully( + fake_llm_with_stats, mock_task_manager +): + cache = LFUCache(maxsize=10) + + cache_entry = { + "result": {"allowed": True, "policy_violations": []}, + "llm_stats": None, + } + cache_key = create_normalized_cache_key("test_key") + cache.put(cache_key, cache_entry) + + mock_task_manager.render_task_prompt.return_value = "test_key" + + llm_stats = LLMStats() + llm_stats_var.set(llm_stats) + + llm_call_info = LLMCallInfo(task="content_safety_check_input $model=test_model") + llm_call_info_var.set(llm_call_info) + + result = await content_safety_check_input( + llms=fake_llm_with_stats, + llm_task_manager=mock_task_manager, + model_name="test_model", + context={"user_message": "test input"}, + model_caches={"test_model": cache}, + ) + + assert result["allowed"] is True + assert llm_stats.get_stat("total_calls") == 0 diff --git a/tests/test_integration_cache.py b/tests/test_integration_cache.py new file mode 100644 index 000000000..81719cc7f --- /dev/null +++ b/tests/test_integration_cache.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails.rails.llm.config import CacheStatsConfig, Model, ModelCacheConfig +from tests.utils import FakeLLM + + +@pytest.mark.asyncio +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +async def test_end_to_end_cache_integration_with_content_safety(mock_init_llm_model): + mock_llm = FakeLLM(responses=["express greeting"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + ), + Model( + type="content_safety", + engine="fake", + model="fake-content-safety", + cache=ModelCacheConfig( + enabled=True, + maxsize=100, + stats=CacheStatsConfig(enabled=True), + ), + ), + ] + ) + + llm = FakeLLM(responses=["express greeting"]) + + rails = LLMRails(config=config, llm=llm, verbose=False) + + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + assert "content_safety" in model_caches + cache = model_caches["content_safety"] + assert cache is not None + + assert cache.size() == 0 + + messages = [{"role": "user", "content": "Hello!"}] + await rails.generate_async(messages=messages) + + if cache.size() > 0: + initial_stats = cache.get_stats() + await rails.generate_async(messages=messages) + second_stats = cache.get_stats() + assert second_stats["hits"] >= initial_stats["hits"] + + +@pytest.mark.asyncio +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +async def test_cache_isolation_between_models(mock_init_llm_model): + mock_llm = FakeLLM(responses=["safe"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + ), + Model( + type="content_safety", + engine="fake", + model="fake-content-safety", + cache=ModelCacheConfig(enabled=True, maxsize=50), + ), + Model( + type="jailbreak_detection", + engine="fake", + model="fake-jailbreak", + cache=ModelCacheConfig(enabled=True, maxsize=100), + ), + ] + ) + + llm = FakeLLM(responses=["safe"]) + + rails = LLMRails(config=config, llm=llm, verbose=False) + + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + assert "content_safety" in model_caches + assert "jailbreak_detection" in model_caches + + content_safety_cache = model_caches["content_safety"] + jailbreak_cache = model_caches["jailbreak_detection"] + + assert content_safety_cache is not jailbreak_cache + assert content_safety_cache.maxsize == 50 + assert jailbreak_cache.maxsize == 100 + + content_safety_cache.put("key1", "value1") + assert content_safety_cache.get("key1") == "value1" + assert jailbreak_cache.get("key1") is None + + +@pytest.mark.asyncio +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +async def test_cache_disabled_for_main_model_in_integration(mock_init_llm_model): + mock_llm = FakeLLM(responses=["safe"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + cache=ModelCacheConfig(enabled=True, maxsize=100), + ), + Model( + type="content_safety", + engine="fake", + model="fake-content-safety", + cache=ModelCacheConfig(enabled=True, maxsize=100), + ), + ] + ) + + llm = FakeLLM(responses=["safe"]) + + rails = LLMRails(config=config, llm=llm, verbose=False) + + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + assert "main" not in model_caches or model_caches["main"] is None + assert "content_safety" in model_caches + assert model_caches["content_safety"] is not None diff --git a/tests/test_llmrails.py b/tests/test_llmrails.py index 9b8a2b300..a5f28a986 100644 --- a/tests/test_llmrails.py +++ b/tests/test_llmrails.py @@ -1187,3 +1187,188 @@ def test_explain_calls_ensure_explain_info(): info = rails.explain() assert info == ExplainInfo() assert rails.explain_info == ExplainInfo() + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_cache_initialization_disabled_by_default(mock_init_llm_model): + mock_llm = FakeLLM(responses=["response"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + ), + Model( + type="content_safety", + engine="fake", + model="fake", + ), + ] + ) + + rails = LLMRails(config=config, verbose=False) + model_caches = rails.runtime.registered_action_params.get("model_caches") + + assert model_caches is None or len(model_caches) == 0 + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_cache_initialization_with_enabled_cache(mock_init_llm_model): + from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig + + mock_llm = FakeLLM(responses=["response"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + ), + Model( + type="content_safety", + engine="fake", + model="fake", + cache=ModelCacheConfig( + enabled=True, + maxsize=1000, + stats=CacheStatsConfig(enabled=False), + ), + ), + ] + ) + + rails = LLMRails(config=config, verbose=False) + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + + assert "content_safety" in model_caches + assert model_caches["content_safety"] is not None + assert model_caches["content_safety"].maxsize == 1000 + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_cache_not_created_for_main_and_embeddings_models(mock_init_llm_model): + from nemoguardrails.rails.llm.config import ModelCacheConfig + + mock_llm = FakeLLM(responses=["response"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + cache=ModelCacheConfig(enabled=True, maxsize=1000), + ), + Model( + type="embeddings", + engine="fake", + model="fake", + cache=ModelCacheConfig(enabled=True, maxsize=1000), + ), + ] + ) + + rails = LLMRails(config=config, verbose=False) + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + + assert "main" not in model_caches + assert "embeddings" not in model_caches + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_cache_initialization_with_zero_maxsize_raises_error(mock_init_llm_model): + from nemoguardrails.rails.llm.config import ModelCacheConfig + + mock_llm = FakeLLM(responses=["response"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="content_safety", + engine="fake", + model="fake", + cache=ModelCacheConfig(enabled=True, maxsize=0), + ), + ] + ) + + with pytest.raises(ValueError, match="Invalid cache maxsize"): + LLMRails(config=config, verbose=False) + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_cache_initialization_with_stats_enabled(mock_init_llm_model): + from nemoguardrails.rails.llm.config import CacheStatsConfig, ModelCacheConfig + + mock_llm = FakeLLM(responses=["response"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="content_safety", + engine="fake", + model="fake", + cache=ModelCacheConfig( + enabled=True, + maxsize=5000, + stats=CacheStatsConfig(enabled=True, log_interval=60.0), + ), + ), + ] + ) + + rails = LLMRails(config=config, verbose=False) + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + + cache = model_caches["content_safety"] + assert cache is not None + assert cache.track_stats is True + assert cache.stats_logging_interval == 60.0 + assert cache.supports_stats_logging() is True + + +@patch("nemoguardrails.rails.llm.llmrails.init_llm_model") +def test_cache_initialization_with_multiple_models(mock_init_llm_model): + from nemoguardrails.rails.llm.config import ModelCacheConfig + + mock_llm = FakeLLM(responses=["response"]) + mock_init_llm_model.return_value = mock_llm + + config = RailsConfig( + models=[ + Model( + type="main", + engine="fake", + model="fake", + ), + Model( + type="content_safety", + engine="fake", + model="fake", + cache=ModelCacheConfig(enabled=True, maxsize=1000), + ), + Model( + type="jailbreak_detection", + engine="fake", + model="fake", + cache=ModelCacheConfig(enabled=True, maxsize=2000), + ), + ] + ) + + rails = LLMRails(config=config, verbose=False) + model_caches = rails.runtime.registered_action_params.get("model_caches", {}) + + assert "main" not in model_caches + assert "content_safety" in model_caches + assert "jailbreak_detection" in model_caches + assert model_caches["content_safety"].maxsize == 1000 + assert model_caches["jailbreak_detection"].maxsize == 2000 diff --git a/tests/tracing/spans/test_span_extractors.py b/tests/tracing/spans/test_span_extractors.py index 9c9c85c05..9b2c24db1 100644 --- a/tests/tracing/spans/test_span_extractors.py +++ b/tests/tracing/spans/test_span_extractors.py @@ -25,6 +25,7 @@ SpanLegacy, create_span_extractor, ) +from nemoguardrails.tracing.constants import GuardrailsAttributes from nemoguardrails.tracing.spans import LLMSpan, is_opentelemetry_span @@ -181,6 +182,78 @@ def test_span_extractor_conversation_events(self, test_data): # Content not included by default (privacy) assert "final_transcript" not in user_event.body + def test_span_extractor_cache_hit_attribute(self): + """Test that cached LLM calls are marked with cache_hit typed field.""" + llm_call_cached = LLMCallInfo( + task="generate_user_intent", + prompt="What is the weather?", + completion="I cannot provide weather information.", + llm_model_name="gpt-4", + llm_provider_name="openai", + prompt_tokens=10, + completion_tokens=20, + total_tokens=30, + started_at=time.time(), + finished_at=time.time() + 0.001, + duration=0.001, + from_cache=True, + ) + + llm_call_not_cached = LLMCallInfo( + task="generate_bot_message", + prompt="Generate a response", + completion="Here is a response", + llm_model_name="gpt-3.5-turbo", + llm_provider_name="openai", + prompt_tokens=5, + completion_tokens=15, + total_tokens=20, + started_at=time.time(), + finished_at=time.time() + 1.0, + duration=1.0, + from_cache=False, + ) + + action = ExecutedAction( + action_name="test_action", + action_params={}, + llm_calls=[llm_call_cached, llm_call_not_cached], + started_at=time.time(), + finished_at=time.time() + 1.5, + duration=1.5, + ) + + rail = ActivatedRail( + type="input", + name="test_rail", + decisions=["continue"], + executed_actions=[action], + stop=False, + started_at=time.time(), + finished_at=time.time() + 2.0, + duration=2.0, + ) + + extractor = SpanExtractorV2() + spans = extractor.extract_spans([rail]) + + llm_spans = [s for s in spans if isinstance(s, LLMSpan)] + assert len(llm_spans) == 2 + + cached_span = next(s for s in llm_spans if "gpt-4" in s.name) + assert cached_span.cache_hit is True + + attributes = cached_span.to_otel_attributes() + assert GuardrailsAttributes.LLM_CACHE_HIT in attributes + assert attributes[GuardrailsAttributes.LLM_CACHE_HIT] is True + + not_cached_span = next(s for s in llm_spans if "gpt-3.5-turbo" in s.name) + assert not_cached_span.cache_hit is False + + attributes = not_cached_span.to_otel_attributes() + assert GuardrailsAttributes.LLM_CACHE_HIT in attributes + assert attributes[GuardrailsAttributes.LLM_CACHE_HIT] is False + class TestSpanFormatConfiguration: """Test span format configuration and factory."""