diff --git a/modules/vllm/README.rst b/modules/vllm/README.rst new file mode 100644 index 00000000..b8139e21 --- /dev/null +++ b/modules/vllm/README.rst @@ -0,0 +1,308 @@ +.. autoclass:: testcontainers.vllm.VLLMContainer +.. title:: testcontainers.vllm.VLLMContainer + +VLLM Container +============== + +The VLLMContainer provides a high-performance LLM inference server using the VLLM framework. It supports various models, GPU acceleration, and OpenAI-compatible API endpoints. + +Features +-------- + +- **High Performance**: Optimized for fast LLM inference with VLLM +- **GPU Support**: Automatic GPU detection and configuration +- **OpenAI Compatible**: Full OpenAI API compatibility for easy integration +- **Model Flexibility**: Support for various model formats and sizes +- **Embedding Support**: Generate embeddings for text +- **Streaming**: Support for streaming text generation +- **Batch Processing**: Efficient batch processing capabilities + +Basic Usage +----------- + +.. code-block:: python + + from testcontainers.vllm import VLLMContainer + import requests + + with VLLMContainer(model_name="microsoft/DialoGPT-medium") as vllm: + # Generate text + response = requests.post( + f"{vllm.get_api_url()}/v1/chat/completions", + json={ + "model": "microsoft/DialoGPT-medium", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "temperature": 0.7 + } + ) + + result = response.json() + print(result["choices"][0]["message"]["content"]) + +Configuration Options +--------------------- + +The VLLMContainer supports various configuration options: + +.. code-block:: python + + with VLLMContainer( + model_name="gpt2", + gpu_memory_utilization=0.9, + max_model_len=2048, + tensor_parallel_size=2, + pipeline_parallel_size=1, + trust_remote_code=True + ) as vllm: + # Your code here + pass + +Parameters +---------- + +- **image**: Docker image to use (default: ``vllm/vllm-openai:latest``) +- **model_name**: Model to load (default: ``microsoft/DialoGPT-medium``) +- **host**: Server host (default: ``0.0.0.0``) +- **port**: Server port (default: ``8000``) +- **gpu_memory_utilization**: GPU memory utilization (default: ``0.9``) +- **max_model_len**: Maximum model length (default: ``4096``) +- **tensor_parallel_size**: Tensor parallel size (default: ``1``) +- **pipeline_parallel_size**: Pipeline parallel size (default: ``1``) +- **trust_remote_code**: Trust remote code (default: ``False``) +- **download_dir**: Directory to download models (default: ``None``) +- **vllm_home**: Directory to mount for model data (default: ``None``) + +API Methods +----------- + +Health and Status +~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + # Check health status + health = vllm.get_health_status() + + # Get server information + server_info = vllm.get_server_info() + + # List available models + models = vllm.list_models() + +Text Generation +~~~~~~~~~~~~~~~ + +.. code-block:: python + + # Generate text using the container method + text = vllm.generate_text( + prompt="Write a short story", + max_tokens=100, + temperature=0.7 + ) + + # Or use direct API calls + response = requests.post( + f"{vllm.get_api_url()}/v1/chat/completions", + json={ + "model": "microsoft/DialoGPT-medium", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 50 + } + ) + +Embeddings +~~~~~~~~~~ + +.. code-block:: python + + # Generate embeddings + embeddings = vllm.generate_embeddings("Hello, world!") + + # Batch embeddings + texts = ["Hello", "World", "VLLM"] + batch_embeddings = vllm.generate_embeddings(texts) + +Streaming +~~~~~~~~~ + +.. code-block:: python + + import json + + response = requests.post( + f"{vllm.get_api_url()}/v1/chat/completions", + json={ + "model": "microsoft/DialoGPT-medium", + "messages": [{"role": "user", "content": "Tell me a story"}], + "max_tokens": 100, + "stream": True + }, + stream=True + ) + + for line in response.iter_lines(): + if line: + line = line.decode('utf-8') + if line.startswith('data: '): + data = line[6:] + if data == '[DONE]': + break + try: + chunk = json.loads(data) + if 'choices' in chunk and len(chunk['choices']) > 0: + delta = chunk['choices'][0].get('delta', {}) + if 'content' in delta: + print(delta['content'], end='', flush=True) + except json.JSONDecodeError: + continue + +Advanced Configuration +---------------------- + +Using VLLM Configuration Objects +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + from testcontainers.vllm import VLLMContainer + from dataclasses import create_vllm_config, SamplingParams + + with VLLMContainer(model_name="gpt2") as vllm: + # Get VLLM configuration + config = vllm.get_vllm_config() + + # Get server configuration + server_config = vllm.get_server_config() + + # Use with sampling parameters + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.9, + max_tokens=100 + ) + +GPU Support +----------- + +The container automatically detects and configures GPU support when available: + +.. code-block:: python + + # GPU support is automatically detected + with VLLMContainer( + model_name="microsoft/DialoGPT-medium", + gpu_memory_utilization=0.9 + ) as vllm: + # Container will use GPU if available + pass + +Model Persistence +----------------- + +To persist models between container runs: + +.. code-block:: python + + from pathlib import Path + + # Use a persistent directory for models + model_dir = Path.home() / ".vllm_models" + + with VLLMContainer( + model_name="microsoft/DialoGPT-medium", + vllm_home=model_dir + ) as vllm: + # Models will be cached in model_dir + pass + +Error Handling +-------------- + +.. code-block:: python + + try: + with VLLMContainer(model_name="invalid-model") as vllm: + # This will fail if the model doesn't exist + pass + except Exception as e: + print(f"Container failed to start: {e}") + +Performance Tips +---------------- + +1. **GPU Memory**: Adjust `gpu_memory_utilization` based on your GPU memory +2. **Model Length**: Set appropriate `max_model_len` for your use case +3. **Parallel Processing**: Use `tensor_parallel_size` for multi-GPU setups +4. **Model Caching**: Use `vllm_home` to cache models between runs + +Example Applications +-------------------- + +RAG System +~~~~~~~~~~ + +.. code-block:: python + + from testcontainers.vllm import VLLMContainer + import requests + + with VLLMContainer(model_name="microsoft/DialoGPT-medium") as vllm: + # Generate embeddings for documents + documents = ["Document 1", "Document 2", "Document 3"] + embeddings = vllm.generate_embeddings(documents) + + # Generate responses with context + context = "Based on the documents: " + " ".join(documents) + response = vllm.generate_text( + prompt="Summarize the key points", + context=context, + max_tokens=100 + ) + +Batch Processing +~~~~~~~~~~~~~~~~ + +.. code-block:: python + + with VLLMContainer(model_name="gpt2") as vllm: + prompts = [ + "What is AI?", + "Explain machine learning", + "What is deep learning?" + ] + + for prompt in prompts: + response = vllm.generate_text(prompt, max_tokens=50) + print(f"Q: {prompt}") + print(f"A: {response}\n") + +Integration with VLLM Dataclass +------------------------------- + +The container integrates seamlessly with the VLLM dataclass system: + +.. code-block:: python + + from testcontainers.vllm import VLLMContainer + from dataclasses import VllmConfig, SamplingParams + from integration import VLLMRAGSystem, VLLMDeployment + + # Create deployment configuration + deployment = VLLMDeployment( + llm_config=VLLMServerConfig( + model_name="microsoft/DialoGPT-medium", + port=8000 + ) + ) + + # Use with container + with VLLMContainer(model_name="microsoft/DialoGPT-medium") as vllm: + # Get configuration objects + config = vllm.get_vllm_config() + server_config = vllm.get_server_config() + + # Use with RAG system + rag_system = VLLMRAGSystem(deployment=deployment) + # Initialize and use the RAG system diff --git a/modules/vllm/example_basic.py b/modules/vllm/example_basic.py new file mode 100644 index 00000000..105de875 --- /dev/null +++ b/modules/vllm/example_basic.py @@ -0,0 +1,210 @@ +import requests +import json +from testcontainers.vllm import VLLMContainer + + +def basic_example(): + """Basic example of using VLLMContainer for text generation and embeddings.""" + + with VLLMContainer( + model_name="microsoft/DialoGPT-medium", + gpu_memory_utilization=0.8, + max_model_len=1024 + ) as vllm: + # Get API endpoint + api_url = vllm.get_api_url() + print(f"VLLM API URL: {api_url}") + + # Check health status + health = vllm.get_health_status() + print(f"Health status: {health}") + + # List available models + models = vllm.list_models() + print(f"Available models: {json.dumps(models, indent=2)}") + + # Generate text using chat completions + prompt = "Write a short poem about artificial intelligence." + print(f"\nGenerating text for prompt: {prompt}") + + response = requests.post( + f"{api_url}/v1/chat/completions", + json={ + "model": "microsoft/DialoGPT-medium", + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 100, + "temperature": 0.7, + "top_p": 0.9 + } + ) + + result = response.json() + print("\nGenerated text:") + print(result["choices"][0]["message"]["content"]) + + # Generate embeddings + text_to_embed = "The quick brown fox jumps over the lazy dog" + print(f"\nGenerating embedding for: {text_to_embed}") + + try: + embedding = vllm.generate_embeddings(text_to_embed) + print(f"\nEmbedding:") + print(f"Length: {len(embedding)}") + print(f"First 5 values: {embedding[:5]}") + except Exception as e: + print(f"Embedding generation failed (model may not support embeddings): {e}") + + # Get server information + server_info = vllm.get_server_info() + print(f"\nServer info:") + print(f"Model: {server_info['model_name']}") + print(f"Health: {server_info['health']['status']}") + + +def advanced_example(): + """Advanced example with custom configuration and streaming.""" + + with VLLMContainer( + model_name="gpt2", + gpu_memory_utilization=0.9, + max_model_len=512, + tensor_parallel_size=1, + trust_remote_code=False + ) as vllm: + api_url = vllm.get_api_url() + + # Generate text with custom parameters + prompt = "Once upon a time in a land far away" + print(f"Generating text with custom parameters for: {prompt}") + + response = requests.post( + f"{api_url}/v1/chat/completions", + json={ + "model": "gpt2", + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 50, + "temperature": 0.8, + "top_p": 0.95, + "frequency_penalty": 0.1, + "presence_penalty": 0.1 + } + ) + + result = response.json() + print("Generated text:") + print(result["choices"][0]["message"]["content"]) + + # Test streaming (if supported) + print("\nTesting streaming generation:") + try: + response = requests.post( + f"{api_url}/v1/chat/completions", + json={ + "model": "gpt2", + "messages": [{"role": "user", "content": "Tell me a joke"}], + "max_tokens": 30, + "temperature": 0.7, + "stream": True + }, + stream=True + ) + + print("Streaming response:") + for line in response.iter_lines(): + if line: + line = line.decode('utf-8') + if line.startswith('data: '): + data = line[6:] # Remove 'data: ' prefix + if data == '[DONE]': + break + try: + chunk = json.loads(data) + if 'choices' in chunk and len(chunk['choices']) > 0: + delta = chunk['choices'][0].get('delta', {}) + if 'content' in delta: + print(delta['content'], end='', flush=True) + except json.JSONDecodeError: + continue + print() # New line after streaming + + except Exception as e: + print(f"Streaming failed: {e}") + + +def configuration_example(): + """Example showing different configuration options.""" + + # Example with custom download directory and model home + with VLLMContainer( + model_name="microsoft/DialoGPT-small", + gpu_memory_utilization=0.7, + max_model_len=2048, + tensor_parallel_size=1, + pipeline_parallel_size=1, + trust_remote_code=True + ) as vllm: + # Get the VLLM configuration + config = vllm.get_vllm_config() + print("VLLM Configuration:") + print(f"Model: {config.model.model}") + print(f"GPU Memory Utilization: {config.cache.gpu_memory_utilization}") + print(f"Max Model Length: {config.model.max_model_len}") + print(f"Trust Remote Code: {config.model.trust_remote_code}") + + # Get server configuration + server_config = vllm.get_server_config() + print(f"\nServer Configuration:") + print(f"Host: {server_config.host}") + print(f"Port: {server_config.port}") + print(f"Tensor Parallel Size: {server_config.tensor_parallel_size}") + + # Test basic functionality + health = vllm.get_health_status() + print(f"\nHealth Status: {health['status']}") + + +def batch_processing_example(): + """Example of batch processing multiple requests.""" + + with VLLMContainer(model_name="gpt2") as vllm: + api_url = vllm.get_api_url() + + # Prepare multiple prompts + prompts = [ + "What is artificial intelligence?", + "Explain machine learning in simple terms.", + "What are the benefits of automation?" + ] + + print("Processing batch of prompts:") + + # Process each prompt + for i, prompt in enumerate(prompts, 1): + print(f"\n{i}. {prompt}") + + response = requests.post( + f"{api_url}/v1/chat/completions", + json={ + "model": "gpt2", + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 50, + "temperature": 0.7 + } + ) + + result = response.json() + print(f"Response: {result['choices'][0]['message']['content']}") + + +if __name__ == "__main__": + print("=== Basic VLLM Container Example ===") + basic_example() + + print("\n=== Advanced VLLM Container Example ===") + advanced_example() + + print("\n=== Configuration Example ===") + configuration_example() + + print("\n=== Batch Processing Example ===") + batch_processing_example() diff --git a/modules/vllm/testcontainers/vllm/__init__.py b/modules/vllm/testcontainers/vllm/__init__.py new file mode 100644 index 00000000..68e9c9ad --- /dev/null +++ b/modules/vllm/testcontainers/vllm/__init__.py @@ -0,0 +1,949 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +import json +import time +from abc import ABC, abstractmethod +from datetime import datetime +from enum import Enum +from os import PathLike +from typing import Any, Dict, List, Optional, Union, AsyncGenerator, Tuple, Callable + +import requests +import torch +import numpy as np +from docker.types.containers import DeviceRequest +from pydantic import BaseModel, Field, validator, root_validator + +from testcontainers.core.container import DockerContainer +from testcontainers.core.waiting_utils import wait_for_logs + + +# ============================================================================ +# Core Enums and Types +# ============================================================================ + +class DeviceType(str, Enum): + """Device types supported by VLLM.""" + CUDA = "cuda" + CPU = "cpu" + TPU = "tpu" + XPU = "xpu" + ROCM = "rocm" + + +class ModelType(str, Enum): + """Model types supported by VLLM.""" + DECODER_ONLY = "decoder_only" + ENCODER_DECODER = "encoder_decoder" + EMBEDDING = "embedding" + POOLING = "pooling" + + +class AttentionBackend(str, Enum): + """Attention backends supported by VLLM.""" + FLASH_ATTN = "flash_attn" + XFORMERS = "xformers" + ROCM_FLASH_ATTN = "rocm_flash_attn" + TORCH_SDPA = "torch_sdpa" + + +class SchedulerType(str, Enum): + """Scheduler types for request management.""" + FCFS = "fcfs" # First Come First Served + PRIORITY = "priority" + + +class BlockSpacePolicy(str, Enum): + """Block space policies for memory management.""" + GUARDED = "guarded" + GUARDED_MMAP = "guarded_mmap" + + +class KVSpacePolicy(str, Enum): + """KV cache space policies.""" + EAGER = "eager" + LAZY = "lazy" + + +class QuantizationMethod(str, Enum): + """Quantization methods supported by VLLM.""" + AWQ = "awq" + GPTQ = "gptq" + SQUEEZELLM = "squeezellm" + FP8 = "fp8" + MIXED = "mixed" + BITSANDBYTES = "bitsandbytes" + AUTOROUND = "autoround" + QUARK = "quark" + TORCHAO = "torchao" + + +class LoadFormat(str, Enum): + """Model loading formats.""" + AUTO = "auto" + TORCH = "torch" + SAFETENSORS = "safetensors" + NPZ = "npz" + DUMMY = "dummy" + + +class TokenizerMode(str, Enum): + """Tokenizer modes.""" + AUTO = "auto" + SLOW = "slow" + FAST = "fast" + + +class PoolingType(str, Enum): + """Pooling types for embedding models.""" + MEAN = "mean" + MAX = "max" + CLS = "cls" + LAST = "last" + + +class SpeculativeMode(str, Enum): + """Speculative decoding modes.""" + SMALL_MODEL = "small_model" + DRAFT_MODEL = "draft_model" + MEDUSA = "medusa" + + +# ============================================================================ +# Configuration Models +# ============================================================================ + +class ModelConfig(BaseModel): + """Model-specific configuration.""" + model: str = Field(..., description="Model name or path") + tokenizer: Optional[str] = Field(None, description="Tokenizer name or path") + tokenizer_mode: TokenizerMode = Field(TokenizerMode.AUTO, description="Tokenizer mode") + trust_remote_code: bool = Field(False, description="Trust remote code") + download_dir: Optional[str] = Field(None, description="Download directory") + load_format: LoadFormat = Field(LoadFormat.AUTO, description="Model loading format") + dtype: str = Field("auto", description="Data type") + seed: int = Field(0, description="Random seed") + revision: Optional[str] = Field(None, description="Model revision") + code_revision: Optional[str] = Field(None, description="Code revision") + max_model_len: Optional[int] = Field(None, description="Maximum model length") + quantization: Optional[QuantizationMethod] = Field(None, description="Quantization method") + enforce_eager: bool = Field(False, description="Enforce eager execution") + max_seq_len_to_capture: int = Field(8192, description="Max sequence length to capture") + disable_custom_all_reduce: bool = Field(False, description="Disable custom all-reduce") + skip_tokenizer_init: bool = Field(False, description="Skip tokenizer initialization") + + class Config: + json_schema_extra = { + "example": { + "model": "microsoft/DialoGPT-medium", + "tokenizer_mode": "auto", + "trust_remote_code": False, + "load_format": "auto", + "dtype": "auto" + } + } + + +class CacheConfig(BaseModel): + """KV cache configuration.""" + block_size: int = Field(16, description="Block size for KV cache") + gpu_memory_utilization: float = Field(0.9, description="GPU memory utilization") + swap_space: int = Field(4, description="Swap space in GB") + cache_dtype: str = Field("auto", description="Cache data type") + num_gpu_blocks_override: Optional[int] = Field(None, description="Override number of GPU blocks") + num_cpu_blocks_override: Optional[int] = Field(None, description="Override number of CPU blocks") + block_space_policy: BlockSpacePolicy = Field(BlockSpacePolicy.GUARDED, description="Block space policy") + kv_space_policy: KVSpacePolicy = Field(KVSpacePolicy.EAGER, description="KV space policy") + enable_prefix_caching: bool = Field(False, description="Enable prefix caching") + enable_chunked_prefill: bool = Field(False, description="Enable chunked prefill") + preemption_mode: str = Field("recompute", description="Preemption mode") + enable_hybrid_engine: bool = Field(False, description="Enable hybrid engine") + num_lookahead_slots: int = Field(0, description="Number of lookahead slots") + delay_factor: float = Field(0.0, description="Delay factor") + enable_sliding_window: bool = Field(False, description="Enable sliding window") + sliding_window_size: Optional[int] = Field(None, description="Sliding window size") + sliding_window_blocks: Optional[int] = Field(None, description="Sliding window blocks") + + class Config: + json_schema_extra = { + "example": { + "block_size": 16, + "gpu_memory_utilization": 0.9, + "swap_space": 4, + "cache_dtype": "auto" + } + } + + +class LoadConfig(BaseModel): + """Model loading configuration.""" + max_model_len: Optional[int] = Field(None, description="Maximum model length") + max_num_batched_tokens: Optional[int] = Field(None, description="Maximum batched tokens") + max_num_seqs: Optional[int] = Field(None, description="Maximum number of sequences") + max_paddings: Optional[int] = Field(None, description="Maximum paddings") + max_lora_rank: int = Field(16, description="Maximum LoRA rank") + max_loras: int = Field(1, description="Maximum number of LoRAs") + max_cpu_loras: int = Field(2, description="Maximum CPU LoRAs") + lora_extra_vocab_size: int = Field(256, description="LoRA extra vocabulary size") + lora_dtype: str = Field("auto", description="LoRA data type") + device_map: Optional[str] = Field(None, description="Device map") + load_in_low_bit: Optional[str] = Field(None, description="Load in low bit") + load_in_4bit: bool = Field(False, description="Load in 4-bit") + load_in_8bit: bool = Field(False, description="Load in 8-bit") + load_in_symmetric: bool = Field(True, description="Load in symmetric") + load_in_nested: bool = Field(False, description="Load in nested") + load_in_half: bool = Field(False, description="Load in half precision") + load_in_bfloat16: bool = Field(False, description="Load in bfloat16") + load_in_float16: bool = Field(False, description="Load in float16") + load_in_float32: bool = Field(False, description="Load in float32") + load_in_int8: bool = Field(False, description="Load in int8") + load_in_int4: bool = Field(False, description="Load in int4") + load_in_int2: bool = Field(False, description="Load in int2") + load_in_int1: bool = Field(False, description="Load in int1") + load_in_bool: bool = Field(False, description="Load in bool") + load_in_uint8: bool = Field(False, description="Load in uint8") + load_in_uint4: bool = Field(False, description="Load in uint4") + load_in_uint2: bool = Field(False, description="Load in uint2") + load_in_uint1: bool = Field(False, description="Load in uint1") + load_in_complex64: bool = Field(False, description="Load in complex64") + load_in_complex128: bool = Field(False, description="Load in complex128") + load_in_quint8: bool = Field(False, description="Load in quint8") + load_in_quint4x2: bool = Field(False, description="Load in quint4x2") + load_in_quint2x4: bool = Field(False, description="Load in quint2x4") + load_in_quint1x8: bool = Field(False, description="Load in quint1x8") + load_in_qint8: bool = Field(False, description="Load in qint8") + load_in_qint4: bool = Field(False, description="Load in qint4") + load_in_qint2: bool = Field(False, description="Load in qint2") + load_in_qint1: bool = Field(False, description="Load in qint1") + load_in_bfloat8: bool = Field(False, description="Load in bfloat8") + load_in_float8: bool = Field(False, description="Load in float8") + load_in_half_bfloat16: bool = Field(False, description="Load in half bfloat16") + load_in_half_float16: bool = Field(False, description="Load in half float16") + load_in_half_float32: bool = Field(False, description="Load in half float32") + load_in_half_int8: bool = Field(False, description="Load in half int8") + load_in_half_int4: bool = Field(False, description="Load in half int4") + load_in_half_int2: bool = Field(False, description="Load in half int2") + load_in_half_int1: bool = Field(False, description="Load in half int1") + load_in_half_bool: bool = Field(False, description="Load in half bool") + load_in_half_uint8: bool = Field(False, description="Load in half uint8") + load_in_half_uint4: bool = Field(False, description="Load in half uint4") + load_in_half_uint2: bool = Field(False, description="Load in half uint2") + load_in_half_uint1: bool = Field(False, description="Load in half uint1") + load_in_half_complex64: bool = Field(False, description="Load in half complex64") + load_in_half_complex128: bool = Field(False, description="Load in half complex128") + load_in_half_quint8: bool = Field(False, description="Load in half quint8") + load_in_half_quint4x2: bool = Field(False, description="Load in half quint4x2") + load_in_half_quint2x4: bool = Field(False, description="Load in half quint2x4") + load_in_half_quint1x8: bool = Field(False, description="Load in half quint1x8") + load_in_half_qint8: bool = Field(False, description="Load in half qint8") + load_in_half_qint4: bool = Field(False, description="Load in half qint4") + load_in_half_qint2: bool = Field(False, description="Load in half qint2") + load_in_half_qint1: bool = Field(False, description="Load in half qint1") + load_in_half_bfloat8: bool = Field(False, description="Load in half bfloat8") + load_in_half_float8: bool = Field(False, description="Load in half float8") + + class Config: + json_schema_extra = { + "example": { + "max_model_len": 4096, + "max_num_batched_tokens": 8192, + "max_num_seqs": 256 + } + } + + +class ParallelConfig(BaseModel): + """Parallel execution configuration.""" + pipeline_parallel_size: int = Field(1, description="Pipeline parallel size") + tensor_parallel_size: int = Field(1, description="Tensor parallel size") + worker_use_ray: bool = Field(False, description="Use Ray for workers") + engine_use_ray: bool = Field(False, description="Use Ray for engine") + disable_custom_all_reduce: bool = Field(False, description="Disable custom all-reduce") + max_parallel_loading_workers: Optional[int] = Field(None, description="Max parallel loading workers") + ray_address: Optional[str] = Field(None, description="Ray cluster address") + placement_group: Optional[Dict[str, Any]] = Field(None, description="Ray placement group") + ray_runtime_env: Optional[Dict[str, Any]] = Field(None, description="Ray runtime environment") + + class Config: + json_schema_extra = { + "example": { + "pipeline_parallel_size": 1, + "tensor_parallel_size": 1, + "worker_use_ray": False + } + } + + +class SchedulerConfig(BaseModel): + """Scheduler configuration.""" + max_num_batched_tokens: int = Field(8192, description="Maximum batched tokens") + max_num_seqs: int = Field(256, description="Maximum number of sequences") + max_paddings: int = Field(256, description="Maximum paddings") + use_v2_block_manager: bool = Field(False, description="Use v2 block manager") + enable_chunked_prefill: bool = Field(False, description="Enable chunked prefill") + preemption_mode: str = Field("recompute", description="Preemption mode") + num_lookahead_slots: int = Field(0, description="Number of lookahead slots") + delay_factor: float = Field(0.0, description="Delay factor") + enable_sliding_window: bool = Field(False, description="Enable sliding window") + sliding_window_size: Optional[int] = Field(None, description="Sliding window size") + sliding_window_blocks: Optional[int] = Field(None, description="Sliding window blocks") + + class Config: + json_schema_extra = { + "example": { + "max_num_batched_tokens": 8192, + "max_num_seqs": 256, + "max_paddings": 256 + } + } + + +class DeviceConfig(BaseModel): + """Device configuration.""" + device: DeviceType = Field(DeviceType.CUDA, description="Device type") + device_id: int = Field(0, description="Device ID") + memory_fraction: float = Field(1.0, description="Memory fraction") + + class Config: + json_schema_extra = { + "example": { + "device": "cuda", + "device_id": 0, + "memory_fraction": 1.0 + } + } + + +class ObservabilityConfig(BaseModel): + """Observability configuration.""" + disable_log_stats: bool = Field(False, description="Disable log statistics") + disable_log_requests: bool = Field(False, description="Disable log requests") + log_requests: bool = Field(False, description="Log requests") + log_stats: bool = Field(False, description="Log statistics") + log_level: str = Field("INFO", description="Log level") + log_file: Optional[str] = Field(None, description="Log file") + log_format: str = Field("%(asctime)s - %(name)s - %(levelname)s - %(message)s", description="Log format") + + class Config: + json_schema_extra = { + "example": { + "disable_log_stats": False, + "disable_log_requests": False, + "log_level": "INFO" + } + } + + +class VllmConfig(BaseModel): + """Complete VLLM configuration aggregating all components.""" + model: ModelConfig = Field(..., description="Model configuration") + cache: CacheConfig = Field(..., description="Cache configuration") + load: LoadConfig = Field(..., description="Load configuration") + parallel: ParallelConfig = Field(..., description="Parallel configuration") + scheduler: SchedulerConfig = Field(..., description="Scheduler configuration") + device: DeviceConfig = Field(..., description="Device configuration") + observability: ObservabilityConfig = Field(..., description="Observability configuration") + + class Config: + json_schema_extra = { + "example": { + "model": { + "model": "microsoft/DialoGPT-medium", + "tokenizer_mode": "auto" + }, + "cache": { + "block_size": 16, + "gpu_memory_utilization": 0.9 + }, + "load": { + "max_model_len": 4096 + }, + "parallel": { + "pipeline_parallel_size": 1, + "tensor_parallel_size": 1 + }, + "scheduler": { + "max_num_batched_tokens": 8192, + "max_num_seqs": 256 + }, + "device": { + "device": "cuda", + "device_id": 0 + }, + "observability": { + "disable_log_stats": False, + "log_level": "INFO" + } + } + } + + +class SamplingParams(BaseModel): + """Sampling parameters for text generation.""" + n: int = Field(1, description="Number of output sequences to generate") + best_of: Optional[int] = Field(None, description="Number of sequences to generate and return the best") + presence_penalty: float = Field(0.0, description="Presence penalty") + frequency_penalty: float = Field(0.0, description="Frequency penalty") + repetition_penalty: float = Field(1.0, description="Repetition penalty") + temperature: float = Field(1.0, description="Sampling temperature") + top_p: float = Field(1.0, description="Top-p sampling parameter") + top_k: int = Field(-1, description="Top-k sampling parameter") + min_p: float = Field(0.0, description="Minimum probability threshold") + use_beam_search: bool = Field(False, description="Use beam search") + length_penalty: float = Field(1.0, description="Length penalty for beam search") + early_stopping: Union[bool, str] = Field(False, description="Early stopping for beam search") + stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences") + stop_token_ids: Optional[List[int]] = Field(None, description="Stop token IDs") + include_stop_str_in_output: bool = Field(False, description="Include stop string in output") + ignore_eos: bool = Field(False, description="Ignore end-of-sequence token") + skip_special_tokens: bool = Field(True, description="Skip special tokens in output") + spaces_between_special_tokens: bool = Field(True, description="Add spaces between special tokens") + logits_processor: Optional[List[Callable]] = Field(None, description="Logits processors") + prompt_logprobs: Optional[int] = Field(None, description="Number of logprobs for prompt tokens") + detokenize: bool = Field(True, description="Detokenize output") + seed: Optional[int] = Field(None, description="Random seed") + logprobs: Optional[int] = Field(None, description="Number of logprobs to return") + + class Config: + json_schema_extra = { + "example": { + "temperature": 0.7, + "top_p": 0.9, + "max_tokens": 50, + "stop": ["\n", "Human:"] + } + } + + +# ============================================================================ +# VLLM Integration Models +# ============================================================================ + +class VLLMServerConfig(BaseModel): + """Configuration for VLLM server deployment.""" + model_name: str = Field(..., description="Model name or path") + host: str = Field("0.0.0.0", description="Server host") + port: int = Field(8000, description="Server port") + gpu_memory_utilization: float = Field(0.9, description="GPU memory utilization") + max_model_len: int = Field(4096, description="Maximum model length") + dtype: str = Field("auto", description="Data type for model") + trust_remote_code: bool = Field(False, description="Trust remote code") + download_dir: Optional[str] = Field(None, description="Download directory for models") + load_format: str = Field("auto", description="Model loading format") + tensor_parallel_size: int = Field(1, description="Tensor parallel size") + pipeline_parallel_size: int = Field(1, description="Pipeline parallel size") + max_num_seqs: int = Field(256, description="Maximum number of sequences") + max_num_batched_tokens: int = Field(8192, description="Maximum batched tokens") + max_paddings: int = Field(256, description="Maximum paddings") + disable_log_stats: bool = Field(False, description="Disable log statistics") + revision: Optional[str] = Field(None, description="Model revision") + code_revision: Optional[str] = Field(None, description="Code revision") + tokenizer: Optional[str] = Field(None, description="Tokenizer name") + tokenizer_mode: str = Field("auto", description="Tokenizer mode") + skip_tokenizer_init: bool = Field(False, description="Skip tokenizer initialization") + enforce_eager: bool = Field(False, description="Enforce eager execution") + max_seq_len_to_capture: int = Field(8192, description="Max sequence length to capture") + + class Config: + json_schema_extra = { + "example": { + "model_name": "microsoft/DialoGPT-medium", + "host": "0.0.0.0", + "port": 8000, + "gpu_memory_utilization": 0.9, + "max_model_len": 4096 + } + } + + +class VLLMEmbeddingServerConfig(BaseModel): + """Configuration for VLLM embedding server deployment.""" + model_name: str = Field(..., description="Embedding model name or path") + host: str = Field("0.0.0.0", description="Server host") + port: int = Field(8001, description="Server port") + gpu_memory_utilization: float = Field(0.9, description="GPU memory utilization") + max_model_len: int = Field(512, description="Maximum model length for embeddings") + dtype: str = Field("auto", description="Data type for model") + trust_remote_code: bool = Field(False, description="Trust remote code") + download_dir: Optional[str] = Field(None, description="Download directory for models") + load_format: str = Field("auto", description="Model loading format") + tensor_parallel_size: int = Field(1, description="Tensor parallel size") + pipeline_parallel_size: int = Field(1, description="Pipeline parallel size") + max_num_seqs: int = Field(256, description="Maximum number of sequences") + max_num_batched_tokens: int = Field(8192, description="Maximum batched tokens") + max_paddings: int = Field(256, description="Maximum paddings") + disable_log_stats: bool = Field(False, description="Disable log statistics") + + class Config: + json_schema_extra = { + "example": { + "model_name": "sentence-transformers/all-MiniLM-L6-v2", + "host": "0.0.0.0", + "port": 8001, + "gpu_memory_utilization": 0.9, + "max_model_len": 512 + } + } + + +# ============================================================================ +# Utility Functions +# ============================================================================ + +def create_vllm_config( + model: str, + gpu_memory_utilization: float = 0.9, + max_model_len: Optional[int] = None, + dtype: str = "auto", + trust_remote_code: bool = False, + **kwargs +) -> VllmConfig: + """Create a VLLM configuration with common defaults.""" + model_config = ModelConfig( + model=model, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len + ) + + cache_config = CacheConfig( + gpu_memory_utilization=gpu_memory_utilization + ) + + load_config = LoadConfig( + max_model_len=max_model_len + ) + + parallel_config = ParallelConfig() + scheduler_config = SchedulerConfig() + device_config = DeviceConfig() + observability_config = ObservabilityConfig() + + return VllmConfig( + model=model_config, + cache=cache_config, + load=load_config, + parallel=parallel_config, + scheduler=scheduler_config, + device=device_config, + observability=observability_config, + **kwargs + ) + + +def create_sampling_params( + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + max_tokens: int = 16, + stop: Optional[Union[str, List[str]]] = None, + **kwargs +) -> SamplingParams: + """Create sampling parameters with common defaults.""" + return SamplingParams( + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop=stop, + **kwargs + ) + + +class VLLMModelInfo: + """Information about a VLLM model.""" + + def __init__(self, name: str, size: int = 0, status: str = "unknown"): + self.name = name + self.size = size + self.status = status + + +class VLLMContainer(DockerContainer): + """ + VLLM Container for high-performance LLM inference. + + :param image: the vllm image to use (default: :code:`vllm/vllm-openai:latest`) + :param model_name: the model to load (default: :code:`microsoft/DialoGPT-medium`) + :param host: server host (default: :code:`0.0.0.0`) + :param port: server port (default: :code:`8000`) + :param gpu_memory_utilization: GPU memory utilization (default: :code:`0.9`) + :param max_model_len: maximum model length (default: :code:`4096`) + :param tensor_parallel_size: tensor parallel size (default: :code:`1`) + :param pipeline_parallel_size: pipeline parallel size (default: :code:`1`) + :param trust_remote_code: trust remote code (default: :code:`False`) + :param download_dir: directory to download models (default: :code:`None`) + :param vllm_home: the directory to mount for model data (default: :code:`None`) + + Examples: + + .. doctest:: + + >>> from testcontainers.vllm import VLLMContainer + >>> with VLLMContainer(model_name="gpt2") as vllm: + ... vllm.get_health_status() + {'status': 'healthy'} + + .. code-block:: python + + >>> import requests + >>> from testcontainers.vllm import VLLMContainer + >>> + >>> with VLLMContainer(model_name="microsoft/DialoGPT-medium") as vllm: + ... # Generate text + ... response = requests.post( + ... f"{vllm.get_api_url()}/v1/chat/completions", + ... json={ + ... "model": "microsoft/DialoGPT-medium", + ... "messages": [{"role": "user", "content": "Hello, how are you?"}], + ... "max_tokens": 50, + ... "temperature": 0.7 + ... } + ... ) + ... result = response.json() + ... print(result["choices"][0]["message"]["content"]) + """ + + VLLM_PORT = 8000 + + def __init__( + self, + image: str = "vllm/vllm-openai:latest", + model_name: str = "microsoft/DialoGPT-medium", + host: str = "0.0.0.0", + port: int = 8000, + gpu_memory_utilization: float = 0.9, + max_model_len: int = 4096, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + trust_remote_code: bool = False, + download_dir: Optional[Union[str, PathLike]] = None, + vllm_home: Optional[Union[str, PathLike]] = None, + **kwargs, + ): + super().__init__(image=image, **kwargs) + + # Store configuration + self.model_name = model_name + self.host = host + self.port = port + self.gpu_memory_utilization = gpu_memory_utilization + self.max_model_len = max_model_len + self.tensor_parallel_size = tensor_parallel_size + self.pipeline_parallel_size = pipeline_parallel_size + self.trust_remote_code = trust_remote_code + self.download_dir = download_dir + self.vllm_home = vllm_home + + # Expose the VLLM port + self.with_exposed_ports(VLLMContainer.VLLM_PORT) + + # Add GPU capabilities if available + self._check_and_add_gpu_capabilities() + + # Set up volume mappings + self._setup_volume_mappings() + + def _check_and_add_gpu_capabilities(self): + """Check for GPU capabilities and add them if available.""" + try: + info = self.get_docker_client().client.info() + if "nvidia" in info.get("Runtimes", {}): + self._kwargs = { + **self._kwargs, + "device_requests": [DeviceRequest(count=-1, capabilities=[["gpu"]])] + } + except Exception: + # If we can't detect GPU capabilities, continue without them + pass + + def _setup_volume_mappings(self): + """Set up volume mappings for model storage.""" + if self.vllm_home: + self.with_volume_mapping(self.vllm_home, "/root/.cache/huggingface", "rw") + + if self.download_dir: + self.with_volume_mapping(self.download_dir, "/models", "rw") + + def _build_vllm_command(self) -> List[str]: + """Build the VLLM server command.""" + cmd = [ + "python", "-m", "vllm.entrypoints.openai.api_server", + "--model", self.model_name, + "--host", self.host, + "--port", str(self.port), + "--gpu-memory-utilization", str(self.gpu_memory_utilization), + "--max-model-len", str(self.max_model_len), + "--tensor-parallel-size", str(self.tensor_parallel_size), + "--pipeline-parallel-size", str(self.pipeline_parallel_size), + ] + + if self.trust_remote_code: + cmd.append("--trust-remote-code") + + if self.download_dir: + cmd.extend(["--download-dir", "/models"]) + + return cmd + + def start(self) -> "VLLMContainer": + """ + Start the VLLM server. + """ + # Set the command to start VLLM server + cmd = self._build_vllm_command() + self.with_command(cmd) + + # Start the container + super().start() + + # Wait for the server to be ready + wait_for_logs(self, "Uvicorn running on", timeout=120) + + # Additional health check + self._wait_for_health_check() + + return self + + def _wait_for_health_check(self, timeout: int = 60): + """Wait for the VLLM server to be healthy.""" + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"{self.get_api_url()}/health", timeout=5) + if response.status_code == 200: + return + except requests.RequestException: + pass + time.sleep(2) + + raise RuntimeError(f"VLLM server did not become healthy within {timeout} seconds") + + def get_api_url(self) -> str: + """ + Return the API URL of the VLLM server. + """ + host = self.get_container_host_ip() + exposed_port = self.get_exposed_port(VLLMContainer.VLLM_PORT) + return f"http://{host}:{exposed_port}" + + def get_endpoint(self) -> str: + """ + Return the endpoint of the VLLM server (alias for get_api_url). + """ + return self.get_api_url() + + @property + def id(self) -> str: + """ + Return the container ID. + """ + return self._container.id + + def get_health_status(self) -> Dict[str, Any]: + """ + Get the health status of the VLLM server. + """ + try: + response = requests.get(f"{self.get_api_url()}/health", timeout=10) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + return {"status": "unhealthy", "error": str(e)} + + def list_models(self) -> List[Dict[str, Any]]: + """ + List available models. + """ + try: + response = requests.get(f"{self.get_api_url()}/v1/models", timeout=10) + response.raise_for_status() + return response.json().get("data", []) + except requests.RequestException as e: + raise RuntimeError(f"Failed to list models: {e}") + + def generate_text( + self, + prompt: str, + max_tokens: int = 50, + temperature: float = 0.7, + top_p: float = 0.9, + **kwargs + ) -> str: + """ + Generate text using the VLLM server. + + Args: + prompt: The input prompt + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling parameter + **kwargs: Additional parameters + + Returns: + Generated text + """ + payload = { + "model": self.model_name, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "stream": False, + **kwargs + } + + try: + response = requests.post( + f"{self.get_api_url()}/v1/chat/completions", + json=payload, + timeout=60 + ) + response.raise_for_status() + result = response.json() + return result["choices"][0]["message"]["content"] + except requests.RequestException as e: + raise RuntimeError(f"Failed to generate text: {e}") + + def generate_embeddings( + self, + texts: Union[str, List[str]], + model: Optional[str] = None + ) -> Union[List[float], List[List[float]]]: + """ + Generate embeddings using the VLLM server. + + Args: + texts: Text or list of texts to embed + model: Model to use for embeddings (defaults to self.model_name) + + Returns: + Embeddings as list of floats or list of lists of floats + """ + if model is None: + model = self.model_name + + if isinstance(texts, str): + texts = [texts] + single_text = True + else: + single_text = False + + payload = { + "model": model, + "input": texts, + "encoding_format": "float" + } + + try: + response = requests.post( + f"{self.get_api_url()}/v1/embeddings", + json=payload, + timeout=60 + ) + response.raise_for_status() + result = response.json() + embeddings = [item["embedding"] for item in result["data"]] + + if single_text: + return embeddings[0] + return embeddings + except requests.RequestException as e: + raise RuntimeError(f"Failed to generate embeddings: {e}") + + def get_vllm_config(self) -> VllmConfig: + """ + Get the VLLM configuration used by this container. + """ + return create_vllm_config( + model=self.model_name, + gpu_memory_utilization=self.gpu_memory_utilization, + max_model_len=self.max_model_len, + trust_remote_code=self.trust_remote_code + ) + + def get_server_config(self) -> VLLMServerConfig: + """ + Get the VLLM server configuration. + """ + return VLLMServerConfig( + model_name=self.model_name, + host=self.host, + port=self.port, + gpu_memory_utilization=self.gpu_memory_utilization, + max_model_len=self.max_model_len, + tensor_parallel_size=self.tensor_parallel_size, + pipeline_parallel_size=self.pipeline_parallel_size, + trust_remote_code=self.trust_remote_code, + download_dir=str(self.download_dir) if self.download_dir else None + ) + + def commit_to_image(self, image_name: str) -> None: + """ + Commit the current container to a new image. + + Args: + image_name: Name of the new image + """ + docker_client = self.get_docker_client() + existing_images = docker_client.client.images.list(name=image_name) + if not existing_images and self.id: + docker_client.client.containers.get(self.id).commit( + repository=image_name, + conf={"Labels": {"org.testcontainers.session-id": ""}} + ) + + def get_metrics(self) -> Dict[str, Any]: + """ + Get server metrics. + """ + try: + response = requests.get(f"{self.get_api_url()}/metrics", timeout=10) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + raise RuntimeError(f"Failed to get metrics: {e}") + + def get_server_info(self) -> Dict[str, Any]: + """ + Get server information. + """ + try: + response = requests.get(f"{self.get_api_url()}/v1/models", timeout=10) + response.raise_for_status() + models_data = response.json() + + health = self.get_health_status() + + return { + "health": health, + "models": models_data, + "api_url": self.get_api_url(), + "model_name": self.model_name, + "config": { + "model_name": self.model_name, + "host": self.host, + "port": self.port, + "gpu_memory_utilization": self.gpu_memory_utilization, + "max_model_len": self.max_model_len, + "tensor_parallel_size": self.tensor_parallel_size, + "pipeline_parallel_size": self.pipeline_parallel_size, + "trust_remote_code": self.trust_remote_code, + "download_dir": str(self.download_dir) if self.download_dir else None + } + } + except requests.RequestException as e: + raise RuntimeError(f"Failed to get server info: {e}") \ No newline at end of file diff --git a/modules/vllm/tests/test_vllm.py b/modules/vllm/tests/test_vllm.py new file mode 100644 index 00000000..28189ed7 --- /dev/null +++ b/modules/vllm/tests/test_vllm.py @@ -0,0 +1,381 @@ +import json +import random +import string +import time +from pathlib import Path + +import pytest +import requests +from testcontainers.vllm import VLLMContainer + + +def random_string(length=6): + """Generate a random string for testing.""" + return "".join(random.choices(string.ascii_lowercase, k=length)) + + +def test_vllm_container_basic(): + """Test basic VLLM container functionality.""" + with VLLMContainer(model_name="gpt2") as vllm: + # Test API URL + api_url = vllm.get_api_url() + assert api_url.startswith("http://") + assert ":8000" in api_url + + # Test health check + health = vllm.get_health_status() + assert "status" in health + + # Test container ID + assert vllm.id is not None + + +def test_vllm_container_with_custom_config(): + """Test VLLM container with custom configuration.""" + with VLLMContainer( + model_name="gpt2", + gpu_memory_utilization=0.8, + max_model_len=1024, + tensor_parallel_size=1, + trust_remote_code=False + ) as vllm: + # Test configuration + config = vllm.get_vllm_config() + assert config.model.model == "gpt2" + assert config.cache.gpu_memory_utilization == 0.8 + assert config.model.max_model_len == 1024 + + # Test server configuration + server_config = vllm.get_server_config() + assert server_config.model_name == "gpt2" + assert server_config.gpu_memory_utilization == 0.8 + assert server_config.max_model_len == 1024 + + +def test_vllm_text_generation(): + """Test text generation functionality.""" + with VLLMContainer(model_name="gpt2") as vllm: + # Test direct API call + api_url = vllm.get_api_url() + response = requests.post( + f"{api_url}/v1/chat/completions", + json={ + "model": "gpt2", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + "temperature": 0.7 + }, + timeout=30 + ) + + assert response.status_code == 200 + result = response.json() + assert "choices" in result + assert len(result["choices"]) > 0 + assert "message" in result["choices"][0] + assert "content" in result["choices"][0]["message"] + + +def test_vllm_container_methods(): + """Test VLLM container helper methods.""" + with VLLMContainer(model_name="gpt2") as vllm: + # Test generate_text method + text = vllm.generate_text( + prompt="Hello", + max_tokens=5, + temperature=0.7 + ) + assert isinstance(text, str) + assert len(text) > 0 + + +def test_vllm_models_list(): + """Test listing available models.""" + with VLLMContainer(model_name="gpt2") as vllm: + models = vllm.list_models() + assert isinstance(models, list) + # Should have at least one model (the one we loaded) + assert len(models) > 0 + + +def test_vllm_server_info(): + """Test getting server information.""" + with VLLMContainer(model_name="gpt2") as vllm: + server_info = vllm.get_server_info() + assert "health" in server_info + assert "models" in server_info + assert "api_url" in server_info + assert "model_name" in server_info + assert "config" in server_info + assert server_info["model_name"] == "gpt2" + + +def test_vllm_embeddings(): + """Test embedding generation (if supported by model).""" + with VLLMContainer(model_name="gpt2") as vllm: + try: + # Test single text embedding + embedding = vllm.generate_embeddings("Hello, world!") + assert isinstance(embedding, list) + assert len(embedding) > 0 + assert all(isinstance(x, (int, float)) for x in embedding) + + # Test batch embeddings + texts = ["Hello", "World", "Test"] + batch_embeddings = vllm.generate_embeddings(texts) + assert isinstance(batch_embeddings, list) + assert len(batch_embeddings) == 3 + assert all(isinstance(emb, list) for emb in batch_embeddings) + + except Exception as e: + # Some models may not support embeddings + pytest.skip(f"Embeddings not supported: {e}") + + +def test_vllm_streaming(): + """Test streaming text generation.""" + with VLLMContainer(model_name="gpt2") as vllm: + api_url = vllm.get_api_url() + + try: + response = requests.post( + f"{api_url}/v1/chat/completions", + json={ + "model": "gpt2", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + "temperature": 0.7, + "stream": True + }, + stream=True, + timeout=30 + ) + + assert response.status_code == 200 + + # Collect streaming chunks + chunks = [] + for line in response.iter_lines(): + if line: + line = line.decode('utf-8') + if line.startswith('data: '): + data = line[6:] + if data == '[DONE]': + break + try: + chunk = json.loads(data) + if 'choices' in chunk and len(chunk['choices']) > 0: + delta = chunk['choices'][0].get('delta', {}) + if 'content' in delta: + chunks.append(delta['content']) + except json.JSONDecodeError: + continue + + # Should have received some chunks + assert len(chunks) > 0 + + except Exception as e: + pytest.skip(f"Streaming not supported: {e}") + + +def test_vllm_with_model_persistence(tmp_path: Path): + """Test VLLM container with model persistence.""" + with VLLMContainer( + model_name="gpt2", + vllm_home=tmp_path + ) as vllm: + # Test that container starts successfully with persistent storage + health = vllm.get_health_status() + assert "status" in health + + # Test basic functionality + models = vllm.list_models() + assert isinstance(models, list) + + +def test_vllm_commit_to_image(): + """Test committing container to image.""" + new_image_name = f"tc-vllm-test-{random_string(length=4).lower()}" + + with VLLMContainer(model_name="gpt2") as vllm: + # Test that container works + health = vllm.get_health_status() + assert "status" in health + + # Commit to new image + vllm.commit_to_image(new_image_name) + + # Verify the new image exists and works + with VLLMContainer(new_image_name) as vllm: + health = vllm.get_health_status() + assert "status" in health + + +def test_vllm_metrics(): + """Test getting server metrics.""" + with VLLMContainer(model_name="gpt2") as vllm: + try: + metrics = vllm.get_metrics() + # Metrics endpoint might return different formats + assert isinstance(metrics, (dict, str)) + except Exception as e: + # Metrics endpoint might not be available + pytest.skip(f"Metrics not available: {e}") + + +def test_vllm_error_handling(): + """Test error handling for invalid requests.""" + with VLLMContainer(model_name="gpt2") as vllm: + api_url = vllm.get_api_url() + + # Test invalid model name + response = requests.post( + f"{api_url}/v1/chat/completions", + json={ + "model": "invalid-model", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10 + }, + timeout=10 + ) + + # Should return an error + assert response.status_code != 200 + + +def test_vllm_different_models(): + """Test VLLM container with different model configurations.""" + models_to_test = [ + "gpt2", + "microsoft/DialoGPT-small" + ] + + for model_name in models_to_test: + with VLLMContainer(model_name=model_name) as vllm: + # Test basic functionality + health = vllm.get_health_status() + assert "status" in health + + # Test configuration + config = vllm.get_vllm_config() + assert config.model.model == model_name + + +def test_vllm_parallel_configuration(): + """Test VLLM container with parallel processing configuration.""" + with VLLMContainer( + model_name="gpt2", + tensor_parallel_size=1, + pipeline_parallel_size=1 + ) as vllm: + # Test configuration + server_config = vllm.get_server_config() + assert server_config.tensor_parallel_size == 1 + assert server_config.pipeline_parallel_size == 1 + + # Test basic functionality + health = vllm.get_health_status() + assert "status" in health + + +def test_vllm_custom_ports(): + """Test VLLM container with custom port configuration.""" + with VLLMContainer( + model_name="gpt2", + port=8001 + ) as vllm: + # Test that the port is correctly configured + api_url = vllm.get_api_url() + assert ":8001" in api_url + + # Test basic functionality + health = vllm.get_health_status() + assert "status" in health + + +def test_vllm_trust_remote_code(): + """Test VLLM container with trust_remote_code option.""" + with VLLMContainer( + model_name="gpt2", + trust_remote_code=True + ) as vllm: + # Test configuration + config = vllm.get_vllm_config() + assert config.model.trust_remote_code is True + + # Test basic functionality + health = vllm.get_health_status() + assert "status" in health + + +def test_vllm_memory_utilization(): + """Test VLLM container with different memory utilization settings.""" + with VLLMContainer( + model_name="gpt2", + gpu_memory_utilization=0.7 + ) as vllm: + # Test configuration + config = vllm.get_vllm_config() + assert config.cache.gpu_memory_utilization == 0.7 + + # Test basic functionality + health = vllm.get_health_status() + assert "status" in health + + +def test_vllm_max_model_length(): + """Test VLLM container with custom max model length.""" + with VLLMContainer( + model_name="gpt2", + max_model_len=2048 + ) as vllm: + # Test configuration + config = vllm.get_vllm_config() + assert config.model.max_model_len == 2048 + + # Test basic functionality + health = vllm.get_health_status() + assert "status" in health + + +def test_vllm_integration_with_dataclass(): + """Test integration with VLLM dataclass system.""" + from dataclasses import SamplingParams, create_sampling_params + + with VLLMContainer(model_name="gpt2") as vllm: + # Test getting VLLM configuration + config = vllm.get_vllm_config() + assert config is not None + assert config.model.model == "gpt2" + + # Test with sampling parameters + sampling_params = create_sampling_params( + temperature=0.8, + top_p=0.9, + max_tokens=50 + ) + assert sampling_params.temperature == 0.8 + assert sampling_params.top_p == 0.9 + + # Test basic functionality + health = vllm.get_health_status() + assert "status" in health + + +def test_vllm_integration_with_rag_system(): + """Test integration with VLLM RAG system.""" + from integration import VLLMServerConfig, VLLMDeployment + + with VLLMContainer(model_name="gpt2") as vllm: + # Test server configuration + server_config = vllm.get_server_config() + assert isinstance(server_config, VLLMServerConfig) + assert server_config.model_name == "gpt2" + + # Test deployment configuration + deployment = VLLMDeployment(llm_config=server_config) + assert deployment.llm_config.model_name == "gpt2" + + # Test basic functionality + health = vllm.get_health_status() + assert "status" in health