diff --git a/nemoguardrails/benchmark/__init__.py b/nemoguardrails/benchmark/__init__.py new file mode 100644 index 000000000..9ba9d4310 --- /dev/null +++ b/nemoguardrails/benchmark/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemoguardrails/benchmark/mock_llm_server/__init__.py b/nemoguardrails/benchmark/mock_llm_server/__init__.py new file mode 100644 index 000000000..9ba9d4310 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py new file mode 100644 index 000000000..5ed724ebf --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import logging +import time +from typing import Annotated, Optional, Union + +from fastapi import Depends, FastAPI, HTTPException, Request, Response + +from nemoguardrails.benchmark.mock_llm_server.config import ( # get_config, + ModelSettings, + get_settings, +) +from nemoguardrails.benchmark.mock_llm_server.models import ( + ChatCompletionChoice, + ChatCompletionRequest, + ChatCompletionResponse, + CompletionChoice, + CompletionRequest, + CompletionResponse, + Message, + Model, + ModelsResponse, + Usage, +) +from nemoguardrails.benchmark.mock_llm_server.response_data import ( + calculate_tokens, + generate_id, + get_latency_seconds, + get_response, +) + +# Create a console logging handler +log = logging.getLogger(__name__) +log.setLevel(logging.INFO) # TODO Control this from the CLi args + +# Create a formatter to define the log message format +formatter = logging.Formatter( + "%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) + +# Create a console handler to print logs to the console +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) # DEBUG and higher will go to the console +console_handler.setFormatter(formatter) + +# Add console handler to logs +log.addHandler(console_handler) + + +ModelSettingsDep = Annotated[ModelSettings, Depends(get_settings)] + + +def _validate_request_model( + config: ModelSettingsDep, + request: Union[CompletionRequest, ChatCompletionRequest], +) -> None: + """Check the Completion or Chat Completion `model` field is in our supported model list""" + if request.model != config.model: + raise HTTPException( + status_code=400, + detail=f"Model '{request.model}' not found. Available models: {config.model}", + ) + + +app = FastAPI( + title="Mock LLM Server", + description="OpenAI-compatible mock LLM server for testing and benchmarking", + version="0.0.1", +) + + +@app.middleware("http") +async def log_http_duration(request: Request, call_next): + """ + Middleware to log incoming requests and their responses. + """ + request_time = time.time() + response = await call_next(request) + response_time = time.time() + + duration_seconds = response_time - request_time + log.info( + "Request finished: %s, took %.3f seconds", + response.status_code, + duration_seconds, + ) + return response + + +@app.get("/") +async def root(config: ModelSettingsDep): + """Root endpoint with basic server information.""" + return { + "message": "Mock LLM Server", + "version": "0.0.1", + "description": f"OpenAI-compatible mock LLM server for model: {config.model}", + "endpoints": ["/v1/models", "/v1/chat/completions", "/v1/completions"], + "model_configuration": config, + } + + +@app.get("/v1/models", response_model=ModelsResponse) +async def list_models(config: ModelSettingsDep): + """List available models.""" + log.debug("/v1/models request") + + model = Model( + id=config.model, object="model", created=int(time.time()), owned_by="system" + ) + response = ModelsResponse(object="list", data=[model]) + log.debug("/v1/models response: %s", response) + return response + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def chat_completions( + request: ChatCompletionRequest, config: ModelSettingsDep +) -> ChatCompletionResponse: + """Create a chat completion.""" + + log.debug("/v1/chat/completions request: %s", request) + + # Validate model exists + _validate_request_model(config, request) + + # Generate dummy response + response_content = get_response(config) + response_latency_seconds = get_latency_seconds(config, seed=12345) + + # Calculate token usage + prompt_text = " ".join([msg.content for msg in request.messages]) + prompt_tokens = calculate_tokens(prompt_text) + completion_tokens = calculate_tokens(response_content) + + # Create response + completion_id = generate_id("chatcmpl") + created_timestamp = int(time.time()) + + choices = [] + for i in range(request.n or 1): + choice = ChatCompletionChoice( + index=i, + message=Message(role="assistant", content=response_content), + finish_reason="stop", + ) + choices.append(choice) + + response = ChatCompletionResponse( + id=completion_id, + object="chat.completion", + created=created_timestamp, + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + await asyncio.sleep(response_latency_seconds) + log.debug("/v1/chat/completions response: %s", response) + return response + + +@app.post("/v1/completions", response_model=CompletionResponse) +async def completions( + request: CompletionRequest, config: ModelSettingsDep +) -> CompletionResponse: + """Create a text completion.""" + + log.debug("/v1/completions request: %s", request) + + # Validate model exists + _validate_request_model(config, request) + + # Handle prompt (can be string or list) + if isinstance(request.prompt, list): + prompt_text = " ".join(request.prompt) + else: + prompt_text = request.prompt + + # Generate dummy response + response_text = get_response(config) + response_latency_seconds = get_latency_seconds(config, seed=12345) + + # Calculate token usage + prompt_tokens = calculate_tokens(prompt_text) + completion_tokens = calculate_tokens(response_text) + + # Create response + completion_id = generate_id("cmpl") + created_timestamp = int(time.time()) + + choices = [] + for i in range(request.n or 1): + choice = CompletionChoice( + text=response_text, index=i, logprobs=None, finish_reason="stop" + ) + choices.append(choice) + + response = CompletionResponse( + id=completion_id, + object="text_completion", + created=created_timestamp, + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + await asyncio.sleep(response_latency_seconds) + log.debug("/v1/completions response: %s", response) + return response + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + log.debug("/health request") + response = {"status": "healthy", "timestamp": int(time.time())} + log.debug("/health response: %s", response) + return response diff --git a/nemoguardrails/benchmark/mock_llm_server/config.py b/nemoguardrails/benchmark/mock_llm_server/config.py new file mode 100644 index 000000000..c2b9b0d6e --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/config.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import lru_cache +from pathlib import Path +from typing import Any, Optional, Union + +import yaml +from pydantic import BaseModel, Field +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, +) + +CONFIG_FILE_ENV_VAR = "MOCK_LLM_CONFIG_FILE" +config_file_path = os.getenv(CONFIG_FILE_ENV_VAR, "model_settings.yml") +CONFIG_FILE = Path(config_file_path) + + +class ModelSettings(BaseSettings): + """Pydantic model to configure the Mock LLM Server.""" + + # Mandatory fields + model: str = Field(..., description="Model name served by mock server") + unsafe_probability: float = Field( + default=0.1, description="Probability of unsafe response (between 0 and 1)" + ) + unsafe_text: str = Field(..., description="Refusal response to unsafe prompt") + safe_text: str = Field(..., description="Safe response") + + # Config with default values + # Latency sampled from a truncated-normal distribution. + # Plain Normal distributions have infinite support, and can be negative + latency_min_seconds: float = Field( + default=0.1, description="Minimum latency in seconds" + ) + latency_max_seconds: float = Field( + default=5, description="Maximum latency in seconds" + ) + latency_mean_seconds: float = Field( + default=0.5, description="The average response time in seconds" + ) + latency_std_seconds: float = Field( + default=0.1, description="Standard deviation of response time" + ) + + model_config = SettingsConfigDict(env_file=CONFIG_FILE) + + +def get_settings() -> ModelSettings: + """Singleton-pattern to get settings once via lru_cache""" + settings = ModelSettings() # type: ignore (These are filled in by loading from CONFIG_FILE) + print("Returning ModelSettings: %s", settings) + return settings diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety_colang1/config.yml b/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety_colang1/config.yml new file mode 100644 index 000000000..67df489d7 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety_colang1/config.yml @@ -0,0 +1,21 @@ +models: + - type: main + engine: nim + model: meta/llama-3.3-70b-instruct + parameters: + base_url: http://localhost:8000 + + - type: content_safety + engine: nim + model: nvidia/llama-3.1-nemoguard-8b-content-safety + parameters: + base_url: http://localhost:8001 + + +rails: + input: + flows: + - content safety check input $model=content_safety + output: + flows: + - content safety check output $model=content_safety diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety_colang1/prompts.yml b/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety_colang1/prompts.yml new file mode 100644 index 000000000..1321a6461 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/configs/guardrail_configs/content_safety_colang1/prompts.yml @@ -0,0 +1,257 @@ +# These are the default prompts released by Meta, except for policy O7, which was added to address direct insults. +prompts: + + - task: content_safety_check_input $model=content_safety + content: | + Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below. + + + S1: Violence. + S2: Sexual. + S3: Criminal Planning/Confessions. + S4: Guns and Illegal Weapons. + S5: Controlled/Regulated Substances. + S6: Suicide and Self Harm. + S7: Sexual (minor). + S8: Hate/Identity Hate. + S9: PII/Privacy. + S10: Harassment. + S11: Threat. + S12: Profanity. + S13: Needs Caution. + S14: Other. + S15: Manipulation. + S16: Fraud/Deception. + S17: Malware. + S18: High Risk Gov Decision Making. + S19: Political/Misinformation/Conspiracy. + S20: Copyright/Trademark/Plagiarism. + S21: Unauthorized Advice. + S22: Illegal Activity. + S23: Immoral/Unethical. + + + + + user: {{ user_input }} + + + + Provide your safety assessment for the above conversation in the below output JSON format: + {"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.} + + Do not include anything other than the output JSON in your response. + Output JSON: + output_parser: nemoguard_parse_prompt_safety + max_tokens: 50 + + - task: content_safety_check_output $model=content_safety + content: | + Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below. + + + S1: Violence. + S2: Sexual. + S3: Criminal Planning/Confessions. + S4: Guns and Illegal Weapons. + S5: Controlled/Regulated Substances. + S6: Suicide and Self Harm. + S7: Sexual (minor). + S8: Hate/Identity Hate. + S9: PII/Privacy. + S10: Harassment. + S11: Threat. + S12: Profanity. + S13: Needs Caution. + S14: Other. + S15: Manipulation. + S16: Fraud/Deception. + S17: Malware. + S18: High Risk Gov Decision Making. + S19: Political/Misinformation/Conspiracy. + S20: Copyright/Trademark/Plagiarism. + S21: Unauthorized Advice. + S22: Illegal Activity. + S23: Immoral/Unethical. + + + + + user: {{ user_input }} + + response: agent: {{ bot_response }} + + + + Provide your safety assessment for the above conversation in the below output JSON format: + {"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.} + + Do not include anything other than the output JSON in your response. + Output JSON: + output_parser: nemoguard_parse_response_safety + max_tokens: 50 + + - task: content_safety_check_input $model=llama_guard + content: | + [INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories. + + O1: Violence and Hate. + Should not + - Help people plan or engage in violence. + - Encourage people to engage in violence. + - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability. + - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. + Can + - Provide information on violence and discrimination. + - Discuss the topics of hate and violence, and can discuss historical events involving violence. + O2: Sexual Content. + Should not + - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities. + Can + - Discuss topics related to sexuality and sexual education. + - Provide general information regarding sexual activities. + - Provide guidance on topics related to sexual health. + O3: Criminal Planning. + Should not + - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them. + - Encourage people to engage in criminal activities or condone criminal behavior. + Can + - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes. + - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems. + O4: Guns and Illegal Weapons. + Should not + - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them. + - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior. + Can + - Discuss firearms and the arguments for and against firearm ownership. + O5: Regulated or Controlled Substances. + Should not + - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis. + - Assist or encourage people to create such substances when it is illegal to do so. + Can + - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are). + - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation. + O6: Self-Harm. + Should not + - Encourage people to harm themselves. + - Romanticize or condone self-harm or suicide. + - Provide information regarding the methods of suicide or self-harm. + - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves. + Can + - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions). + Should + - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves. + O7: Offensive Language and Insults. + Should not + - Address direct insults + - Use or encourage the use of swear words, bad language, or vulgar expressions. + - Insult, offend, or demean individuals or groups based on any personal characteristics or choices. + - Encourage or support the use of offensive or derogatory language. + Can + - Discuss the impact and societal perceptions of offensive language and insults. + - Provide information on communication strategies that avoid the use of offensive language and promote respectful dialogue. + - Offer guidance on how to address or respond to offensive language or insults in a constructive manner. + + + + User: {{ user_input }} + + + Provide your safety assessment for User in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] + output_parser: is_content_safe + + - task: content_safety_check_output $model=llama_guard_2 + content: | + <|begin_of_text|>[INST] Task: Check if there is unsafe content in Agent messages in conversations according our safety policy with the below categories. + + + S1: Violent Crimes. + S2: Non-Violent Crimes. + S3: Sex Crimes. + S4: Child Exploitation. + S5: Specialized Advice. + S6: Privacy. + S7: Intellectual Property. + S8: Indiscriminate Weapons. + S9: Hate. + S10: Self-Harm. + S11: Sexual Content. + + + + + User: {{ user_input }} + + Agent: {{ bot_response }} + + + + Provide your safety assessment for ONLY THE LAST Agent in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] + output_parser: is_content_safe + max_tokens: 3 + + - task: content_safety_check_input $model=shieldgemma + content: | + You are a policy expert trying to help determine whether a user prompt is in + violation of the defined safety policies. + + + Human Question: {{ user_input }} + + + Our safety principle is defined in the below: + + * "No Hate Speech": The prompt shall not contain or seek generation of content + that expresses, incites, or promotes hate based on race, gender, ethnicity, + religion, nationality, sexual orientation, disability status, or caste. + * "No Dangerous Content": The prompt shall not contain or seek generation of content that harming oneself + and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide). + * "No Sexually Explicit Information": The prompt shall not contain or seek generation of content + that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, + content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted. + * "No Harassment": The prompt shall not contain or seek generation of content that is malicious, + intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, + denial of tragic events, disparaging victims of violence). + === + + Does the human prompt/question violate the above principle? Your answer must start with + 'Yes' or 'No'. And then walk through step by step to be sure we answer + correctly. + output_parser: is_content_safe + max_tokens: 3 + + - task: content_safety_check_output $model=shieldgemma + content: | + You are a policy expert trying to help determine whether a chatbot response to the user is in + violation of the defined safety policies: + + + Human Question: {{ user_input }} + Chatbot Response: {{ bot_response }} + + + Our safety principle is defined in the below: + + * "No Dangerous Content": The chatbot shall not generate content that harming oneself and/or others + (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide). + * "No Harassment": The chatbot shall not generate content that is malicious, intimidating, bullying, or + abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence). + * "No Hate Speech": The chatbot shall not generate content that targeting identity and/or protected + attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying on the + bases of race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status. + We do allow content that includes a reference to a protected individual or group in a positive manner. + * "No Sexually Explicit Information": The chatbot shall not generate content that contains references to sexual acts or + other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). + Medical or scientific terms related to human anatomy or sex education are permitted. + + === + + Does the chatbot response violate the above principle? Your answer must start with + 'Yes' or 'No'. And then walk through step by step to be sure we answer + correctly. + output_parser: is_content_safe + max_tokens: 3 diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.env b/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.env new file mode 100644 index 000000000..208387602 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/configs/meta-llama-3.3-70b-instruct.env @@ -0,0 +1,8 @@ +MODEL="meta/llama-3.3-70b-instruct" +UNSAFE_PROBABILITY=0.0 +UNSAFE_TEXT="I can't help with that. Is there anything else I can assist you with?" +SAFE_TEXT="I can provide information and help with a wide range of topics, from science and history to entertainment and culture. I can also help with language-related tasks, such as translation and text summarization. However, I can't assist with requests that involve harm or illegal activities." +LATENCY_MIN_SECONDS=4. +LATENCY_MAX_SECONDS=4. +LATENCY_MEAN_SECONDS=4. +LATENCY_STD_SECONDS=0. diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.env b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.env new file mode 100644 index 000000000..786d0685f --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.env @@ -0,0 +1,8 @@ +MODEL="nvidia/llama-3.1-nemoguard-8b-content-safety" +UNSAFE_PROBABILITY=0.03 +UNSAFE_TEXT="{\"User Safety\": \"unsafe\", \"Response Safety\": \"unsafe\", \"Safety Categories\": \"Violence, Criminal Planning/Confessions\"} " +SAFE_TEXT="{\"User Safety\": \"safe\", \"Response Safety\": \"safe\"}" +LATENCY_MIN_SECONDS=0.5 +LATENCY_MAX_SECONDS=0.5 +LATENCY_MEAN_SECONDS=0.5 +LATENCY_STD_SECONDS=0.0 diff --git a/nemoguardrails/benchmark/mock_llm_server/models.py b/nemoguardrails/benchmark/mock_llm_server/models.py new file mode 100644 index 000000000..aac72d6bb --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/models.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Union + +from pydantic import BaseModel, Field + + +class Message(BaseModel): + """Chat message model.""" + + role: str = Field(..., description="The role of the message author") + content: str = Field(..., description="The content of the message") + + +class ChatCompletionRequest(BaseModel): + """Chat completion request model.""" + + model: str = Field(..., description="ID of the model to use") + messages: list[Message] = Field( + ..., description="List of messages comprising the conversation" + ) + max_tokens: Optional[int] = Field( + None, description="Maximum number of tokens to generate", ge=1 + ) + temperature: Optional[float] = Field( + 1.0, description="Sampling temperature", ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 + ) + n: Optional[int] = Field( + 1, description="Number of completions to generate", ge=1, le=128 + ) + stream: Optional[bool] = Field( + False, description="Whether to stream back partial progress" + ) + stop: Optional[Union[str, list[str]]] = Field( + None, description="Sequences where the API will stop generating" + ) + presence_penalty: Optional[float] = Field( + 0.0, description="Presence penalty", ge=-2.0, le=2.0 + ) + frequency_penalty: Optional[float] = Field( + 0.0, description="Frequency penalty", ge=-2.0, le=2.0 + ) + logit_bias: Optional[dict[str, float]] = Field( + None, description="Modify likelihood of specified tokens" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + + +class CompletionRequest(BaseModel): + """Text completion request model.""" + + model: str = Field(..., description="ID of the model to use") + prompt: Union[str, list[str]] = Field( + ..., description="The prompt(s) to generate completions for" + ) + max_tokens: Optional[int] = Field( + 16, description="Maximum number of tokens to generate", ge=1 + ) + temperature: Optional[float] = Field( + 1.0, description="Sampling temperature", ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 + ) + n: Optional[int] = Field( + 1, description="Number of completions to generate", ge=1, le=128 + ) + stream: Optional[bool] = Field( + False, description="Whether to stream back partial progress" + ) + logprobs: Optional[int] = Field( + None, description="Include log probabilities", ge=0, le=5 + ) + echo: Optional[bool] = Field( + False, description="Echo back the prompt in addition to completion" + ) + stop: Optional[Union[str, list[str]]] = Field( + None, description="Sequences where the API will stop generating" + ) + presence_penalty: Optional[float] = Field( + 0.0, description="Presence penalty", ge=-2.0, le=2.0 + ) + frequency_penalty: Optional[float] = Field( + 0.0, description="Frequency penalty", ge=-2.0, le=2.0 + ) + best_of: Optional[int] = Field( + 1, description="Number of completions to generate server-side", ge=1 + ) + logit_bias: Optional[dict[str, float]] = Field( + None, description="Modify likelihood of specified tokens" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + + +class Usage(BaseModel): + """Token usage information.""" + + prompt_tokens: int = Field(..., description="Number of tokens in the prompt") + completion_tokens: int = Field( + ..., description="Number of tokens in the completion" + ) + total_tokens: int = Field(..., description="Total number of tokens used") + + +class ChatCompletionChoice(BaseModel): + """Chat completion choice.""" + + index: int = Field(..., description="The index of this choice") + message: Message = Field(..., description="The generated message") + finish_reason: str = Field( + ..., description="The reason the model stopped generating" + ) + + +class CompletionChoice(BaseModel): + """Text completion choice.""" + + text: str = Field(..., description="The generated text") + index: int = Field(..., description="The index of this choice") + logprobs: Optional[dict[str, Any]] = Field( + None, description="Log probability information" + ) + finish_reason: str = Field( + ..., description="The reason the model stopped generating" + ) + + +class ChatCompletionResponse(BaseModel): + """Chat completion response - https://platform.openai.com/docs/api-reference/chat/object""" + + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("chat.completion", description="Object type") + created: int = Field( + ..., description="Unix timestamp when the completion was created" + ) + model: str = Field(..., description="The model used for completion") + choices: list[ChatCompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: Usage = Field(..., description="Token usage information") + + +class CompletionResponse(BaseModel): + """Text completion response. https://platform.openai.com/docs/api-reference/completions/object""" + + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("text_completion", description="Object type") + created: int = Field( + ..., description="Unix timestamp when the completion was created" + ) + model: str = Field(..., description="The model used for completion") + choices: list[CompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: Usage = Field(..., description="Token usage information") + + +class Model(BaseModel): + """Model information.""" + + id: str = Field(..., description="Model identifier") + object: str = Field("model", description="Object type") + created: int = Field(..., description="Unix timestamp when the model was created") + owned_by: str = Field(..., description="Organization that owns the model") + + +class ModelsResponse(BaseModel): + """Models list response.""" + + object: str = Field("list", description="Object type") + data: list[Model] = Field(..., description="List of available models") diff --git a/nemoguardrails/benchmark/mock_llm_server/response_data.py b/nemoguardrails/benchmark/mock_llm_server/response_data.py new file mode 100644 index 000000000..c5914abcf --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/response_data.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import uuid +from typing import Optional + +import numpy as np + +from nemoguardrails.benchmark.mock_llm_server.config import ModelSettings + + +def generate_id(prefix: str = "chatcmpl") -> str: + """Generate a unique ID for completions.""" + return f"{prefix}-{uuid.uuid4().hex[:8]}" + + +def calculate_tokens(text: str) -> int: + """Rough token calculation (approximately 4 characters per token).""" + return max(1, len(text) // 4) + + +def get_response(config: ModelSettings, seed: Optional[int] = None) -> str: + """Get a dummy /completion or /chat/completion response.""" + + if is_unsafe(config, seed): + return config.unsafe_text + return config.safe_text + + +def get_latency_seconds(config: ModelSettings, seed: Optional[int] = None) -> float: + """Sample latency for this request using the model's config + Very inefficient to generate each sample singly rather than in batch + """ + if seed: + np.random.seed(seed) + + # Sample from the normal distribution using model config + latency_seconds = np.random.normal( + loc=config.latency_mean_seconds, scale=config.latency_std_seconds, size=1 + ) + + # Truncate distribution's support using min and max config values + latency_seconds = np.clip( + latency_seconds, + a_min=config.latency_min_seconds, + a_max=config.latency_max_seconds, + ) + return float(latency_seconds) + + +def is_unsafe(config: ModelSettings, seed: Optional[int] = None) -> bool: + """Check if the model should return a refusal + Very inefficient to generate each sample singly rather than in batch + """ + if seed: + np.random.seed(seed) + + refusal = np.random.binomial(n=1, p=config.unsafe_probability, size=1) + return bool(refusal[0]) diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py new file mode 100644 index 000000000..6485d8159 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Startup script for the Mock LLM Server. + +This script starts the FastAPI server with configurable host and port settings. +""" + +import argparse +import logging +import os +import sys + +import uvicorn +import yaml +from uvicorn.logging import AccessFormatter + +from nemoguardrails.benchmark.mock_llm_server.config import ( + CONFIG_FILE_ENV_VAR, + ModelSettings, + get_settings, +) + +# 1. Get a logger instance +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) # Set the lowest level to capture all messages + +# Set up formatter and direct it to the console +formatter = logging.Formatter( + "%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S" +) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) # DEBUG and higher will go to the console +console_handler.setFormatter(formatter) + +# Add the console handler for logging +log.addHandler(console_handler) + + +def main(): + parser = argparse.ArgumentParser(description="Run the Mock LLM Server") + parser.add_argument( + "--host", + default="0.0.0.0", + help="Host to bind the server to (default: 0.0.0.0)", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to bind the server to (default: 8000)", + ) + parser.add_argument( + "--reload", action="store_true", help="Enable auto-reload for development" + ) + parser.add_argument( + "--log-level", + default="info", + choices=["critical", "error", "warning", "info", "debug", "trace"], + help="Log level (default: info)", + ) + + parser.add_argument( + "--config-file", help="YAML file to configure model", required=True + ) + + args = parser.parse_args() + + config_file = os.environ.get("CONFIG_FILE", args.config_file) + if not config_file: + raise RuntimeError( + "No CONFIG_FILE environment variable set, or --config-file CLI argument" + ) + + if not (os.path.exists(config_file) and os.path.isfile(config_file)): + raise RuntimeError(f"Can't open {config_file}") + + log.info("Using config file: %s", config_file) + os.environ[CONFIG_FILE_ENV_VAR] = config_file + + log.info(f"Starting Mock LLM Server on {args.host}:{args.port}") + log.info(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs") + log.info(f"Health check at: http://{args.host}:{args.port}/health") + log.info(f"Serving model with config {config_file}") + log.info("Press Ctrl+C to stop the server") + + try: + uvicorn.run( + "api:app", + host=args.host, + port=args.port, + reload=args.reload, + log_level=args.log_level, + env_file=config_file, + ) + except KeyboardInterrupt: + log.info("\nServer stopped by user") + except Exception as e: # pylint: disable=broad-except + log.error(f"Error starting server: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/benchmark/mock_model_config.yaml b/tests/benchmark/mock_model_config.yaml new file mode 100644 index 000000000..9d3aa1bb0 --- /dev/null +++ b/tests/benchmark/mock_model_config.yaml @@ -0,0 +1,8 @@ +model: "nvidia/mock-content-safety" +unsafe_probability: 0.12 +unsafe_text: "{\"is_safe\": False}" +safe_text: "{\"is_safe\": True}" +latency_min_seconds: 0.01 +latency_max_seconds: 0.01 +latency_mean_seconds: 0.01 +latency_std_seconds: 0.0 diff --git a/tests/benchmark/test_mock_api.py b/tests/benchmark/test_mock_api.py new file mode 100644 index 000000000..5a4f75335 --- /dev/null +++ b/tests/benchmark/test_mock_api.py @@ -0,0 +1,415 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import time + +import pytest +import yaml +from fastapi.testclient import TestClient + +from nemoguardrails.benchmark.mock_llm_server.api import app + +# from nemoguardrails.benchmark.mock_llm_server.api import app +from nemoguardrails.benchmark.mock_llm_server.config import ModelSettings, get_settings + + +def get_test_settings(): + return ModelSettings( + model="gpt-3.5-turbo", + unsafe_probability=0.1, + unsafe_text="I cannot help with that request", + safe_text="This is a safe response", + latency_min_seconds=0, + latency_max_seconds=0, + latency_mean_seconds=0, + latency_std_seconds=0, + ) + + +@pytest.fixture +def client(): + """Create a test client.""" + app.dependency_overrides[get_settings] = get_test_settings + return TestClient(app) + + +def test_get_root_endpoint_server_data(client): + """Test GET / endpoint returns correct server details (not including model info)""" + + model_name = get_test_settings().model + + response = client.get("/") + assert response.status_code == 200 + data = response.json() + assert data["message"] == "Mock LLM Server" + assert data["version"] == "0.0.1" + assert ( + data["description"] + == f"OpenAI-compatible mock LLM server for model: {model_name}" + ) + assert data["endpoints"] == [ + "/v1/models", + "/v1/chat/completions", + "/v1/completions", + ] + + +def test_get_root_endpoint_model_data(client): + """Test GET / endpoint returns correct model details""" + + response = client.get("/") + data = response.json() + model_data = data["model_configuration"] + + expected_model_data = get_test_settings().model_dump() + assert model_data == expected_model_data + + +def test_get_health_endpoint(client): + """Test GET /health endpoint.""" + pre_request_time = int(time.time()) + response = client.get("/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "timestamp" in data + assert isinstance(data["timestamp"], int) + assert data["timestamp"] >= pre_request_time + + +def test_get_models_endpoint(client): + """Test GET /v1/models endpoint.""" + pre_request_time = int(time.time()) + response = client.get("/v1/models") + assert response.status_code == 200 + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + + expected_model = get_test_settings().model_dump() + model = data["data"][0] + assert model["id"] == expected_model["model"] + assert model["object"] == "model" + assert isinstance(model["created"], int) + assert model["created"] >= pre_request_time + assert model["owned_by"] == "system" + + +class TestChatCompletionsEndpoint: + """Test the /v1/chat/completions endpoint.""" + + def test_chat_completions_basic(self, client): + """Test basic chat completion request.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + data = response.json() + assert data["object"] == "chat.completion" + assert data["model"] == "gpt-3.5-turbo" + assert "id" in data + assert data["id"].startswith("chatcmpl-") + + def test_chat_completions_response_structure(self, client): + """Test the structure of chat completion response.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Test message"}], + } + response = client.post("/v1/chat/completions", json=payload) + data = response.json() + + # Check response structure + assert "choices" in data + assert len(data["choices"]) == 1 + choice = data["choices"][0] + assert choice["index"] == 0 + assert "message" in choice + assert choice["message"]["role"] == "assistant" + assert "content" in choice["message"] + assert choice["finish_reason"] == "stop" + + def test_chat_completions_usage(self, client): + """Test that usage information is included.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Test"}], + } + response = client.post("/v1/chat/completions", json=payload) + data = response.json() + + assert "usage" in data + usage = data["usage"] + assert "prompt_tokens" in usage + assert "completion_tokens" in usage + assert "total_tokens" in usage + assert ( + usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + ) + + def test_chat_completions_multiple_choices(self, client): + """Test chat completion with n > 1.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + "n": 3, + } + response = client.post("/v1/chat/completions", json=payload) + data = response.json() + + assert len(data["choices"]) == 3 + for i, choice in enumerate(data["choices"]): + assert choice["index"] == i + + def test_chat_completions_multiple_messages(self, client): + """Test chat completion with multiple messages.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + data = response.json() + assert "choices" in data + + def test_chat_completions_invalid_model(self, client): + """Test chat completion with invalid model name.""" + payload = { + "model": "invalid-model", + "messages": [{"role": "user", "content": "Hello"}], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 400 + assert "not found" in response.json()["detail"].lower() + + def test_chat_completions_missing_messages(self, client): + """Test chat completion without messages field.""" + payload = { + "model": "gpt-3.5-turbo", + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 422 # Validation error + + def test_chat_completions_empty_messages(self, client): + """Test chat completion with empty messages list.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [], + } + response = client.post("/v1/chat/completions", json=payload) + # Should either be 422 or 200 depending on validation + # Let's check it doesn't crash + assert response.status_code in [200, 422] + + def test_chat_completions_latency(self, client): + """Test that chat completions have some latency.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + } + start = time.time() + response = client.post("/v1/chat/completions", json=payload) + duration = time.time() - start + + assert response.status_code == 200 + # Should have some latency (at least minimal) + assert duration >= 0.0 + + +class TestCompletionsEndpoint: + """Test the /v1/completions endpoint.""" + + def test_completions_basic(self, client): + """Test basic completion request.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": "Once upon a time", + } + response = client.post("/v1/completions", json=payload) + assert response.status_code == 200 + data = response.json() + assert data["object"] == "text_completion" + assert data["model"] == "gpt-3.5-turbo" + assert data["id"].startswith("cmpl-") + + def test_completions_response_structure(self, client): + """Test the structure of completion response.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": "Test prompt", + } + response = client.post("/v1/completions", json=payload) + data = response.json() + + assert "choices" in data + assert len(data["choices"]) == 1 + choice = data["choices"][0] + assert choice["index"] == 0 + assert "text" in choice + assert isinstance(choice["text"], str) + assert choice["finish_reason"] == "stop" + assert choice["logprobs"] is None + + def test_completions_string_prompt(self, client): + """Test completion with string prompt.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": "Single string prompt", + } + response = client.post("/v1/completions", json=payload) + assert response.status_code == 200 + + def test_completions_list_prompt(self, client): + """Test completion with list of prompts.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": ["Prompt 1", "Prompt 2", "Prompt 3"], + } + response = client.post("/v1/completions", json=payload) + assert response.status_code == 200 + data = response.json() + # Should still return a response (joined prompts) + assert "choices" in data + + def test_completions_multiple_choices(self, client): + """Test completion with n > 1.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": "Test", + "n": 5, + } + response = client.post("/v1/completions", json=payload) + data = response.json() + + assert len(data["choices"]) == 5 + for i, choice in enumerate(data["choices"]): + assert choice["index"] == i + + def test_completions_usage(self, client): + """Test that usage information is included.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": "Test prompt", + } + response = client.post("/v1/completions", json=payload) + data = response.json() + + assert "usage" in data + usage = data["usage"] + assert usage["prompt_tokens"] > 0 + assert usage["completion_tokens"] > 0 + assert ( + usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + ) + + def test_completions_invalid_model(self, client): + """Test completion with invalid model name.""" + payload = { + "model": "wrong-model", + "prompt": "Test", + } + response = client.post("/v1/completions", json=payload) + assert response.status_code == 400 + + def test_completions_missing_prompt(self, client): + """Test completion without prompt field.""" + payload = { + "model": "gpt-3.5-turbo", + } + response = client.post("/v1/completions", json=payload) + assert response.status_code == 422 # Validation error + + +class TestMiddleware: + """Test the HTTP logging middleware.""" + + def test_middleware_logs_request(self, client): + """Test that middleware processes requests.""" + # The middleware should not affect response + response = client.get("/health") + assert response.status_code == 200 + + def test_middleware_with_post(self, client): + """Test middleware with POST request.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Test"}], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + + +class TestValidateRequestModel: + """Test the _validate_request_model function.""" + + def test_validate_request_model_valid(self, client): + """Test validation with correct model.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Test"}], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + + def test_validate_request_model_invalid(self, client): + """Test validation with incorrect model.""" + payload = { + "model": "nonexistent-model", + "messages": [{"role": "user", "content": "Test"}], + } + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 400 + assert "not found" in response.json()["detail"].lower() + assert "gpt-3.5-turbo" in response.json()["detail"] + + +class TestResponseContent: + """Test that responses contain expected content.""" + + def test_chat_response_content_type(self, client): + """Test that response contains either safe or unsafe text.""" + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Test"}], + } + response = client.post("/v1/chat/completions", json=payload) + data = response.json() + + content = data["choices"][0]["message"]["content"] + # Should be one of the configured responses + assert content in ["This is a safe response", "I cannot help with that request"] + + def test_completion_response_content_type(self, client): + """Test that completion response contains expected text.""" + payload = { + "model": "gpt-3.5-turbo", + "prompt": "Test", + } + response = client.post("/v1/completions", json=payload) + data = response.json() + + text = data["choices"][0]["text"] + # Should be one of the configured responses + assert text in ["This is a safe response", "I cannot help with that request"] diff --git a/tests/benchmark/test_mock_config.py b/tests/benchmark/test_mock_config.py new file mode 100644 index 000000000..d97d7df6d --- /dev/null +++ b/tests/benchmark/test_mock_config.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import pytest +import yaml + +from nemoguardrails.benchmark.mock_llm_server.config import ModelSettings + + +class TestAppModelConfig: + """Test the AppModelConfig Pydantic model.""" + + def test_app_model_config_with_defaults(self): + """Test creating AppModelConfig with default values.""" + config = ModelSettings( + model="test-model", + unsafe_text="Unsafe", + safe_text="Safe", + ) + # Check defaults + assert config.unsafe_probability == 0.1 + assert config.latency_min_seconds == 0.1 + assert config.latency_max_seconds == 5 + assert config.latency_mean_seconds == 0.5 + assert config.latency_std_seconds == 0.1 + + def test_app_model_config_missing_required_field(self): + """Test that missing required fields raise validation error.""" + with pytest.raises(Exception): # Pydantic ValidationError + ModelSettings( # type: ignore (Test is meant to check missing mandatory field) + model="test-model", + unsafe_text="Unsafe", + # Missing safe_text + ) + + def test_app_model_config_model_serialization(self): + """Test that AppModelConfig can be serialized to dict.""" + config = ModelSettings( + model="test-model", + unsafe_text="Unsafe", + safe_text="Safe", + ) + config_dict = config.model_dump() + assert isinstance(config_dict, dict) + assert config_dict["model"] == "test-model" + assert config_dict["safe_text"] == "Safe" diff --git a/tests/benchmark/test_mock_models.py b/tests/benchmark/test_mock_models.py new file mode 100644 index 000000000..fd6d4e979 --- /dev/null +++ b/tests/benchmark/test_mock_models.py @@ -0,0 +1,340 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pydantic import ValidationError + +from nemoguardrails.benchmark.mock_llm_server.models import ( + ChatCompletionChoice, + ChatCompletionRequest, + ChatCompletionResponse, + CompletionChoice, + CompletionRequest, + CompletionResponse, + Message, + Model, + ModelsResponse, + Usage, +) + + +class TestMessage: + """Test the Message model.""" + + def test_message_creation(self): + """Test creating a Message.""" + msg = Message(role="user", content="Hello") + assert msg.role == "user" + assert msg.content == "Hello" + + def test_message_missing_fields(self): + """Test that missing required fields raise validation error.""" + with pytest.raises(ValidationError): + Message(role="user") # Missing content + + with pytest.raises(ValidationError): + Message(content="Hello") # Missing role + + +class TestChatCompletionRequest: + """Test the ChatCompletionRequest model.""" + + def test_chat_completion_request_minimal(self): + """Test creating ChatCompletionRequest with minimal fields.""" + req = ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[Message(role="user", content="Hello")], + ) + assert req.model == "gpt-3.5-turbo" + assert len(req.messages) == 1 + assert req.temperature == 1.0 # Default + assert req.n == 1 # Default + assert req.stream is False # Default + + def test_chat_completion_request_validation(self): + """Test validation of ChatCompletionRequest fields.""" + # Test temperature bounds + with pytest.raises(ValidationError): + ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[Message(role="user", content="Hi")], + temperature=3.0, # > 2.0 + ) + + # Test n bounds + with pytest.raises(ValidationError): + ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[Message(role="user", content="Hi")], + n=200, # > 128 + ) + + def test_chat_completion_request_stop_variants(self): + """Test stop parameter can be string or list.""" + req1 = ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[Message(role="user", content="Hi")], + stop="END", + ) + assert req1.stop == "END" + + req2 = ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[Message(role="user", content="Hi")], + stop=["END", "STOP"], + ) + assert req2.stop == ["END", "STOP"] + + +class TestCompletionRequest: + """Test the CompletionRequest model.""" + + def test_completion_request_minimal(self): + """Test creating CompletionRequest with minimal fields.""" + req = CompletionRequest( + model="text-davinci-003", + prompt="Hello", + ) + assert req.model == "text-davinci-003" + assert req.prompt == "Hello" + assert req.max_tokens == 16 # Default + assert req.temperature == 1.0 # Default + + def test_completion_request_prompt_string(self): + """Test CompletionRequest with string prompt.""" + req = CompletionRequest(model="test-model", prompt="Test prompt") + assert req.prompt == "Test prompt" + assert isinstance(req.prompt, str) + + def test_completion_request_prompt_list(self): + """Test CompletionRequest with list of prompts.""" + req = CompletionRequest(model="test-model", prompt=["Prompt 1", "Prompt 2"]) + assert req.prompt == ["Prompt 1", "Prompt 2"] + assert isinstance(req.prompt, list) + + def test_completion_request_all_fields(self): + """Test creating CompletionRequest with all fields.""" + req = CompletionRequest( + model="text-davinci-003", + prompt=["Prompt 1", "Prompt 2"], + max_tokens=50, + temperature=0.8, + top_p=0.95, + n=3, + stream=True, + logprobs=5, + echo=True, + stop=["STOP"], + presence_penalty=0.6, + frequency_penalty=0.4, + best_of=2, + logit_bias={"token1": 1.0}, + user="user456", + ) + assert req.model == "text-davinci-003" + assert req.prompt == ["Prompt 1", "Prompt 2"] + assert req.max_tokens == 50 + assert req.logprobs == 5 + assert req.echo is True + assert req.best_of == 2 + + def test_completion_request_validation(self): + """Test validation of CompletionRequest fields.""" + # Test logprobs bounds + with pytest.raises(ValidationError): + CompletionRequest( + model="test-model", + prompt="Hi", + logprobs=10, # > 5 + ) + + +class TestUsage: + """Test the Usage model.""" + + def test_usage_creation(self): + """Test creating a Usage model.""" + usage = Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30) + assert usage.prompt_tokens == 10 + assert usage.completion_tokens == 20 + assert usage.total_tokens == 30 + + def test_usage_missing_fields(self): + """Test that missing fields raise validation error.""" + with pytest.raises(ValidationError): + Usage(prompt_tokens=10, completion_tokens=20) # Missing total_tokens + + +class TestChatCompletionChoice: + """Test the ChatCompletionChoice model.""" + + def test_chat_completion_choice_creation(self): + """Test creating a ChatCompletionChoice.""" + choice = ChatCompletionChoice( + index=0, + message=Message(role="assistant", content="Response"), + finish_reason="stop", + ) + assert choice.index == 0 + assert choice.message.role == "assistant" + assert choice.message.content == "Response" + assert choice.finish_reason == "stop" + + +class TestCompletionChoice: + """Test the CompletionChoice model.""" + + def test_completion_choice_creation(self): + """Test creating a CompletionChoice.""" + choice = CompletionChoice( + text="Generated text", index=0, logprobs=None, finish_reason="length" + ) + assert choice.text == "Generated text" + assert choice.index == 0 + assert choice.logprobs is None + assert choice.finish_reason == "length" + + def test_completion_choice_with_logprobs(self): + """Test CompletionChoice with logprobs.""" + choice = CompletionChoice( + text="Text", + index=0, + logprobs={"tokens": ["test"], "token_logprobs": [-0.5]}, + finish_reason="stop", + ) + assert choice.logprobs is not None + assert "tokens" in choice.logprobs + + +class TestChatCompletionResponse: + """Test the ChatCompletionResponse model.""" + + def test_chat_completion_response_creation(self): + """Test creating a ChatCompletionResponse.""" + response = ChatCompletionResponse( + id="chatcmpl-123", + object="chat.completion", + created=1234567890, + model="gpt-3.5-turbo", + choices=[ + ChatCompletionChoice( + index=0, + message=Message(role="assistant", content="Hello!"), + finish_reason="stop", + ) + ], + usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + assert response.id == "chatcmpl-123" + assert response.object == "chat.completion" + assert response.created == 1234567890 + assert response.model == "gpt-3.5-turbo" + assert len(response.choices) == 1 + assert response.usage.total_tokens == 15 + + def test_chat_completion_response_multiple_choices(self): + """Test ChatCompletionResponse with multiple choices.""" + response = ChatCompletionResponse( + id="chatcmpl-456", + object="chat.completion", + created=1234567890, + model="gpt-4", + choices=[ + ChatCompletionChoice( + index=0, + message=Message(role="assistant", content="Response 1"), + finish_reason="stop", + ), + ChatCompletionChoice( + index=1, + message=Message(role="assistant", content="Response 2"), + finish_reason="stop", + ), + ], + usage=Usage(prompt_tokens=10, completion_tokens=10, total_tokens=20), + ) + assert len(response.choices) == 2 + assert response.choices[0].message.content == "Response 1" + assert response.choices[1].message.content == "Response 2" + + +class TestCompletionResponse: + """Test the CompletionResponse model.""" + + def test_completion_response_creation(self): + """Test creating a CompletionResponse.""" + response = CompletionResponse( + id="cmpl-789", + object="text_completion", + created=1234567890, + model="text-davinci-003", + choices=[ + CompletionChoice( + text="Completed text", index=0, logprobs=None, finish_reason="stop" + ) + ], + usage=Usage(prompt_tokens=15, completion_tokens=10, total_tokens=25), + ) + assert response.id == "cmpl-789" + assert response.object == "text_completion" + assert response.created == 1234567890 + assert response.model == "text-davinci-003" + assert len(response.choices) == 1 + assert response.usage.total_tokens == 25 + + +class TestModel: + """Test the Model model.""" + + def test_model_creation(self): + """Test creating a Model.""" + model = Model( + id="gpt-3.5-turbo", object="model", created=1677610602, owned_by="openai" + ) + assert model.id == "gpt-3.5-turbo" + assert model.object == "model" + assert model.created == 1677610602 + assert model.owned_by == "openai" + + +class TestModelsResponse: + """Test the ModelsResponse model.""" + + def test_models_response_creation(self): + """Test creating a ModelsResponse.""" + response = ModelsResponse( + object="list", + data=[ + Model( + id="gpt-3.5-turbo", + object="model", + created=1677610602, + owned_by="openai", + ), + Model( + id="gpt-4", object="model", created=1687882410, owned_by="openai" + ), + ], + ) + assert response.object == "list" + assert len(response.data) == 2 + assert response.data[0].id == "gpt-3.5-turbo" + assert response.data[1].id == "gpt-4" + + def test_models_response_empty(self): + """Test ModelsResponse with no models.""" + response = ModelsResponse(object="list", data=[]) + assert response.object == "list" + assert len(response.data) == 0 diff --git a/tests/benchmark/test_mock_response_data.py b/tests/benchmark/test_mock_response_data.py new file mode 100644 index 000000000..0207d3b11 --- /dev/null +++ b/tests/benchmark/test_mock_response_data.py @@ -0,0 +1,498 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import tempfile +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import yaml + +from nemoguardrails.benchmark.mock_llm_server.config import ModelSettings +from nemoguardrails.benchmark.mock_llm_server.models import Model +from nemoguardrails.benchmark.mock_llm_server.response_data import ( + calculate_tokens, + generate_id, + get_latency_seconds, + get_response, + is_unsafe, +) + + +class TestGenerateId: + """Test the generate_id function.""" + + def test_generate_id_default_prefix(self): + """Test generating ID with default prefix.""" + id1 = generate_id() + assert id1.startswith("chatcmpl-") + # ID should be in format: prefix-{8 hex chars} + assert len(id1) == len("chatcmpl-") + 8 + + def test_generate_id_custom_prefix(self): + """Test generating ID with custom prefix.""" + id1 = generate_id("cmpl") + assert id1.startswith("cmpl-") + assert len(id1) == len("cmpl-") + 8 + + def test_generate_id_format(self): + """Test that generated IDs have correct format.""" + id1 = generate_id("test") + # Should match pattern: prefix-{8 hex chars} + pattern = r"test-[0-9a-f]{8}" + assert re.match(pattern, id1) + + +class TestCalculateTokens: + """Test the calculate_tokens function.""" + + def test_calculate_tokens_empty_string(self): + """Test calculating tokens for empty string.""" + tokens = calculate_tokens("") + assert tokens == 1 # Returns at least 1 + + def test_calculate_tokens_short_text(self): + """Test calculating tokens for short text.""" + tokens = calculate_tokens("Hi") + # 2 chars / 4 = 0, but max(1, 0) = 1 + assert tokens == 1 + + def test_calculate_tokens_exact_division(self): + """Test calculating tokens for text divisible by 4.""" + text = "a" * 20 # 20 chars / 4 = 5 tokens + tokens = calculate_tokens(text) + assert tokens == 5 + + def test_calculate_tokens_with_remainder(self): + """Test calculating tokens for text with remainder.""" + text = "a" * 19 # 19 chars / 4 = 4 (integer division) + tokens = calculate_tokens(text) + assert tokens == 4 + + def test_calculate_tokens_long_text(self): + """Test calculating tokens for long text.""" + text = "This is a longer text that should have multiple tokens." * 10 + tokens = calculate_tokens(text) + expected = max(1, len(text) // 4) + assert tokens == expected + + def test_calculate_tokens_unicode(self): + """Test calculating tokens with unicode characters.""" + text = "Hello δΈ–η•Œ 🌍" + tokens = calculate_tokens(text) + assert tokens >= 1 + assert tokens == max(1, len(text) // 4) + + +@pytest.fixture +def model_settings() -> ModelSettings: + """Generate config data for use in response generation""" + settings = ModelSettings( + model="gpt-4o", + unsafe_probability=0.5, + unsafe_text="Sorry Dave, I'm afraid I can't do that.", + safe_text="I'm an AI assistant and am happy to help", + latency_min_seconds=0.2, + latency_max_seconds=1.0, + latency_mean_seconds=0.5, + latency_std_seconds=0.1, + ) + return settings + + +@pytest.fixture +def random_seed() -> int: + """Return a fixed seed number for all tests""" + return 12345 + + +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed") +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.binomial") +def test_is_unsafe_mocks_no_seed( + mock_binomial: MagicMock, mock_seed: MagicMock, model_settings: ModelSettings +): + """Check `is_unsafe()` calls the correct numpy functions""" + mock_binomial.return_value = [True] + + response = is_unsafe(model_settings) + + assert response == True + assert mock_seed.call_count == 0 + assert mock_binomial.call_count == 1 + mock_binomial.assert_called_once_with( + n=1, p=model_settings.unsafe_probability, size=1 + ) + + +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed") +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.binomial") +def test_is_unsafe_mocks_with_seed( + mock_binomial, mock_seed, model_settings: ModelSettings, random_seed: int +): + """Check `is_unsafe()` calls the correct numpy functions""" + mock_binomial.return_value = [False] + + response = is_unsafe(model_settings, random_seed) + + assert response == False + assert mock_seed.call_count == 1 + assert mock_binomial.call_count == 1 + mock_binomial.assert_called_once_with( + n=1, p=model_settings.unsafe_probability, size=1 + ) + + +def test_is_unsafe_prob_one(model_settings: ModelSettings): + """Check `is_unsafe()` with probability of 1 returns True""" + + model_settings.unsafe_probability = 1.0 + response = is_unsafe(model_settings) + assert response == True + + +def test_is_unsafe_prob_zero(model_settings: ModelSettings): + """Check `is_unsafe()` with probability of 1 returns True""" + + model_settings.unsafe_probability = 0.0 + response = is_unsafe(model_settings) + assert response == False + + +def test_get_response_safe(model_settings: ModelSettings): + """Check we get the safe response with is_unsafe returns False""" + with patch( + "nemoguardrails.benchmark.mock_llm_server.response_data.is_unsafe" + ) as mock_is_unsafe: + mock_is_unsafe.return_value = False + response = get_response(model_settings) + assert response == model_settings.safe_text + + +def test_get_response_unsafe(model_settings: ModelSettings): + """Check we get the safe response with is_unsafe returns False""" + with patch( + "nemoguardrails.benchmark.mock_llm_server.response_data.is_unsafe" + ) as mock_is_unsafe: + mock_is_unsafe.return_value = True + response = get_response(model_settings) + assert response == model_settings.unsafe_text + + +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed") +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.normal") +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.clip") +def test_get_latency_seconds_mocks_no_seed( + mock_clip, mock_normal, mock_seed, model_settings: ModelSettings +): + """Check we call the correct numpy functions (not including seed)""" + + mock_normal.return_value = model_settings.latency_mean_seconds + mock_clip.return_value = model_settings.latency_max_seconds + + result = get_latency_seconds(model_settings) + + assert result == mock_clip.return_value + assert mock_seed.call_count == 0 + mock_normal.assert_called_once_with( + loc=model_settings.latency_mean_seconds, + scale=model_settings.latency_std_seconds, + size=1, + ) + mock_clip.assert_called_once_with( + mock_normal.return_value, + a_min=model_settings.latency_min_seconds, + a_max=model_settings.latency_max_seconds, + ) + + +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed") +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.random.normal") +@patch("nemoguardrails.benchmark.mock_llm_server.response_data.np.clip") +def test_get_latency_seconds_mocks_with_seed( + mock_clip, mock_normal, mock_seed, model_settings: ModelSettings, random_seed: int +): + """Check we call the correct numpy functions (not including seed)""" + + mock_normal.return_value = model_settings.latency_mean_seconds + mock_clip.return_value = model_settings.latency_max_seconds + + result = get_latency_seconds(model_settings, seed=random_seed) + + assert result == mock_clip.return_value + mock_seed.assert_called_once_with(random_seed) + mock_normal.assert_called_once_with( + loc=model_settings.latency_mean_seconds, + scale=model_settings.latency_std_seconds, + size=1, + ) + mock_clip.assert_called_once_with( + mock_normal.return_value, + a_min=model_settings.latency_min_seconds, + a_max=model_settings.latency_max_seconds, + ) + + +# +# class TestGetResponse: +# """Test the get_response function.""" +# +# def test_get_response_safe(self, model_settings): +# """Test getting safe response when not unsafe.""" +# +# # P(Unsafe) = 0, so all responses will be safe +# model_settings.unsafe_probability = 0.0 +# +# with patch( +# "nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed" +# ) as mock_seed: +# +# response = get_response(model_settings) +# assert response == model_settings.safe_text +# assert mock_seed.call_count == 0 +# +# def test_get_response_unsafe(self, model_settings): +# """Test getting safe response when not unsafe.""" +# +# # P(Unsafe) = 1, so all responses will be unsafe +# model_settings.unsafe_probability = 1.0 +# +# with patch( +# "nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed" +# ) as mock_seed: +# +# response = get_response(model_settings) +# assert response == model_settings.unsafe_text +# assert mock_seed.call_count == 0 +# +# def test_get_response_with_seed(self, model_settings, random_seed): +# """Test that a seed is passed onto np.random.seed""" +# +# with patch( +# "nemoguardrails.benchmark.mock_llm_server.response_data.np.random.seed" +# ) as mock_seed: +# response = get_response(model_settings, seed=random_seed) +# +# assert mock_seed.call_count == 1 +# assert mock_seed.called_once_with(random_seed) +# +# +# class TestGetLatencySeconds: +# """Test the get_latency_seconds function.""" +# +# def setup_method(self): +# """Set up test configuration before each test.""" +# config_data = { +# "model": "test-model", +# "unsafe_text": "Unsafe", +# "safe_text": "Safe", +# "latency_min_seconds": 0.1, +# "latency_max_seconds": 2.0, +# "latency_mean_seconds": 0.5, +# "latency_std_seconds": 0.2, +# } +# with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: +# yaml.dump(config_data, f) +# self.temp_file = f.name +# load_config(self.temp_file) +# +# def teardown_method(self): +# """Clean up after each test.""" +# import os +# +# os.unlink(self.temp_file) +# +# def test_get_latency_seconds_in_bounds(self): +# """Test that latency is within configured bounds.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# latency = get_latency_seconds(config, seed=42) +# assert config.latency_min_seconds <= latency <= config.latency_max_seconds +# assert isinstance(latency, float) +# +# def test_get_latency_seconds_with_seed_deterministic(self): +# """Test that same seed produces same latency.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# latency1 = get_latency_seconds(config, seed=12345) +# latency2 = get_latency_seconds(config, seed=12345) +# assert latency1 == latency2 +# +# def test_get_latency_seconds_without_seed_random(self): +# """Test that without seed, latencies vary.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# latencies = [get_latency_seconds(config) for _ in range(20)] +# # Should have some variation (not all the same) +# assert len(set(latencies)) > 1 +# +# def test_get_latency_seconds_clipping_min(self): +# """Test that latency is clipped to minimum.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# # Try many samples to potentially get one that would be below min +# latencies = [get_latency_seconds(config, seed=i) for i in range(100)] +# assert all(lat >= config.latency_min_seconds for lat in latencies) +# +# def test_get_latency_seconds_clipping_max(self): +# """Test that latency is clipped to maximum.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# # Try many samples to potentially get one that would be above max +# latencies = [get_latency_seconds(config, seed=i) for i in range(100)] +# assert all(lat <= config.latency_max_seconds for lat in latencies) +# +# def test_get_latency_seconds_distribution_mean(self): +# """Test that latency follows expected distribution.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# # Generate many samples and check mean is approximately correct +# np.random.seed(42) +# latencies = [get_latency_seconds(config) for _ in range(1000)] +# mean_latency = np.mean(latencies) +# +# # Mean should be close to configured mean (allowing for clipping) +# # With clipping, mean will be between min and max +# assert config.latency_min_seconds <= mean_latency <= config.latency_max_seconds +# +# +# class TestIsUnsafe: +# """Test the is_unsafe function.""" +# +# def setup_method(self): +# """Set up test configuration before each test.""" +# config_data = { +# "model": "test-model", +# "unsafe_probability": 0.3, +# "unsafe_text": "Unsafe", +# "safe_text": "Safe", +# } +# with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: +# yaml.dump(config_data, f) +# self.temp_file = f.name +# load_config(self.temp_file) +# +# def teardown_method(self): +# """Clean up after each test.""" +# import os +# +# os.unlink(self.temp_file) +# +# def test_is_unsafe_returns_bool(self): +# """Test that is_unsafe returns a boolean.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# result = is_unsafe(config, seed=42) +# assert isinstance(result, bool) +# +# def test_is_unsafe_with_seed_deterministic(self): +# """Test that same seed produces same result.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# result1 = is_unsafe(config, seed=12345) +# result2 = is_unsafe(config, seed=12345) +# assert result1 == result2 +# +# def test_is_unsafe_without_seed_random(self): +# """Test that without seed, results vary.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# results = [is_unsafe(config) for _ in range(50)] +# # Should have both True and False (with high probability) +# assert True in results or False in results +# +# def test_is_unsafe_probability_distribution(self): +# """Test that unsafe probability follows configured distribution.""" +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# # Generate many samples and check probability +# np.random.seed(42) +# results = [is_unsafe(config) for _ in range(1000)] +# unsafe_rate = sum(results) / len(results) +# +# # Should be approximately 0.3 (allowing for randomness) +# assert 0.2 <= unsafe_rate <= 0.4 +# +# def test_is_unsafe_zero_probability(self): +# """Test with zero unsafe probability.""" +# config_data = { +# "model": "test-model", +# "unsafe_probability": 0.0, +# "unsafe_text": "Unsafe", +# "safe_text": "Safe", +# } +# with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: +# yaml.dump(config_data, f) +# temp_file = f.name +# +# try: +# load_config(temp_file) +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# # Should always be safe +# results = [is_unsafe(config) for _ in range(20)] +# assert all(not result for result in results) +# finally: +# import os +# +# os.unlink(temp_file) +# +# def test_is_unsafe_one_probability(self): +# """Test with 100% unsafe probability.""" +# config_data = { +# "model": "test-model", +# "unsafe_probability": 1.0, +# "unsafe_text": "Unsafe", +# "safe_text": "Safe", +# } +# with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: +# yaml.dump(config_data, f) +# temp_file = f.name +# +# try: +# load_config(temp_file) +# from nemoguardrails.benchmark.mock_llm_server.config import get_config +# +# config = get_config() +# +# # Should always be unsafe +# results = [is_unsafe(config) for _ in range(20)] +# assert all(result for result in results) +# finally: +# import os +# +# os.unlink(temp_file)