diff --git a/README.md b/README.md index a83030f..bcdc27a 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,159 @@ Additional functionality like cost modelling and MLFlow experiment tracking is e For more details, check out our selection of end-to-end code examples in the [examples](https://github.com/awslabs/llmeter/tree/main/examples) folder! +## πŸ–ΌοΈ Multi-Modal Payload Support + +LLMeter supports creating payloads with multi-modal content including images, videos, audio, and documents alongside text. This enables testing of modern multi-modal AI models. + +### Installation for Multi-Modal Support + +For enhanced format detection from file content (recommended), install the optional `multimodal` extra: + +```terminal +pip install 'llmeter[multimodal]' +``` + +Or with uv: + +```terminal +uv pip install 'llmeter[multimodal]' +``` + +This installs the `puremagic` library for content-based format detection using magic bytes. Without it, format detection falls back to file extensions. + +### Basic Multi-Modal Usage + +```python +from llmeter.endpoints import BedrockConverse + +# Single image from file +payload = BedrockConverse.create_payload( + user_message="What is in this image?", + images=["photo.jpg"], + max_tokens=256 +) + +# Multiple images +payload = BedrockConverse.create_payload( + user_message="Compare these images:", + images=["image1.jpg", "image2.png"], + max_tokens=512 +) + +# Image from bytes (requires puremagic for format detection) +with open("photo.jpg", "rb") as f: + image_bytes = f.read() + +payload = BedrockConverse.create_payload( + user_message="What is in this image?", + images=[image_bytes], + max_tokens=256 +) + +# Mixed content types +payload = BedrockConverse.create_payload( + user_message="Analyze this presentation and supporting materials", + documents=["slides.pdf"], + images=["chart.png"], + max_tokens=1024 +) + +# Video analysis +payload = BedrockConverse.create_payload( + user_message="Describe what happens in this video", + videos=["clip.mp4"], + max_tokens=1024 +) +``` + +### Supported Content Types + +- **Images**: JPEG, PNG, GIF, WebP +- **Documents**: PDF +- **Videos**: MP4, MOV, AVI +- **Audio**: MP3, WAV, OGG + +Format support varies by model. The library detects formats automatically and lets the API endpoint validate compatibility. + +### Endpoint-Specific Format Handling + +Different endpoints expect different format strings: + +- **Bedrock**: Uses short format strings (e.g., `"jpeg"`, `"png"`, `"pdf"`) +- **OpenAI**: Uses full MIME types (e.g., `"image/jpeg"`, `"image/png"`) +- **SageMaker**: Uses Bedrock format by default (model-dependent) + +The library handles these differences automatically based on the endpoint you're using. + +### ⚠️ Security Warning: Format Detection Is NOT Input Validation + +**IMPORTANT**: The format detection in this library is for testing and development convenience ONLY. It is NOT a security mechanism and MUST NOT be used with untrusted files without proper validation. + +#### What This Library Does + +- Detects likely file format from magic bytes (puremagic) or extension (mimetypes) +- Reads binary content from files +- Packages content for API endpoints +- Provides type checking (bytes vs strings) + +#### What This Library Does NOT Do + +- ❌ Validate file content safety or integrity +- ❌ Scan for malicious content or malware +- ❌ Sanitize or clean file data +- ❌ Protect against malformed or exploited files +- ❌ Guarantee format correctness beyond detection heuristics +- ❌ Validate file size or prevent memory exhaustion +- ❌ Check for embedded scripts or exploits +- ❌ Verify file authenticity or source + +#### Intended Use Cases + +This format detection is designed for: + +- **Testing and development**: Loading known-safe test files during development +- **Internal tools**: Processing files from trusted internal sources +- **Prototyping**: Quick experimentation with multi-modal models +- **Controlled environments**: Scenarios where file sources are fully trusted + +#### NOT Intended For + +This format detection should NOT be used for: + +- **Production user uploads**: Files uploaded by end users through web forms or APIs +- **External file sources**: Files from untrusted URLs, email attachments, or third-party systems +- **Security-sensitive applications**: Any application where file safety is critical +- **Public-facing services**: Services that accept files from the internet + +#### Recommended Security Practices for Untrusted Files + +When working with untrusted files (user uploads, external sources, etc.), you MUST implement proper security measures: + +1. **Validate file sources**: Only accept files from trusted, authenticated sources +2. **Scan for malware**: Use antivirus/malware scanning (e.g., ClamAV) before processing +3. **Validate file integrity**: Verify checksums, digital signatures, or other integrity mechanisms +4. **Sanitize content**: Use specialized libraries to validate and sanitize file content: + - Images: Re-encode with PIL/Pillow to strip metadata and validate structure + - PDFs: Use PDF sanitization libraries to remove scripts and validate structure + - Videos: Re-encode with ffmpeg to validate and sanitize +5. **Limit file sizes**: Enforce maximum file size limits before reading into memory +6. **Sandbox processing**: Process untrusted files in isolated environments (containers, VMs) +7. **Validate API responses**: Check that API endpoints successfully processed the content +8. **Implement rate limiting**: Prevent abuse through excessive file uploads +9. **Log and monitor**: Track file processing for security auditing + +### Backward Compatibility + +Text-only payloads continue to work exactly as before: + +```python +# Still works - no changes needed +payload = BedrockConverse.create_payload( + user_message="Hello, world!", + max_tokens=256 +) +``` + ## Analyze and compare results You can analyze the results of a single run or a load test by generating interactive charts. You can find examples in in the [examples](examples) folder. diff --git a/docs/reference/json_utils.md b/docs/reference/json_utils.md new file mode 100644 index 0000000..a30ad78 --- /dev/null +++ b/docs/reference/json_utils.md @@ -0,0 +1 @@ +:::: llmeter.json_utils diff --git a/llmeter/callbacks/base.py b/llmeter/callbacks/base.py index 5c6a5d2..449f287 100644 --- a/llmeter/callbacks/base.py +++ b/llmeter/callbacks/base.py @@ -4,13 +4,21 @@ from __future__ import annotations -import os +import importlib +import json +import logging from abc import ABC -from typing import final +from typing import Any, final + +from upath.types import ReadablePathLike, WritablePathLike from ..endpoints.base import InvocationResponse +from ..json_utils import LLMeterEncoder from ..results import Result from ..runner import _RunConfig +from ..utils import ensure_path + +logger = logging.getLogger(__name__) class Callback(ABC): @@ -21,8 +29,17 @@ class Callback(ABC): associated with test runs or individual model invocations. A Callback object may implement multiple of the defined lifecycle hooks (such as - `before_invoke`, `after_run`, etc). Callbacks must support serializing their configuration to - a file (by implementing `save_to_file`), and loading back (via `load_from_file`). + `before_invoke`, `after_run`, etc). Callbacks support serializing their configuration via + ``to_dict()`` / ``from_dict()`` (and the convenience wrappers ``save_to_file()`` / + ``load_from_file()``). + + Serialization uses a ``_callback_type`` marker (``"module:ClassName"``) so that + ``Callback.from_dict()`` can dynamically import and reconstruct the correct subclass + without a hardcoded registry. This means third-party callbacks round-trip through JSON + automatically, as long as the defining module is importable at load time. + + Subclasses with complex nested state (like ``CostModel``) can override ``to_dict()`` and + ``from_dict()`` while preserving the type marker by calling ``super()``. """ async def before_invoke(self, payload: dict) -> None: @@ -70,46 +87,149 @@ async def after_run(self, result: Result) -> None: """ pass - def save_to_file(self, path: os.PathLike | str) -> None: - """Save this Callback to file + # -- Serialization ----------------------------------------------------------------- - Individual Callbacks implement this method to save their configuration to a file that will - be re-loadable with the equivalent `_load_from_file()` method. + def to_dict(self) -> dict: + """Serialize this callback's configuration to a JSON-safe dict. - Args: - path: (Local or Cloud) path where the callback is saved + The returned dict includes a ``_callback_type`` key with the fully-qualified + class path (``"module:ClassName"``), enabling ``Callback.from_dict`` to + reconstruct the correct subclass without a hardcoded registry. + + By default, all public (non-underscore-prefixed) instance attributes are + included. Subclasses with richer state should override this method and call + ``super().to_dict()`` to preserve the type marker. + + Returns: + dict: A JSON-serializable dictionary representation of this callback. + + Example:: + + >>> from llmeter.callbacks import CostModel + >>> from llmeter.callbacks.cost.dimensions import InputTokens + >>> model = CostModel(request_dims=[InputTokens(price_per_million=3.0)]) + >>> d = model.to_dict() + >>> d["_callback_type"] + 'llmeter.callbacks.cost.model:CostModel' """ - raise NotImplementedError("TODO: Callback.save_to_file is not yet implemented!") + cls = self.__class__ + data: dict[str, Any] = { + "_callback_type": f"{cls.__module__}:{cls.__qualname__}", + } + data.update({k: v for k, v in vars(self).items() if not k.startswith("_")}) + return data - @staticmethod - @final - def load_from_file(path: os.PathLike | str) -> Callback: - """Load (any type of) Callback from file + @classmethod + def from_dict(cls, raw: dict, **kwargs: Any) -> Callback: + """Reconstruct a Callback from a dict produced by ``to_dict()``. - `Callback.load_from_file()` attempts to detect the type of Callback saved in a given file, - and use the relevant implementation's `_load_from_file` method to load it. + Uses the ``_callback_type`` field to dynamically import and instantiate + the correct subclass. If called on a concrete subclass (e.g. + ``CostModel.from_dict(...)``), the ``_callback_type`` is still respected + so that the dict always controls which class is created. Args: - path: (Local or Cloud) path where the callback is saved + raw: A dictionary previously produced by ``to_dict()`` (or loaded from + JSON). Must contain a ``_callback_type`` key. + **kwargs: Extra keyword arguments forwarded to the resolved class + constructor (or its own ``from_dict`` if it overrides this method). + Returns: - callback: A loaded Callback - for example an `MlflowCallback`. + Callback: An instance of the appropriate Callback subclass. + + Raises: + ValueError: If ``_callback_type`` is missing from *raw*. + ImportError: If the module referenced by ``_callback_type`` cannot be + imported. + AttributeError: If the class name cannot be found in the referenced + module. + + Example:: + + >>> from llmeter.callbacks.base import Callback + >>> d = { + ... "_callback_type": "llmeter.callbacks.mlflow:MlflowCallback", + ... "step": 1, + ... "nested": False, + ... } + >>> cb = Callback.from_dict(d) # returns an MlflowCallback instance """ - raise NotImplementedError( - "TODO: Callback.load_from_file is not yet implemented!" - ) + raw = dict(raw) # shallow copy β€” don't mutate caller's dict + callback_type = raw.pop("_callback_type", None) + if callback_type is None: + raise ValueError( + "Cannot deserialize Callback: dict is missing '_callback_type' key. " + f"Got keys: {list(raw.keys())}" + ) + + module_path, class_name = callback_type.rsplit(":", 1) + module = importlib.import_module(module_path) + callback_cls = getattr(module, class_name) + + # If the resolved class has its own from_dict (e.g. CostModel), delegate to it + # so that subclass-specific deserialization logic is honoured. + if callback_cls is not cls and "from_dict" in callback_cls.__dict__: + # Re-inject _callback_type so the subclass from_dict can pop it if needed + return callback_cls.from_dict(raw, **kwargs) + + # Remove any keys the constructor doesn't expect (e.g. _type from JSONableBase) + raw.pop("_type", None) + return callback_cls(**raw, **kwargs) + + def to_json(self, **kwargs: Any) -> str: + """Serialize this callback to a JSON string. + + Args: + **kwargs: Extra keyword arguments forwarded to ``json.dumps`` + (e.g. ``indent``). + + Returns: + str: JSON representation of this callback. + """ + kwargs.setdefault("cls", LLMeterEncoder) + return json.dumps(self.to_dict(), **kwargs) @classmethod - def _load_from_file(cls, path: os.PathLike | str) -> Callback: - """Load this Callback from file + def from_json(cls, json_string: str, **kwargs: Any) -> Callback: + """Reconstruct a Callback from a JSON string produced by ``to_json()``. + + Args: + json_string: A valid JSON string. + **kwargs: Extra keyword arguments forwarded to ``from_dict``. + + Returns: + Callback: An instance of the appropriate Callback subclass. + """ + return cls.from_dict(json.loads(json_string), **kwargs) + + def save_to_file(self, path: WritablePathLike) -> None: + """Save this Callback's configuration to a JSON file. - Individual Callbacks implement this method to define how they can be loaded from files - created by the equivalent `save_to_file()` method. + The file can be loaded back with ``Callback.load_from_file(path)``. Args: - path: (Local or Cloud) path where the callback is saved + path: (Local or Cloud) path where the callback should be saved. + """ + path = ensure_path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as f: + f.write(self.to_json(indent=4)) + + @staticmethod + @final + def load_from_file(path: ReadablePathLike) -> Callback: + """Load (any type of) Callback from a JSON file. + + The ``_callback_type`` field inside the file determines which subclass is + instantiated, so callers don't need to know the concrete type in advance. + + Args: + path: (Local or Cloud) path to a JSON file previously created by + ``save_to_file()``. + Returns: - callback: The loaded Callback object + Callback: The deserialized callback instance. """ - raise NotImplementedError( - "TODO: Callback._load_from_file is not yet implemented!" - ) + path = ensure_path(path) + with path.open("r") as f: + return Callback.from_dict(json.load(f)) diff --git a/llmeter/callbacks/cost/model.py b/llmeter/callbacks/cost/model.py index d491aad..9f3170b 100644 --- a/llmeter/callbacks/cost/model.py +++ b/llmeter/callbacks/cost/model.py @@ -1,8 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 # Python Built-Ins: -from dataclasses import dataclass, field import importlib +from dataclasses import dataclass, field + +from llmeter.utils import ensure_path # Local Dependencies: from ...endpoints.base import InvocationResponse @@ -11,7 +13,7 @@ from ..base import Callback from .dimensions import IRequestCostDimension, IRunCostDimension from .results import CalculatedCostWithDimensions -from .serde import from_dict_with_class_map, JSONableBase +from .serde import JSONableBase, from_dict_with_class, from_dict_with_class_map @dataclass @@ -199,9 +201,23 @@ async def after_run(self, result: Result) -> None: result, recalculate_request_costs=False, save=True ) + def to_dict(self, **kwargs) -> dict: + """Serialize the cost model to a JSON-safe dict. + + Injects the ``_callback_type`` marker from ``Callback.to_dict()`` into the + dict produced by ``JSONableBase.to_dict()``, so that + ``Callback.from_dict()`` can reconstruct this ``CostModel`` dynamically. + """ + data = JSONableBase.to_dict(self, **kwargs) + cls = self.__class__ + data["_callback_type"] = f"{cls.__module__}:{cls.__qualname__}" + return data + def save_to_file(self, path: str) -> None: """Save the cost model (including all dimensions) to a JSON file""" - with open(path, "w") as f: + path = ensure_path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as f: f.write(self.to_json()) @classmethod @@ -211,16 +227,20 @@ def from_dict(cls, raw: dict, alt_classes: dict = {}, **kwargs) -> "CostModel": **alt_classes, } raw_args = {**raw} + # Strip callback/serde type markers β€” they're not constructor args + raw_args.pop("_callback_type", None) + raw_args.pop("_type", None) for key in ("request_dims", "run_dims"): if key in raw_args: raw_args[key] = { name: from_dict_with_class_map(d, class_map=dim_classes) for name, d in raw_args[key].items() } - return super().from_dict(raw_args, alt_classes=alt_classes, **kwargs) + return from_dict_with_class(raw=raw_args, cls=cls, **kwargs) @classmethod def _load_from_file(cls, path: str): """Load the cost model (including all dimensions) from a JSON file""" - with open(path, "r") as f: + path = ensure_path(path) + with path.open("r") as f: return cls.from_json(f.read()) diff --git a/llmeter/callbacks/cost/serde.py b/llmeter/callbacks/cost/serde.py index 843fad2..f04dc39 100644 --- a/llmeter/callbacks/cost/serde.py +++ b/llmeter/callbacks/cost/serde.py @@ -3,15 +3,17 @@ """(De/re)serialization interfaces for saving Cost Model objects to file and loading them back""" # Python Built-Ins: -from dataclasses import is_dataclass -from datetime import date, datetime, time import json import logging -import os -from typing import Any, Callable, Dict, Protocol, Type, TypeVar +from dataclasses import is_dataclass +from typing import Any, Protocol, TypeVar # External Dependencies: from upath import UPath as Path +from upath.types import ReadablePathLike, WritablePathLike + +from ...json_utils import LLMeterEncoder +from ...utils import ensure_path logger = logging.getLogger(__name__) @@ -31,6 +33,7 @@ def from_dict(cls, **kwargs) -> TSerializable: def is_dataclass_instance(obj): """Check whether `obj` is an instance of any dataclass + See: https://docs.python.org/3/library/dataclasses.html#dataclasses.is_dataclass """ return is_dataclass(obj) and not isinstance(obj, type) @@ -38,17 +41,19 @@ def is_dataclass_instance(obj): def to_dict_recursive_generic(obj: object, **kwargs) -> dict: """Convert a vaguely dataclass-like object (with maybe IJSONable fields) to a JSON-ready dict + The output dict is augmented with `_type` storing the `__class__.__name__` of the provided `obj`. + Args: obj: The object to convert **kwargs: Optional extra parameters to insert in the output dictionary """ obj_classname = obj.__class__.__name__ - result = {} if obj_classname == "dict" else {"_type": obj.__class__.__name__} + result: dict = {} if obj_classname == "dict" else {"_type": obj.__class__.__name__} if hasattr(obj, "__dict__"): - # We *don't* use dataclass asdict() here because we want our custom behaviour instead of its - # recursion: + # We *don't* use dataclass asdict() here because we want our custom behaviour instead of + # its recursion: result.update(obj.__dict__) elif ( isinstance(obj, dict) @@ -66,18 +71,18 @@ def to_dict_recursive_generic(obj: object, **kwargs) -> dict: result[k] = to_dict_recursive_generic(v) elif isinstance(v, (list, tuple)): result[k] = [to_dict_recursive_generic(item) for item in v] - elif isinstance(v, (date, datetime, time)): - result[k] = v.isoformat() return result TFromDict = TypeVar("TFromDict") -def from_dict_with_class(raw: dict, cls: Type[TFromDict], **kwargs) -> TFromDict: +def from_dict_with_class(raw: dict, cls: type[TFromDict], **kwargs) -> TFromDict: """Initialize an instance of a class from a plain dict (with optional extra kwargs) + If the input dictionary contains a `_type` key, and this doesn't match the provided `cls.__name__`, a warning will be logged. + Args: raw: A plain Python dict, for example loaded from a JSON file cls: The class to create an instance of @@ -94,12 +99,13 @@ def from_dict_with_class(raw: dict, cls: Type[TFromDict], **kwargs) -> TFromDict def from_dict_with_class_map( - raw: dict, class_map: Dict[str, Type[TFromDict]], **kwargs + raw: dict, class_map: dict[str, type[TFromDict]], **kwargs ) -> TFromDict: """Initialize an instance of a class from a plain dict (with optional extra kwargs) + Args: raw: A plain Python dict which must contain a `_type` key - classes: A mapping from `_type` string to class to create an instance of + class_map: A mapping from `_type` string to class to create an instance of **kwargs: Optional extra keyword arguments to pass to the constructor """ if "_type" not in raw: @@ -120,9 +126,9 @@ class JSONableBase: @classmethod def from_dict( - cls: Type[TJSONable], + cls: type[TJSONable], raw: dict, - alt_classes: Dict[str, TJSONable] = {}, + alt_classes: dict[str, TJSONable] = {}, **kwargs: Any, ) -> TJSONable: """Initialize an instance of this class from a plain dict (with optional extra kwargs) @@ -144,23 +150,26 @@ def from_dict( return from_dict_with_class(raw=raw, cls=cls, **kwargs) @classmethod - def from_file(cls: Type[TJSONable], input_path: os.PathLike, **kwargs) -> TJSONable: + def from_file( + cls: type[TJSONable], input_path: ReadablePathLike, **kwargs + ) -> TJSONable: """Initialize an instance of this class from a (local or Cloud) JSON file + Args: input_path: The path to the JSON data file. **kwargs: Optional extra keyword arguments to pass to `from_dict()` """ - input_path = Path(input_path) + input_path = ensure_path(input_path) with input_path.open("r") as f: return cls.from_json(f.read(), **kwargs) @classmethod - def from_json(cls: Type[TJSONable], json_string: str, **kwargs: Any) -> TJSONable: + def from_json(cls: type[TJSONable], json_string: str, **kwargs: Any) -> TJSONable: """Initialize an instance of this class from a JSON string (with optional extra kwargs) Args: json_string: A string containing valid JSON data - **kwargs: Optional extra keyword arguments to pass to `from_dict()`` + **kwargs: Optional extra keyword arguments to pass to `from_dict()` """ return cls.from_dict(json.loads(json_string), **kwargs) @@ -175,29 +184,28 @@ def to_dict(self, **kwargs) -> dict: def to_file( self, - output_path: os.PathLike, + output_path: WritablePathLike, indent: int | str | None = 4, - default: Callable[[Any], Any] | None = str, **kwargs: Any, ) -> Path: """Save the state of the object to a (local or Cloud) JSON file Args: output_path: The path where the configuration file will be saved. - indent: Optional indentation passed through to `to_json()` and therefore `json.dumps()` - default: Optional function to convert non-JSON-serializable objects to strings, passed - through to `to_json()` and therefore to `json.dumps()` + indent: Optional indentation passed through to `to_json()` and therefore + `json.dumps()` **kwargs: Optional extra keyword arguments to pass to `to_json()` Returns: output_path: Universal Path representation of the target file. """ - output_path = Path(output_path) + output_path = ensure_path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w") as f: - f.write(self.to_json(indent=indent, default=default, **kwargs)) + f.write(self.to_json(indent=indent, **kwargs)) return output_path def to_json(self, **kwargs) -> str: """Serialize this object to JSON, with optional kwargs passed through to `json.dumps()`""" + kwargs.setdefault("cls", LLMeterEncoder) return json.dumps(self.to_dict(), **kwargs) diff --git a/llmeter/callbacks/mlflow.py b/llmeter/callbacks/mlflow.py index b83635b..bfb10de 100644 --- a/llmeter/callbacks/mlflow.py +++ b/llmeter/callbacks/mlflow.py @@ -66,16 +66,29 @@ def __init__(self, step: int | None = None, nested: bool = False) -> None: # Check MLFlow is installed by polling any attribute on the module to trigger DeferredError mlflow.__version__ + def to_dict(self) -> dict: + """Serialize this callback's configuration to a JSON-safe dict. + + Returns: + dict: Contains ``_callback_type``, ``step``, and ``nested``. + """ + data = super().to_dict() + # Exclude the mlflow version check side-effect attribute if present + return data + @classmethod - async def _load_from_file(cls, path: str): - raise NotImplementedError( - "TODO: MlflowCallback does not yet support loading from file" - ) + def from_dict(cls, raw: dict, **kwargs) -> "MlflowCallback": + """Reconstruct an MlflowCallback from a dict produced by ``to_dict()``. - def save_to_file(self) -> str | None: - raise NotImplementedError( - "TODO: MlflowCallback does not yet support saving to file" - ) + Args: + raw: Dictionary with ``step`` and ``nested`` keys. + + Returns: + MlflowCallback: The deserialized callback. + """ + raw = dict(raw) + raw.pop("_callback_type", None) + return cls(**raw, **kwargs) async def _log_llmeter_run(self, result: Result): """Log parameters and metrics from an LLMeter run to MLflow. diff --git a/llmeter/endpoints/base.py b/llmeter/endpoints/base.py index 9ce6903..932c83d 100644 --- a/llmeter/endpoints/base.py +++ b/llmeter/endpoints/base.py @@ -7,13 +7,16 @@ import importlib import json -import os from abc import ABC, abstractmethod from dataclasses import asdict, dataclass from typing import Any from uuid import uuid4 from upath import UPath as Path +from upath.types import ReadablePathLike, WritablePathLike + +from ..json_utils import LLMeterEncoder +from ..utils import ensure_path # @dataclass(slots=True) @@ -47,13 +50,109 @@ class InvocationResponse: retries: int | None = None def to_json(self, **kwargs) -> str: - def default_serializer(obj): - try: - return str(obj) - except Exception: - return None + """ + Convert InvocationResponse to JSON string with binary content support. + + This method serializes the InvocationResponse object to a JSON string, with + automatic handling of binary content (bytes objects) in the input_payload field. + Binary data is converted to base64-encoded strings wrapped in marker objects, + enabling JSON serialization while preserving the ability to restore the original + bytes during deserialization. + + Binary Content Handling: + When the input_payload contains bytes objects (e.g., images, video), they are + automatically converted to base64-encoded strings and wrapped in marker objects + with the key "__llmeter_bytes__". This approach enables JSON serialization of + multimodal payloads while maintaining round-trip integrity. + + The marker object format is: {"__llmeter_bytes__": ""} - return json.dumps(asdict(self), default=default_serializer, **kwargs) + For non-serializable types other than bytes, the encoder falls back to str() + representation to ensure the response can always be serialized. + + Args: + **kwargs: Additional arguments passed to json.dumps (e.g., indent, sort_keys) + + Returns: + str: JSON representation of the response + + Examples: + Serialize a response with binary content in the payload: + + >>> # Create a response with binary image data in the payload + >>> with open("image.jpg", "rb") as f: + ... image_bytes = f.read() + >>> response = InvocationResponse( + ... response_text="The image shows a cat.", + ... input_payload={ + ... "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + ... "messages": [{ + ... "role": "user", + ... "content": [ + ... {"text": "What is in this image?"}, + ... { + ... "image": { + ... "format": "jpeg", + ... "source": {"bytes": image_bytes} + ... } + ... } + ... ] + ... }] + ... }, + ... time_to_last_token=1.23, + ... num_tokens_output=15 + ... ) + >>> json_str = response.to_json() + >>> # The JSON string contains marker objects for binary data + >>> "__llmeter_bytes__" in json_str + True + + Serialize with pretty printing: + + >>> json_str = response.to_json(indent=2) + >>> print(json_str) + { + "response_text": "The image shows a cat.", + "input_payload": { + "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "messages": [ + { + "role": "user", + "content": [ + {"text": "What is in this image?"}, + { + "image": { + "format": "jpeg", + "source": { + "bytes": {"__llmeter_bytes__": "/9j/4AAQSkZJRg..."} + } + } + } + ] + } + ] + }, + "time_to_last_token": 1.23, + "num_tokens_output": 15, + ... + } + + Round-trip serialization with binary preservation: + + >>> # Serialize to JSON + >>> json_str = response.to_json() + >>> # Parse back to dict + >>> import json + >>> from llmeter.json_utils import llmeter_bytes_decoder + >>> response_dict = json.loads(json_str, object_hook=llmeter_bytes_decoder) + >>> # Binary data is preserved + >>> original_bytes = response.input_payload["messages"][0]["content"][1]["image"]["source"]["bytes"] + >>> restored_bytes = response_dict["input_payload"]["messages"][0]["content"][1]["image"]["source"]["bytes"] + >>> original_bytes == restored_bytes + True + """ + kwargs.setdefault("cls", LLMeterEncoder) + return json.dumps(asdict(self), **kwargs) @staticmethod def error_output( @@ -165,7 +264,7 @@ def __subclasshook__(cls, C: type) -> bool: return True return NotImplemented - def save(self, output_path: os.PathLike) -> os.PathLike: + def save(self, output_path: WritablePathLike) -> Path: """ Save the endpoint configuration to a JSON file. @@ -178,11 +277,10 @@ def save(self, output_path: os.PathLike) -> os.PathLike: Returns: None """ - output_path = Path(output_path) + output_path = ensure_path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w") as f: - endpoint_conf = self.to_dict() - json.dump(endpoint_conf, f, indent=4, default=str) + json.dump(self, f, indent=4, cls=LLMeterEncoder) return output_path def to_dict(self) -> dict: @@ -192,12 +290,14 @@ def to_dict(self) -> dict: Returns: Dict: A dictionary representation of the endpoint configuration. """ - endpoint_conf = {k: v for k, v in vars(self).items() if not k.startswith("_")} + endpoint_conf = { + k: v for k, v in vars(self).items() if not k.startswith("_") + } endpoint_conf["endpoint_type"] = self.__class__.__name__ return endpoint_conf @classmethod - def load_from_file(cls, input_path: os.PathLike) -> "Endpoint": + def load_from_file(cls, input_path: ReadablePathLike) -> "Endpoint": """ Load an endpoint configuration from a JSON file. @@ -213,7 +313,7 @@ def load_from_file(cls, input_path: os.PathLike) -> "Endpoint": with the configuration from the file. """ - input_path = Path(input_path) + input_path = ensure_path(input_path) with input_path.open("r") as f: data = json.load(f) endpoint_type = data.pop("endpoint_type") diff --git a/llmeter/endpoints/bedrock.py b/llmeter/endpoints/bedrock.py index 9edf30b..302df9f 100644 --- a/llmeter/endpoints/bedrock.py +++ b/llmeter/endpoints/bedrock.py @@ -19,11 +19,117 @@ from botocore.config import Config from botocore.exceptions import ClientError +from ..prompt_utils import ( + detect_format_from_bytes, + detect_format_from_file, + read_file, +) from .base import Endpoint, InvocationResponse logger = logging.getLogger(__name__) +def _mime_to_format(mime_type: str) -> str | None: + """Convert MIME type to Bedrock format string. + + Maps MIME types to format names used by Bedrock Converse API. + + Args: + mime_type: MIME type (e.g., "image/jpeg", "application/pdf") + + Returns: + str | None: Format string (e.g., "jpeg", "png", "pdf") or None if not recognized + """ + mime_map = { + "image/jpeg": "jpeg", + "image/png": "png", + "image/gif": "gif", + "image/webp": "webp", + "application/pdf": "pdf", + "video/mp4": "mp4", + "video/quicktime": "mov", + "video/x-msvideo": "avi", + "audio/mpeg": "mp3", + "audio/wav": "wav", + "audio/x-wav": "wav", + "audio/ogg": "ogg", + } + return mime_map.get(mime_type) + + +def _build_content_blocks( + user_message: str | list[str] | None, + images: list[bytes] | list[str] | None, + documents: list[bytes] | list[str] | None, + videos: list[bytes] | list[str] | None, + audio: list[bytes] | list[str] | None, +) -> list[dict]: + """Build content blocks from parameters. + + Returns list of content block dictionaries in Bedrock Converse API format. + + Args: + user_message: Text message(s) + images: List of image bytes or file paths + documents: List of document bytes or file paths + videos: List of video bytes or file paths + audio: List of audio bytes or file paths + + Returns: + list[dict]: Content blocks + + Raises: + ValueError: If format cannot be auto-detected from bytes + """ + content = [] + + # Add text blocks first + if user_message: + messages = [user_message] if isinstance(user_message, str) else user_message + for msg in messages: + content.append({"text": msg}) + + # Add media blocks in order: images, videos, audio, documents + for media_list, media_type in [ + (images, "image"), + (videos, "video"), + (audio, "audio"), + (documents, "document"), + ]: + if media_list: + for item in media_list: + if isinstance(item, bytes): + # Bytes provided directly - detect MIME type from content + data = item + mime_type = detect_format_from_bytes(data) + if mime_type is None: + raise ValueError( + f"Cannot detect format from bytes for {media_type}. " + "Either install puremagic for content-based detection " + "or provide file path for extension-based detection." + ) + fmt = _mime_to_format(mime_type) + if fmt is None: + raise ValueError( + f"Unsupported MIME type '{mime_type}' for {media_type}" + ) + else: + # File path - read and detect MIME type from file + data = read_file(item) + mime_type = detect_format_from_file(item) + if mime_type is None: + raise ValueError(f"Cannot detect format from file: {item}") + fmt = _mime_to_format(mime_type) + if fmt is None: + raise ValueError( + f"Unsupported MIME type '{mime_type}' for file: {item}" + ) + + content.append({media_type: {"format": fmt, "source": {"bytes": data}}}) + + return content + + class BedrockBase(Endpoint): """Base class for interacting with Amazon Bedrock endpoints. @@ -114,43 +220,156 @@ def _parse_payload(self, payload): @staticmethod def create_payload( - user_message: str | list[str], max_tokens: int = 256, **kwargs: Any + user_message: str | list[str] | None = None, + max_tokens: int | None = None, + *, + images: list[bytes] | list[str] | None = None, + documents: list[bytes] | list[str] | None = None, + videos: list[bytes] | list[str] | None = None, + audio: list[bytes] | list[str] | None = None, + **kwargs: Any, ) -> dict: """ - Create a payload for the Bedrock Converse API request. + Create a payload for the Bedrock Converse API request with optional multi-modal content. + + ⚠️ SECURITY WARNING: Format detection is for testing/development convenience ONLY. + This method does NOT validate file safety, integrity, or protect against malicious + content. DO NOT use with untrusted files (user uploads, external sources) without + proper validation, sanitization, and security measures. Args: - user_message (str | Sequence[str]): The user's message or a sequence of messages. - max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 256. + user_message (str | list[str] | None): The user's message or a sequence of messages. + max_tokens (int | None): The maximum number of tokens to generate. Defaults to 256. + images (list[bytes] | list[str] | None): List of image bytes or file paths (keyword-only). + documents (list[bytes] | list[str] | None): List of document bytes or file paths (keyword-only). + videos (list[bytes] | list[str] | None): List of video bytes or file paths (keyword-only). + audio (list[bytes] | list[str] | None): List of audio bytes or file paths (keyword-only). **kwargs: Additional keyword arguments to include in the payload. Returns: dict: The formatted payload for the Bedrock API request. Raises: - TypeError: If user_message is not a string or list of strings - ValueError: If max_tokens is not a positive integer + TypeError: If parameters have invalid types + ValueError: If parameters have invalid values + FileNotFoundError: If a file path doesn't exist + IOError: If a file cannot be read + + Security: + - Format detection (puremagic/extension) is NOT security validation + - Malicious files can have misleading extensions or forged magic bytes + - This method does NOT scan for malware or sanitize content + - Users MUST validate and sanitize untrusted files before calling this method + - Intended for testing/development with trusted files only + - NOT intended for production user uploads without proper security measures + + Examples: + # Text only (backward compatible) + create_payload(user_message="Hello") + create_payload("Hello", 256) # Positional args still work + + # Single image from file path (trusted source) + create_payload( + user_message="What's in this image?", + images=["photo.jpg"] + ) + + # Multiple images from bytes (trusted source) + create_payload( + user_message="Compare these images", + images=[image_bytes1, image_bytes2] + ) + + # Mixed content (trusted source) + create_payload( + user_message="Analyze this", + images=["chart.png"], + documents=["report.pdf"] + ) """ - if not isinstance(user_message, (str, list)): - raise TypeError("user_message must be a string or list of strings") + # Set default for max_tokens if not provided + if max_tokens is None: + max_tokens = 256 + + # Check if any multi-modal content is provided + has_multimodal = any([images, documents, videos, audio]) + + # Backward compatibility: if only user_message provided, use old logic + if not has_multimodal: + if user_message is None: + raise ValueError("user_message is required when no media is provided") + + if not isinstance(user_message, (str, list)): + raise TypeError("user_message must be a string or list of strings") + + if isinstance(user_message, list): + if not all(isinstance(msg, str) for msg in user_message): + raise TypeError("All messages must be strings") + if not user_message: + raise ValueError("user_message list cannot be empty") + + if not isinstance(max_tokens, int) or max_tokens <= 0: + raise ValueError("max_tokens must be a positive integer") - if isinstance(user_message, list): - if not all(isinstance(msg, str) for msg in user_message): - raise TypeError("All messages must be strings") - if not user_message: - raise ValueError("user_message list cannot be empty") + if isinstance(user_message, str): + user_message = [user_message] + + try: + payload: dict = { + "messages": [ + {"role": "user", "content": [{"text": k}]} for k in user_message + ], + } + payload.update(kwargs) + if payload.get("inferenceConfig") is None: + payload["inferenceConfig"] = {} + + payload["inferenceConfig"] = { + **payload["inferenceConfig"], + "maxTokens": max_tokens, + } + return payload + + except Exception as e: + logger.error(f"Error creating payload: {e}") + raise RuntimeError(f"Failed to create payload: {str(e)}") + + # Multi-modal path: validate types + if images is not None and not isinstance(images, list): + raise TypeError("images must be a list") + if documents is not None and not isinstance(documents, list): + raise TypeError("documents must be a list") + if videos is not None and not isinstance(videos, list): + raise TypeError("videos must be a list") + if audio is not None and not isinstance(audio, list): + raise TypeError("audio must be a list") + + # Validate list items are bytes or str + for media_list, media_name in [ + (images, "images"), + (documents, "documents"), + (videos, "videos"), + (audio, "audio"), + ]: + if media_list: + for item in media_list: + if not isinstance(item, (bytes, str)): + raise TypeError( + f"Items in {media_name} list must be bytes or str (file path), " + f"got {type(item).__name__}" + ) if not isinstance(max_tokens, int) or max_tokens <= 0: raise ValueError("max_tokens must be a positive integer") - if isinstance(user_message, str): - user_message = [user_message] - try: + # Build content blocks + content_blocks = _build_content_blocks( + user_message, images, documents, videos, audio + ) + payload: dict = { - "messages": [ - {"role": "user", "content": [{"text": k}]} for k in user_message - ], + "messages": [{"role": "user", "content": content_blocks}], } payload.update(kwargs) if payload.get("inferenceConfig") is None: @@ -164,7 +383,7 @@ def create_payload( except Exception as e: logger.error(f"Error creating payload: {e}") - raise RuntimeError(f"Failed to create payload: {str(e)}") + raise class BedrockConverse(BedrockBase): diff --git a/llmeter/endpoints/bedrock_invoke.py b/llmeter/endpoints/bedrock_invoke.py index 20a9403..0b0c118 100644 --- a/llmeter/endpoints/bedrock_invoke.py +++ b/llmeter/endpoints/bedrock_invoke.py @@ -8,10 +8,11 @@ from uuid import uuid4 import boto3 +import jmespath from botocore.config import Config from botocore.exceptions import ClientError -import jmespath +from ..json_utils import LLMeterEncoder from .base import Endpoint, InvocationResponse logger = logging.getLogger(__name__) @@ -261,7 +262,7 @@ def invoke(self, payload: dict) -> InvocationResponse: raise TypeError("Payload must be a dictionary") try: - req_body = json.dumps(payload).encode("utf-8") + req_body = json.dumps(payload, cls=LLMeterEncoder).encode("utf-8") try: start_t = time.perf_counter() client_response = self._bedrock_client.invoke_model( # type: ignore @@ -353,7 +354,7 @@ def __init__( ) def invoke(self, payload: dict) -> InvocationResponse: - req_body = json.dumps(payload).encode("utf-8") + req_body = json.dumps(payload, cls=LLMeterEncoder).encode("utf-8") try: start_t = time.perf_counter() client_response = self._bedrock_client.invoke_model_with_response_stream( # type: ignore diff --git a/llmeter/endpoints/openai.py b/llmeter/endpoints/openai.py index 88dd041..a3e8d85 100644 --- a/llmeter/endpoints/openai.py +++ b/llmeter/endpoints/openai.py @@ -12,6 +12,108 @@ from openai.types.chat import ChatCompletion from .base import Endpoint, InvocationResponse +from ..prompt_utils import read_file, detect_format_from_bytes, detect_format_from_file + +logger = logging.getLogger(__name__) + + +def _mime_to_openai_format(mime_type: str) -> str | None: + """Convert MIME type to OpenAI format string. + + OpenAI expects full MIME types for media content. + + Args: + mime_type: MIME type (e.g., "image/jpeg", "application/pdf") + + Returns: + str | None: Format string (full MIME type) or None if not recognized + """ + # OpenAI uses full MIME types + supported_mimes = { + "image/jpeg", + "image/png", + "image/gif", + "image/webp", + "application/pdf", + "video/mp4", + "audio/mpeg", + "audio/wav", + } + return mime_type if mime_type in supported_mimes else None + + +def _build_content_blocks_openai( + user_message: str | list[str] | None, + images: list[bytes] | list[str] | None, + documents: list[bytes] | list[str] | None, + videos: list[bytes] | list[str] | None, + audio: list[bytes] | list[str] | None, +) -> list[dict]: + """Build content blocks from parameters for OpenAI API. + + Returns list of content block dictionaries in OpenAI format. + + Args: + user_message: Text message(s) + images: List of image bytes or file paths + documents: List of document bytes or file paths + videos: List of video bytes or file paths + audio: List of audio bytes or file paths + + Returns: + list[dict]: Content blocks + + Raises: + ValueError: If format cannot be auto-detected from bytes + """ + content = [] + + # Add text blocks first + if user_message: + messages = [user_message] if isinstance(user_message, str) else user_message + for msg in messages: + content.append({"text": msg}) + + # Add media blocks in order: images, videos, audio, documents + for media_list, media_type in [ + (images, "image"), + (videos, "video"), + (audio, "audio"), + (documents, "document"), + ]: + if media_list: + for item in media_list: + if isinstance(item, bytes): + # Bytes provided directly - detect MIME type from content + data = item + mime_type = detect_format_from_bytes(data) + if mime_type is None: + raise ValueError( + f"Cannot detect format from bytes for {media_type}. " + "Either install puremagic for content-based detection " + "or provide file path for extension-based detection." + ) + fmt = _mime_to_openai_format(mime_type) + if fmt is None: + raise ValueError( + f"Unsupported MIME type '{mime_type}' for {media_type}" + ) + else: + # File path - read and detect MIME type from file + data = read_file(item) + mime_type = detect_format_from_file(item) + if mime_type is None: + raise ValueError(f"Cannot detect format from file: {item}") + fmt = _mime_to_openai_format(mime_type) + if fmt is None: + raise ValueError( + f"Unsupported MIME type '{mime_type}' for file: {item}" + ) + + content.append({media_type: {"format": fmt, "source": {"bytes": data}}}) + + return content + logger = logging.getLogger(__name__) @@ -65,26 +167,129 @@ def _parse_payload(self, payload): @staticmethod def create_payload( - user_message: str | Sequence[str], max_tokens: int = 256, **kwargs: Any + user_message: str | Sequence[str] | None = None, + max_tokens: int = 256, + *, + images: list[bytes] | list[str] | None = None, + documents: list[bytes] | list[str] | None = None, + videos: list[bytes] | list[str] | None = None, + audio: list[bytes] | list[str] | None = None, + **kwargs: Any, ) -> dict: - """Create a payload for the OpenAI API request. + """Create a payload for the OpenAI API request with optional multi-modal content. + + ⚠️ SECURITY WARNING: Format detection is for testing/development convenience ONLY. + This method does NOT validate file safety, integrity, or protect against malicious + content. DO NOT use with untrusted files (user uploads, external sources) without + proper validation, sanitization, and security measures. Args: - user_message (str | Sequence[str]): User message(s) to send + user_message (str | Sequence[str] | None): User message(s) to send max_tokens (int, optional): Maximum tokens in response. Defaults to 256. + images (list[bytes] | list[str] | None): List of image bytes or file paths (keyword-only). + documents (list[bytes] | list[str] | None): List of document bytes or file paths (keyword-only). + videos (list[bytes] | list[str] | None): List of video bytes or file paths (keyword-only). + audio (list[bytes] | list[str] | None): List of audio bytes or file paths (keyword-only). **kwargs: Additional payload parameters Returns: dict: Formatted payload for API request + + Raises: + TypeError: If parameters have invalid types + ValueError: If parameters have invalid values + FileNotFoundError: If a file path doesn't exist + IOError: If a file cannot be read + + Security: + - Format detection (puremagic/extension) is NOT security validation + - Malicious files can have misleading extensions or forged magic bytes + - This method does NOT scan for malware or sanitize content + - Users MUST validate and sanitize untrusted files before calling this method + - Intended for testing/development with trusted files only + - NOT intended for production user uploads without proper security measures + + Examples: + # Text only (backward compatible) + create_payload(user_message="Hello") + + # Single image from file path (trusted source) + create_payload( + user_message="What's in this image?", + images=["photo.jpg"] + ) + + # Multiple images from bytes (trusted source) + create_payload( + user_message="Compare these images", + images=[image_bytes1, image_bytes2] + ) + + # Mixed content (trusted source) + create_payload( + user_message="Analyze this", + images=["chart.png"], + documents=["report.pdf"] + ) """ - if isinstance(user_message, str): - user_message = [user_message] - payload = { - "messages": [{"role": "user", "content": k} for k in user_message], - "max_tokens": max_tokens, - } - payload.update(kwargs) - return payload + # Check if any multi-modal content is provided + has_multimodal = any([images, documents, videos, audio]) + + # Backward compatibility: if only user_message provided, use old logic + if not has_multimodal: + if isinstance(user_message, str): + user_message = [user_message] + payload = { + "messages": [{"role": "user", "content": k} for k in user_message], + "max_tokens": max_tokens, + } + payload.update(kwargs) + return payload + + # Multi-modal path: validate types + if images is not None and not isinstance(images, list): + raise TypeError("images must be a list") + if documents is not None and not isinstance(documents, list): + raise TypeError("documents must be a list") + if videos is not None and not isinstance(videos, list): + raise TypeError("videos must be a list") + if audio is not None and not isinstance(audio, list): + raise TypeError("audio must be a list") + + # Validate list items are bytes or str + for media_list, media_name in [ + (images, "images"), + (documents, "documents"), + (videos, "videos"), + (audio, "audio"), + ]: + if media_list: + for item in media_list: + if not isinstance(item, (bytes, str)): + raise TypeError( + f"Items in {media_name} list must be bytes or str (file path), " + f"got {type(item).__name__}" + ) + + if not isinstance(max_tokens, int) or max_tokens <= 0: + raise ValueError("max_tokens must be a positive integer") + + try: + # Build content blocks + content_blocks = _build_content_blocks_openai( + user_message, images, documents, videos, audio + ) + + payload = { + "messages": [{"role": "user", "content": content_blocks}], + "max_tokens": max_tokens, + } + payload.update(kwargs) + return payload + + except Exception as e: + logger.error(f"Error creating payload: {e}") + raise class OpenAICompletionEndpoint(OpenAIEndpoint): diff --git a/llmeter/endpoints/sagemaker.py b/llmeter/endpoints/sagemaker.py index 05c67af..ec81708 100644 --- a/llmeter/endpoints/sagemaker.py +++ b/llmeter/endpoints/sagemaker.py @@ -13,6 +13,112 @@ from botocore.exceptions import ClientError from .base import Endpoint, InvocationResponse +from ..prompt_utils import read_file, detect_format_from_bytes, detect_format_from_file + +logger = logging.getLogger(__name__) + + +def _mime_to_sagemaker_format(mime_type: str) -> str | None: + """Convert MIME type to SageMaker format string. + + SageMaker format depends on the deployed model, using Bedrock format as default. + + Args: + mime_type: MIME type (e.g., "image/jpeg", "application/pdf") + + Returns: + str | None: Format string (e.g., "jpeg", "png", "pdf") or None if not recognized + """ + # SageMaker uses Bedrock-style short format strings as default + mime_map = { + "image/jpeg": "jpeg", + "image/png": "png", + "image/gif": "gif", + "image/webp": "webp", + "application/pdf": "pdf", + "video/mp4": "mp4", + "video/quicktime": "mov", + "video/x-msvideo": "avi", + "audio/mpeg": "mp3", + "audio/wav": "wav", + "audio/x-wav": "wav", + "audio/ogg": "ogg", + } + return mime_map.get(mime_type) + + +def _build_content_blocks_sagemaker( + user_message: str | list[str] | None, + images: list[bytes] | list[str] | None, + documents: list[bytes] | list[str] | None, + videos: list[bytes] | list[str] | None, + audio: list[bytes] | list[str] | None, +) -> list[dict]: + """Build content blocks from parameters for SageMaker API. + + Returns list of content block dictionaries in SageMaker format. + + Args: + user_message: Text message(s) + images: List of image bytes or file paths + documents: List of document bytes or file paths + videos: List of video bytes or file paths + audio: List of audio bytes or file paths + + Returns: + list[dict]: Content blocks + + Raises: + ValueError: If format cannot be auto-detected from bytes + """ + content = [] + + # Add text blocks first + if user_message: + messages = [user_message] if isinstance(user_message, str) else user_message + for msg in messages: + content.append({"text": msg}) + + # Add media blocks in order: images, videos, audio, documents + for media_list, media_type in [ + (images, "image"), + (videos, "video"), + (audio, "audio"), + (documents, "document"), + ]: + if media_list: + for item in media_list: + if isinstance(item, bytes): + # Bytes provided directly - detect MIME type from content + data = item + mime_type = detect_format_from_bytes(data) + if mime_type is None: + raise ValueError( + f"Cannot detect format from bytes for {media_type}. " + "Either install puremagic for content-based detection " + "or provide file path for extension-based detection." + ) + fmt = _mime_to_sagemaker_format(mime_type) + if fmt is None: + raise ValueError( + f"Unsupported MIME type '{mime_type}' for {media_type}" + ) + else: + # File path - read and detect MIME type from file + data = read_file(item) + mime_type = detect_format_from_file(item) + if mime_type is None: + raise ValueError(f"Cannot detect format from file: {item}") + fmt = _mime_to_sagemaker_format(mime_type) + if fmt is None: + raise ValueError( + f"Unsupported MIME type '{mime_type}' for file: {item}" + ) + + content.append({media_type: {"format": fmt, "source": {"bytes": data}}}) + + return content + logger = logging.getLogger(__name__) @@ -54,22 +160,139 @@ def _parse_input(self, payload: dict) -> str | None: @staticmethod def create_payload( - input_text: str | list[str], + input_text: str | list[str] | None = None, max_tokens: int = 256, inference_parameters: dict = {}, + *, + images: list[bytes] | list[str] | None = None, + documents: list[bytes] | list[str] | None = None, + videos: list[bytes] | list[str] | None = None, + audio: list[bytes] | list[str] | None = None, **kwargs, ): - payload = { - "inputs": input_text, - "parameters": { - "max_new_tokens": max_tokens, - "details": True, - }, - } - if inference_parameters: - payload["parameters"].update(inference_parameters) - payload.update(kwargs) - return payload + """Create a payload for the SageMaker API request with optional multi-modal content. + + ⚠️ SECURITY WARNING: Format detection is for testing/development convenience ONLY. + This method does NOT validate file safety, integrity, or protect against malicious + content. DO NOT use with untrusted files (user uploads, external sources) without + proper validation, sanitization, and security measures. + + Args: + input_text (str | list[str] | None): The input text or a sequence of texts. + max_tokens (int): Maximum tokens to generate. Defaults to 256. + inference_parameters (dict): Additional inference parameters. Defaults to {}. + images (list[bytes] | list[str] | None): List of image bytes or file paths (keyword-only). + documents (list[bytes] | list[str] | None): List of document bytes or file paths (keyword-only). + videos (list[bytes] | list[str] | None): List of video bytes or file paths (keyword-only). + audio (list[bytes] | list[str] | None): List of audio bytes or file paths (keyword-only). + **kwargs: Additional keyword arguments to include in the payload. + + Returns: + dict: The formatted payload for the SageMaker API request. + + Raises: + TypeError: If parameters have invalid types + ValueError: If parameters have invalid values + FileNotFoundError: If a file path doesn't exist + IOError: If a file cannot be read + + Security: + - Format detection (puremagic/extension) is NOT security validation + - Malicious files can have misleading extensions or forged magic bytes + - This method does NOT scan for malware or sanitize content + - Users MUST validate and sanitize untrusted files before calling this method + - Intended for testing/development with trusted files only + - NOT intended for production user uploads without proper security measures + + Examples: + # Text only (backward compatible) + create_payload(input_text="Hello") + + # Single image from file path (trusted source) + create_payload( + input_text="What's in this image?", + images=["photo.jpg"] + ) + + # Multiple images from bytes (trusted source) + create_payload( + input_text="Compare these images", + images=[image_bytes1, image_bytes2] + ) + + # Mixed content (trusted source) + create_payload( + input_text="Analyze this", + images=["chart.png"], + documents=["report.pdf"] + ) + """ + # Check if any multi-modal content is provided + has_multimodal = any([images, documents, videos, audio]) + + # Backward compatibility: if only input_text provided, use old logic + if not has_multimodal: + payload = { + "inputs": input_text, + "parameters": { + "max_new_tokens": max_tokens, + "details": True, + }, + } + if inference_parameters: + payload["parameters"].update(inference_parameters) + payload.update(kwargs) + return payload + + # Multi-modal path: validate types + if images is not None and not isinstance(images, list): + raise TypeError("images must be a list") + if documents is not None and not isinstance(documents, list): + raise TypeError("documents must be a list") + if videos is not None and not isinstance(videos, list): + raise TypeError("videos must be a list") + if audio is not None and not isinstance(audio, list): + raise TypeError("audio must be a list") + + # Validate list items are bytes or str + for media_list, media_name in [ + (images, "images"), + (documents, "documents"), + (videos, "videos"), + (audio, "audio"), + ]: + if media_list: + for item in media_list: + if not isinstance(item, (bytes, str)): + raise TypeError( + f"Items in {media_name} list must be bytes or str (file path), " + f"got {type(item).__name__}" + ) + + if not isinstance(max_tokens, int) or max_tokens <= 0: + raise ValueError("max_tokens must be a positive integer") + + try: + # Build content blocks + content_blocks = _build_content_blocks_sagemaker( + input_text, images, documents, videos, audio + ) + + payload = { + "inputs": content_blocks, + "parameters": { + "max_new_tokens": max_tokens, + "details": True, + }, + } + if inference_parameters: + payload["parameters"].update(inference_parameters) + payload.update(kwargs) + return payload + + except Exception as e: + logger.error(f"Error creating payload: {e}") + raise class SageMakerEndpoint(SageMakerBase): @@ -251,23 +474,141 @@ def invoke(self, payload: dict) -> InvocationResponse: @staticmethod def create_payload( - input_text: str | list[str], + input_text: str | list[str] | None = None, max_tokens: int = 256, inference_parameters: dict = {}, + *, + images: list[bytes] | list[str] | None = None, + documents: list[bytes] | list[str] | None = None, + videos: list[bytes] | list[str] | None = None, + audio: list[bytes] | list[str] | None = None, **kwargs, ): - payload = { - "inputs": input_text, - "parameters": { - "max_new_tokens": max_tokens, - "details": True, - }, - "stream": True, - } - if inference_parameters: - payload["parameters"].update(inference_parameters) - payload.update(kwargs) - return payload + """Create a payload for the SageMaker streaming API request with optional multi-modal content. + + ⚠️ SECURITY WARNING: Format detection is for testing/development convenience ONLY. + This method does NOT validate file safety, integrity, or protect against malicious + content. DO NOT use with untrusted files (user uploads, external sources) without + proper validation, sanitization, and security measures. + + Args: + input_text (str | list[str] | None): The input text or a sequence of texts. + max_tokens (int): Maximum tokens to generate. Defaults to 256. + inference_parameters (dict): Additional inference parameters. Defaults to {}. + images (list[bytes] | list[str] | None): List of image bytes or file paths (keyword-only). + documents (list[bytes] | list[str] | None): List of document bytes or file paths (keyword-only). + videos (list[bytes] | list[str] | None): List of video bytes or file paths (keyword-only). + audio (list[bytes] | list[str] | None): List of audio bytes or file paths (keyword-only). + **kwargs: Additional keyword arguments to include in the payload. + + Returns: + dict: The formatted payload for the SageMaker streaming API request. + + Raises: + TypeError: If parameters have invalid types + ValueError: If parameters have invalid values + FileNotFoundError: If a file path doesn't exist + IOError: If a file cannot be read + + Security: + - Format detection (puremagic/extension) is NOT security validation + - Malicious files can have misleading extensions or forged magic bytes + - This method does NOT scan for malware or sanitize content + - Users MUST validate and sanitize untrusted files before calling this method + - Intended for testing/development with trusted files only + - NOT intended for production user uploads without proper security measures + + Examples: + # Text only (backward compatible) + create_payload(input_text="Hello") + + # Single image from file path (trusted source) + create_payload( + input_text="What's in this image?", + images=["photo.jpg"] + ) + + # Multiple images from bytes (trusted source) + create_payload( + input_text="Compare these images", + images=[image_bytes1, image_bytes2] + ) + + # Mixed content (trusted source) + create_payload( + input_text="Analyze this", + images=["chart.png"], + documents=["report.pdf"] + ) + """ + # Check if any multi-modal content is provided + has_multimodal = any([images, documents, videos, audio]) + + # Backward compatibility: if only input_text provided, use old logic + if not has_multimodal: + payload = { + "inputs": input_text, + "parameters": { + "max_new_tokens": max_tokens, + "details": True, + }, + "stream": True, + } + if inference_parameters: + payload["parameters"].update(inference_parameters) + payload.update(kwargs) + return payload + + # Multi-modal path: validate types + if images is not None and not isinstance(images, list): + raise TypeError("images must be a list") + if documents is not None and not isinstance(documents, list): + raise TypeError("documents must be a list") + if videos is not None and not isinstance(videos, list): + raise TypeError("videos must be a list") + if audio is not None and not isinstance(audio, list): + raise TypeError("audio must be a list") + + # Validate list items are bytes or str + for media_list, media_name in [ + (images, "images"), + (documents, "documents"), + (videos, "videos"), + (audio, "audio"), + ]: + if media_list: + for item in media_list: + if not isinstance(item, (bytes, str)): + raise TypeError( + f"Items in {media_name} list must be bytes or str (file path), " + f"got {type(item).__name__}" + ) + + if not isinstance(max_tokens, int) or max_tokens <= 0: + raise ValueError("max_tokens must be a positive integer") + + try: + # Build content blocks + content_blocks = _build_content_blocks_sagemaker( + input_text, images, documents, videos, audio + ) + + payload = { + "inputs": content_blocks, + "parameters": { + "max_new_tokens": max_tokens, + "details": True, + }, + "stream": True, + } + if inference_parameters: + payload["parameters"].update(inference_parameters) + payload.update(kwargs) + return payload + + except Exception as e: + logger.error(f"Error creating payload: {e}") + raise class TokenIterator: diff --git a/llmeter/experiments.py b/llmeter/experiments.py index 0a70a17..dcf6a0e 100644 --- a/llmeter/experiments.py +++ b/llmeter/experiments.py @@ -15,12 +15,14 @@ from tqdm.auto import tqdm from upath import UPath as Path +from upath.types import ReadablePathLike, WritablePathLike from llmeter.callbacks.base import Callback from llmeter.results import Result +from llmeter.utils import ensure_path from .endpoints.base import Endpoint -from .plotting import plot_heatmap, plot_load_test_results, color_sequences +from .plotting import color_sequences, plot_heatmap, plot_load_test_results from .prompt_utils import CreatePromptCollection from .runner import Runner from .tokenizers import Tokenizer @@ -38,7 +40,7 @@ class LoadTestResult: results: dict[int, Result] test_name: str - output_path: os.PathLike | str | None = None + output_path: WritablePathLike | None = None def plot_results(self, show: bool = True, format: Literal["html", "png"] = "html"): figs = plot_load_test_results(self) @@ -57,7 +59,7 @@ def plot_results(self, show: bool = True, format: Literal["html", "png"] = "html f.update_layout(colorway=c_seqs[i % len(c_seqs)]) if self.output_path is not None: - output_path = Path(self.output_path) + output_path = ensure_path(self.output_path) # save figure to the output path output_path.parent.mkdir(parents=True, exist_ok=True) for k, f in figs.items(): @@ -96,7 +98,7 @@ def load( raise FileNotFoundError("Load path cannot be None or empty") if isinstance(load_path, str): - load_path = Path(load_path) + load_path = ensure_path(load_path) if not load_path.exists(): raise FileNotFoundError(f"Load path {load_path} does not exist") @@ -130,7 +132,7 @@ class LoadTest: sequence_of_clients: list[int] min_requests_per_client: int = 1 min_requests_per_run: int = 10 - output_path: os.PathLike | str | None = None + output_path: WritablePathLike | None = None tokenizer: Tokenizer | None = None test_name: str | None = None callbacks: list[Callback] | None = None @@ -143,9 +145,9 @@ def _get_n_requests(self, clients): return int(ceil(self.min_requests_per_run / clients)) return int(self.min_requests_per_client) - async def run(self, output_path: os.PathLike | None = None): + async def run(self, output_path: WritablePathLike | None = None): try: - output_path = Path(output_path or self.output_path) / self._test_name + output_path = ensure_path(output_path or self.output_path) / self._test_name except Exception: output_path = None _runner = Runner( @@ -209,9 +211,9 @@ class LatencyHeatmap: """ endpoint: Endpoint - source_file: os.PathLike | str + source_file: ReadablePathLike clients: int = 4 - output_path: os.PathLike | str | None = None + output_path: WritablePathLike | None = None input_lengths: list[int] = field(default_factory=lambda: [10, 50, 200, 500]) output_lengths: list[int] = field(default_factory=lambda: [128, 256, 512, 1024]) requests_per_combination: int = 1 @@ -224,7 +226,7 @@ def __post_init__(self) -> None: requests_per_combination=self.requests_per_combination, input_lengths=self.input_lengths, output_lengths=self.output_lengths, - source_file=Path(self.source_file), + source_file=ensure_path(self.source_file), tokenizer=self.tokenizer, # type: ignore ) @@ -239,7 +241,7 @@ def __post_init__(self) -> None: self._runner = Runner( endpoint=self.endpoint, - output_path=Path(self.output_path) + output_path=ensure_path(self.output_path) if self.output_path is not None else None, tokenizer=self.tokenizer, @@ -249,7 +251,7 @@ async def run(self, output_path=None): # Handle None output_path properly final_output_path = output_path or self.output_path if final_output_path is not None: - final_output_path = Path(final_output_path) + final_output_path = ensure_path(final_output_path) heatmap_results = await self._runner.run( payload=self.payload, diff --git a/llmeter/json_utils.py b/llmeter/json_utils.py new file mode 100644 index 0000000..167e755 --- /dev/null +++ b/llmeter/json_utils.py @@ -0,0 +1,133 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""JSON encoding and decoding helpers used across LLMeter + +Provides a unified :class:`json.JSONEncoder` subclass and a matching decoder hook for +round-tripping binary content (``bytes``) through JSON via base64 marker objects, while also +handling ``datetime``, ``os.PathLike``, and objects that implement ``to_dict()``. + +Example:: + + import json + from llmeter.json_utils import LLMeterEncoder, llmeter_bytes_decoder + + payload = {"image": {"bytes": b"\\xff\\xd8\\xff\\xe0"}} + + # Serialize + encoded = json.dumps(payload, cls=LLMeterEncoder) + + # Deserialize (bytes are restored automatically) + decoded = json.loads(encoded, object_hook=llmeter_bytes_decoder) + assert decoded == payload +""" + +import base64 +import json +import os +from datetime import date, datetime, time, timezone +from typing import Any + +from upath import UPath as Path + + +class LLMeterEncoder(json.JSONEncoder): + """JSON encoder that handles common non-serializable types found in LLMeter. + + Type handling (checked in order): + + * Objects with a ``to_dict()`` method β€” delegates to that method. + * ``bytes`` β€” wrapped in a ``{"__llmeter_bytes__": ""}`` marker object + so that :func:`llmeter_bytes_decoder` can restore them on the way back. + * ``datetime`` β€” converted to a UTC ISO-8601 string with a ``Z`` suffix. + * ``date`` / ``time`` β€” converted via ``.isoformat()``. + * ``os.PathLike`` β€” converted to a POSIX path string. + * Anything else β€” ``str()`` fallback (returns ``None`` if that also fails). + + Customization: + To handle additional types, subclass ``LLMeterEncoder`` and override + :meth:`default`. Call ``super().default(obj)`` as a fallback so that the + built-in type handling is preserved. Because the encoder is used + consistently across LLMeter (payloads, results, run configs), any type + that implements a ``to_dict()`` method will be serialized automatically + without needing a custom encoder. + + Example:: + + Subclassing to handle a custom type: + + >>> import json + >>> import numpy as np + >>> from llmeter.json_utils import LLMeterEncoder + >>> + >>> class MyEncoder(LLMeterEncoder): + ... def default(self, obj): + ... if isinstance(obj, np.ndarray): + ... return obj.tolist() + ... return super().default(obj) + ... + >>> json.dumps({"data": np.array([1, 2, 3])}, cls=MyEncoder) + '{"data": [1, 2, 3]}' + + Using ``to_dict()`` (no subclassing needed): + + >>> class MyPayload: + ... def __init__(self, model_id, temperature): + ... self.model_id = model_id + ... self.temperature = temperature + ... def to_dict(self): + ... return {"model_id": self.model_id, "temperature": self.temperature} + ... + >>> json.dumps({"payload": MyPayload("gpt-4", 0.7)}, cls=LLMeterEncoder) + '{"payload": {"model_id": "gpt-4", "temperature": 0.7}}' + """ + + def default(self, obj: Any) -> Any: + """Encode a single non-serializable object. + + Args: + obj: The object that the default encoder could not handle. + + Returns: + A JSON-serializable representation of *obj*. + """ + if hasattr(obj, "to_dict") and callable(obj.to_dict): + return obj.to_dict() + if isinstance(obj, bytes): + return {"__llmeter_bytes__": base64.b64encode(obj).decode("utf-8")} + if isinstance(obj, datetime): + if obj.tzinfo is not None: + obj = obj.astimezone(timezone.utc) + return obj.isoformat(timespec="seconds").replace("+00:00", "Z") + if isinstance(obj, (date, time)): + return obj.isoformat() + if isinstance(obj, (os.PathLike, Path)): + return Path(obj).as_posix() + try: + return str(obj) + except Exception: + return None + + +def llmeter_bytes_decoder(dct: dict) -> dict | bytes: + """Decode ``__llmeter_bytes__`` marker objects back to ``bytes``. + + Intended for use as the ``object_hook`` argument to :func:`json.loads` or + :func:`json.load`. Marker objects produced by :class:`LLMeterEncoder` are + detected and converted back to ``bytes``; all other dicts pass through unchanged. + + Args: + dct: A dictionary produced by the JSON parser. + + Returns: + The original ``bytes`` if *dct* is a marker object, otherwise *dct* unchanged. + + Example:: + + >>> import json + >>> from llmeter.json_utils import llmeter_bytes_decoder + >>> json.loads('{"__llmeter_bytes__": "/9j/4A=="}', object_hook=llmeter_bytes_decoder) + b'\\xff\\xd8\\xff\\xe0' + """ + if "__llmeter_bytes__" in dct and len(dct) == 1: + return base64.b64decode(dct["__llmeter_bytes__"]) + return dct diff --git a/llmeter/plotting/plotting.py b/llmeter/plotting/plotting.py index 8852550..23561b5 100644 --- a/llmeter/plotting/plotting.py +++ b/llmeter/plotting/plotting.py @@ -22,7 +22,7 @@ try: import kaleido -except ModuleNotFoundError as e: +except ImportError as e: kaleido = DeferredError(e) from typing import TYPE_CHECKING diff --git a/llmeter/prompt_utils.py b/llmeter/prompt_utils.py index 0cf5126..b7b190c 100644 --- a/llmeter/prompt_utils.py +++ b/llmeter/prompt_utils.py @@ -3,24 +3,157 @@ import json import logging +import random from dataclasses import dataclass from itertools import product -import os -import random from typing import Any, Callable, Iterator from upath import UPath as Path +from upath.types import ReadablePathLike, WritablePathLike +from .json_utils import LLMeterEncoder, llmeter_bytes_decoder from .tokenizers import DummyTokenizer, Tokenizer +from .utils import DeferredError, ensure_path logger = logging.getLogger(__name__) +# Optional dependency: puremagic for content-based format detection +try: + import puremagic +except ImportError as e: + logger.debug( + "puremagic not available. Format detection will fall back to file extensions. " + "Install with: pip install 'llmeter[multimodal]'" + ) + puremagic = DeferredError(e) + + +# Multi-modal content utilities + + +def read_file(file_path: str) -> bytes: + """Read binary content from a file. + + Args: + file_path: Path to the file + + Returns: + bytes: File content + + Raises: + FileNotFoundError: If file doesn't exist + IOError: If file cannot be read + """ + try: + _path = ensure_path(file_path) + with _path.open("rb") as f: + return f.read() + except FileNotFoundError: + raise FileNotFoundError(f"File not found: {file_path}") + except Exception as e: + raise IOError(f"Failed to read file {file_path}: {e}") + + +def detect_format_from_extension(file_path: str) -> str | None: + """Detect MIME type from file extension. + + Args: + file_path: Path to the file + + Returns: + str | None: MIME type or None if extension not recognized + + Examples: + >>> detect_format_from_extension("image.jpg") + "image/jpeg" + >>> detect_format_from_extension("document.pdf") + "application/pdf" + """ + extension = ensure_path(file_path).suffix.lower() + + # Map common extensions to MIME types + extension_to_mime = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".gif": "image/gif", + ".webp": "image/webp", + ".pdf": "application/pdf", + ".mp4": "video/mp4", + ".mov": "video/quicktime", + ".avi": "video/x-msvideo", + ".mp3": "audio/mpeg", + ".wav": "audio/wav", + ".ogg": "audio/ogg", + } + + return extension_to_mime.get(extension) + + +def detect_format_from_bytes(content: bytes) -> str | None: + """Detect MIME type from bytes content using puremagic. + + Args: + content: Binary content + + Returns: + str | None: MIME type or None if detection fails or puremagic not available + + Examples: + >>> detect_format_from_bytes(b"\\xff\\xd8\\xff\\xe0") # JPEG magic bytes + "image/jpeg" + """ + try: + # Get MIME type from content using puremagic (v2.0+ API) + mime_type = puremagic.from_string(content, mime=True) + return mime_type if mime_type else None + except (ImportError, AttributeError): + # puremagic not available or DeferredError raised + return None + except Exception: + pass + + return None + + +def detect_format_from_file(file_path: str) -> str | None: + """Detect MIME type from file using puremagic or extension fallback. + + Args: + file_path: Path to the file + + Returns: + str | None: MIME type or None if format cannot be detected + + Examples: + >>> detect_format_from_file("photo.jpg") + "image/jpeg" + """ + # Try puremagic first if available + try: + matches = puremagic.magic_file(file_path) + if matches: + # Extract MIME type from first match + mime_type = ( + matches[0].mime_type if hasattr(matches[0], "mime_type") else None + ) + if mime_type: + return mime_type + except (ImportError, AttributeError): + # puremagic not available or DeferredError raised + pass + except Exception: + pass + + # Fallback to extension-based detection + return detect_format_from_extension(file_path) + @dataclass class CreatePromptCollection: input_lengths: list[int] output_lengths: list[int] - source_file: os.PathLike + source_file: ReadablePathLike requests_per_combination: int = 1 tokenizer: Tokenizer | None = None source_file_encoding: str = "utf-8-sig" @@ -36,8 +169,8 @@ def create_collection(self) -> list[Any]: ) return random.sample(collection, k=len(collection)) - def _generate_sample(self, source_file: os.PathLike, sample_size: int) -> str: - source_file = Path(source_file) + def _generate_sample(self, source_file: ReadablePathLike, sample_size: int) -> str: + source_file = ensure_path(source_file) sample = [] with source_file.open(encoding=self.source_file_encoding, mode="r") as f: for line in f: @@ -54,7 +187,7 @@ def _generate_samples(self) -> None: def load_prompts( - file_path: os.PathLike, + file_path: ReadablePathLike, create_payload_fn: Callable, create_payload_kwargs: dict = {}, file_pattern: str | None = None, @@ -83,7 +216,7 @@ def load_prompts( """ - file_path = Path(file_path) + file_path = ensure_path(file_path) if file_path.is_file(): with file_path.open(mode="r") as f: for line in f: @@ -107,12 +240,24 @@ def load_prompts( continue -def load_payloads(file_path: os.PathLike | str) -> Iterator[dict]: +def load_payloads( + file_path: ReadablePathLike, +) -> Iterator[dict]: """ - Load JSON payload(s) from a file or directory. + Load JSON payload(s) from a file or directory with binary content support. This function reads JSON data from either a single file or multiple files - in a directory. It supports both .json and .jsonl file formats. + in a directory. It supports both .json and .jsonl file formats. Binary content + (bytes objects) that were serialized using LLMeterEncoder are automatically + restored during deserialization. + + Binary Content Handling: + When loading payloads saved with save_payloads(), marker objects with the key + "__llmeter_bytes__" are automatically detected and converted back to bytes objects. + The base64-encoded strings are decoded to restore the original binary data, + enabling round-trip preservation of multimodal content like images and video. + + The marker object format is: {"__llmeter_bytes__": ""} Args: file_path (Union[Path, str]): Path to a JSON file or a directory @@ -127,8 +272,57 @@ def load_payloads(file_path: os.PathLike | str) -> Iterator[dict]: ValidationError: If the JSON data does not conform to the expected schema. IOError: If there's an error reading the file. + Examples: + Load a Bedrock Converse API payload with image content: + + >>> # Assuming a file was saved with save_payloads() containing binary data + >>> payloads = list(load_payloads("/tmp/output/payload.jsonl")) + >>> payload = payloads[0] + >>> # Binary content is automatically restored as bytes + >>> image_bytes = payload["messages"][0]["content"][1]["image"]["source"]["bytes"] + >>> isinstance(image_bytes, bytes) + True + >>> # The bytes can be used directly with the API + >>> print(f"Image size: {len(image_bytes)} bytes") + Image size: 52341 bytes + + Load multiple payloads with video content: + + >>> for payload in load_payloads("/tmp/output/multimodal.jsonl"): + ... video_content = payload["messages"][0]["content"][1] + ... if "video" in video_content: + ... video_bytes = video_content["video"]["source"]["bytes"] + ... print(f"Loaded video: {len(video_bytes)} bytes") + Loaded video: 1048576 bytes + + Load all payloads from a directory: + + >>> # Load all .json and .jsonl files in a directory + >>> all_payloads = list(load_payloads("/tmp/output/")) + >>> print(f"Loaded {len(all_payloads)} payloads") + Loaded 5 payloads + + Round-trip example showing binary preservation: + + >>> # Original payload with binary data + >>> original = { + ... "modelId": "test-model", + ... "messages": [{ + ... "role": "user", + ... "content": [ + ... {"image": {"source": {"bytes": b"\\xff\\xd8\\xff\\xe0"}}} + ... ] + ... }] + ... } + >>> # Save and load + >>> save_payloads(original, "/tmp/test") + PosixPath('/tmp/test/payload.jsonl') + >>> loaded = list(load_payloads("/tmp/test/payload.jsonl"))[0] + >>> # Binary data is preserved byte-for-byte + >>> original == loaded + True """ - file_path = Path(file_path) + file_path = ensure_path(file_path) if not file_path.exists(): raise FileNotFoundError(f"The specified path does not exist: {file_path}") @@ -148,11 +342,13 @@ def _load_data_file(file: Path) -> Iterator[dict]: try: if not line.strip(): continue - yield json.loads(line.strip()) + yield json.loads( + line.strip(), object_hook=llmeter_bytes_decoder + ) except json.JSONDecodeError as e: print(f"Error decoding JSON in {file}: {e}") else: # Assume it's a regular JSON file - yield json.load(f) + yield json.load(f, object_hook=llmeter_bytes_decoder) except IOError as e: print(f"Error reading file {file}: {e}") except json.JSONDecodeError as e: @@ -161,24 +357,109 @@ def _load_data_file(file: Path) -> Iterator[dict]: def save_payloads( payloads: list[dict] | dict, - output_path: os.PathLike | str, + output_path: WritablePathLike, output_file: str = "payload.jsonl", ) -> Path: """ - Save payloads to a file. + Save payloads to a file with support for binary content. + + This function saves payloads to a JSONL file, with automatic handling of binary + content (bytes objects) through base64 encoding. Binary data is wrapped in marker + objects during serialization to enable round-trip preservation. + + Binary Content Handling: + When a payload contains bytes objects (e.g., images, video), they are automatically + converted to base64-encoded strings and wrapped in a marker object with the key + "__llmeter_bytes__". This approach enables JSON serialization while preserving + the ability to restore the original bytes during deserialization with load_payloads(). + + The marker object format is: {"__llmeter_bytes__": ""} Args: - payloads (Iterator[Dict]): An iterator of payloads (dicts). + payloads (Union[list[dict], dict]): Payload(s) to save. May contain bytes objects + at any nesting level. output_path (Union[Path, str]): The directory path where the output file should be saved. - output_file (str, optional): The name of the output file. Defaults to "payloads.jsonl". + output_file (str, optional): The name of the output file. Defaults to "payload.jsonl". Returns: - output_file_path (UPath): The path to the output file. + Path: The path to the output file. Raises: IOError: If there's an error writing to the file. + TypeError: If payload contains unserializable types. + + Examples: + Save a Bedrock Converse API payload with image content: + + >>> import base64 + >>> # Create a payload with binary image data + >>> with open("image.jpg", "rb") as f: + ... image_bytes = f.read() + >>> payload = { + ... "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + ... "messages": [{ + ... "role": "user", + ... "content": [ + ... {"text": "What is in this image?"}, + ... { + ... "image": { + ... "format": "jpeg", + ... "source": {"bytes": image_bytes} + ... } + ... } + ... ] + ... }] + ... } + >>> output_path = save_payloads(payload, "/tmp/output") + >>> print(output_path) + /tmp/output/payload.jsonl + + Save multiple payloads with video content: + + >>> with open("video.mp4", "rb") as f: + ... video_bytes = f.read() + >>> payloads = [ + ... { + ... "modelId": "anthropic.claude-3-sonnet-20240229-v1:0", + ... "messages": [{ + ... "role": "user", + ... "content": [ + ... {"text": "Describe this video"}, + ... { + ... "video": { + ... "format": "mp4", + ... "source": {"bytes": video_bytes} + ... } + ... } + ... ] + ... }] + ... } + ... ] + >>> save_payloads(payloads, "/tmp/output", "multimodal.jsonl") + PosixPath('/tmp/output/multimodal.jsonl') + + The saved JSON file will contain marker objects for binary data: + + >>> # Example of what gets written to the file: + >>> # { + >>> # "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + >>> # "messages": [{ + >>> # "role": "user", + >>> # "content": [ + >>> # {"text": "What is in this image?"}, + >>> # { + >>> # "image": { + >>> # "format": "jpeg", + >>> # "source": { + >>> # "bytes": {"__llmeter_bytes__": "/9j/4AAQSkZJRg..."} + >>> # } + >>> # } + >>> # } + >>> # ] + >>> # }] + >>> # } """ - output_path = Path(output_path) + output_path = ensure_path(output_path) output_path.mkdir(parents=True, exist_ok=True) output_file_path = output_path / output_file @@ -186,5 +467,5 @@ def save_payloads( payloads = [payloads] with output_file_path.open(mode="w") as f: for payload in payloads: - f.write(json.dumps(payload) + "\n") + f.write(json.dumps(payload, cls=LLMeterEncoder) + "\n") return output_file_path diff --git a/llmeter/results.py b/llmeter/results.py index 45a1c00..f5481a3 100644 --- a/llmeter/results.py +++ b/llmeter/results.py @@ -3,42 +3,22 @@ import json import logging -import os from dataclasses import asdict, dataclass -from datetime import datetime, timezone +from datetime import datetime from functools import cached_property from numbers import Number from typing import Any, Sequence import jmespath -from upath import UPath as Path +from upath.types import ReadablePathLike, WritablePathLike from .endpoints import InvocationResponse -from .utils import summary_stats_from_list +from .json_utils import LLMeterEncoder +from .utils import ensure_path, summary_stats_from_list logger = logging.getLogger(__name__) -def utc_datetime_serializer(obj: Any) -> str: - """ - Serialize datetime objects to UTC ISO format strings. - - Args: - obj: Object to serialize. If datetime, converts to ISO format string with 'Z' timezone. - Otherwise returns string representation. - - Returns: - str: ISO format string with 'Z' timezone for datetime objects, or string representation - for other objects. - """ - if isinstance(obj, datetime): - # Convert to UTC if timezone is set - if obj.tzinfo is not None: - obj = obj.astimezone(timezone.utc) - return obj.isoformat(timespec="seconds").replace("+00:00", "Z") - return str(obj) - - @dataclass class Result: """Results of a test run.""" @@ -49,7 +29,7 @@ class Result: n_requests: int total_test_time: float | None = None model_id: str | None = None - output_path: os.PathLike | None = None + output_path: WritablePathLike | None = None endpoint_name: str | None = None provider: str | None = None run_name: str | None = None @@ -58,7 +38,7 @@ class Result: end_time: datetime | None = None def __str__(self): - return json.dumps(self.stats, indent=4, default=utc_datetime_serializer) + return json.dumps(self.stats, indent=4, cls=LLMeterEncoder) def __post_init__(self): """Initialize the Result instance.""" @@ -81,7 +61,7 @@ def _update_contributed_stats(self, stats: dict[str, Number]): ) self._contributed_stats.update(stats) - def save(self, output_path: os.PathLike | str | None = None): + def save(self, output_path: WritablePathLike | None = None): """ Save the results to disk or cloud storage. @@ -109,9 +89,8 @@ def save(self, output_path: os.PathLike | str | None = None): which provides a unified interface for working with different file systems. """ - try: - output_path = Path(self.output_path or output_path) - except TypeError: + output_path = ensure_path(self.output_path or output_path) + if output_path is None: raise ValueError("No output path provided") output_path.mkdir(parents=True, exist_ok=True) @@ -120,28 +99,32 @@ def save(self, output_path: os.PathLike | str | None = None): stats_path = output_path / "stats.json" with summary_path.open("w") as f, stats_path.open("w") as s: f.write(self.to_json(indent=4)) - s.write(json.dumps(self.stats, indent=4, default=utc_datetime_serializer)) + s.write(json.dumps(self.stats, indent=4, cls=LLMeterEncoder)) responses_path = output_path / "responses.jsonl" if not responses_path.exists(): with responses_path.open("w") as f: for response in self.responses: - f.write(json.dumps(asdict(response)) + "\n") + f.write(json.dumps(asdict(response), cls=LLMeterEncoder) + "\n") def to_json(self, **kwargs): """Return the results as a JSON string.""" + kwargs.setdefault("cls", LLMeterEncoder) summary = { k: o for k, o in asdict(self).items() if k not in ["responses", "stats"] } - return json.dumps(summary, default=utc_datetime_serializer, **kwargs) + return json.dumps(summary, **kwargs) def to_dict(self, include_responses: bool = False): - """Return the results as a dictionary.""" + """Return the results as a dictionary with JSON-serializable values.""" + data = asdict(self) + # Serialize datetime objects so stats dict is always JSON-safe + for key in ("start_time", "end_time"): + if key in data and isinstance(data[key], datetime): + data[key] = LLMeterEncoder().default(data[key]) if include_responses: - return asdict(self) - return { - k: o for k, o in asdict(self).items() if k not in ["responses", "stats"] - } + return data + return {k: v for k, v in data.items() if k not in ["responses", "stats"]} def load_responses(self) -> list[InvocationResponse]: """ @@ -163,7 +146,7 @@ def load_responses(self) -> list[InvocationResponse]: raise ValueError( "No output_path set on this Result. Cannot locate responses file." ) - responses_path = Path(self.output_path) / "responses.jsonl" + responses_path = ensure_path(self.output_path) / "responses.jsonl" with responses_path.open("r") as f: self.responses = [ InvocationResponse(**json.loads(line)) for line in f if line @@ -175,7 +158,7 @@ def load_responses(self) -> list[InvocationResponse]: @classmethod def load( - cls, result_path: os.PathLike | str, load_responses: bool = True + cls, result_path: ReadablePathLike, load_responses: bool = True ) -> "Result": """ Load run results from disk or cloud storage. @@ -205,7 +188,7 @@ def load( either file. """ - result_path = Path(result_path) + result_path = ensure_path(result_path) summary_path = result_path / "summary.json" with summary_path.open("r") as g: diff --git a/llmeter/runner.py b/llmeter/runner.py index 47626e3..6f56d1b 100644 --- a/llmeter/runner.py +++ b/llmeter/runner.py @@ -19,14 +19,16 @@ from tqdm.auto import tqdm, trange from upath import UPath as Path +from upath.types import ReadablePathLike, WritablePathLike -from llmeter.utils import now_utc +from .utils import ensure_path, now_utc if TYPE_CHECKING: # Avoid circular import: We only need typing for Callback from .callbacks.base import Callback from .endpoints.base import Endpoint, InvocationResponse +from .json_utils import LLMeterEncoder from .prompt_utils import load_payloads, save_payloads from .results import Result from .tokenizers import DummyTokenizer, Tokenizer @@ -52,11 +54,11 @@ class _RunConfig: """ endpoint: Endpoint | dict | None = None - output_path: str | Path | None = None + output_path: WritablePathLike | None = None tokenizer: Tokenizer | Any | None = None clients: int = 1 n_requests: int | None = None - payload: dict | list[dict] | os.PathLike | str | None = None + payload: dict | list[dict] | ReadablePathLike | None = None run_name: str | None = None run_description: str | None = None timeout: int | float = 60 @@ -84,7 +86,7 @@ def __post_init__(self, disable_client_progress_bar, disable_clients_progress_ba self._endpoint = self.endpoint if self.output_path is not None: - self.output_path = Path(self.output_path) + self.output_path = ensure_path(self.output_path) if self.tokenizer is None: self.tokenizer = DummyTokenizer() @@ -95,7 +97,7 @@ def __post_init__(self, disable_client_progress_bar, disable_clients_progress_ba def save( self, - output_path: os.PathLike | str | None = None, + output_path: WritablePathLike | None = None, file_name: str = "run_config.json", ): """Save the configuration to a disk or cloud storage. @@ -104,13 +106,13 @@ def save( output_path: Optional override for output folder. By default, self.output_path is used. file_name: File name to create under `output_path`. """ - output_path = Path(output_path or self.output_path) + output_path = ensure_path(output_path or self.output_path) output_path.mkdir(parents=True, exist_ok=True) run_config_path = output_path / file_name config_copy = replace(self) - if self.payload and (not isinstance(self.payload, (os.PathLike, str))): + if self.payload and (not isinstance(self.payload, (Path, str))): payload_path = save_payloads(self.payload, output_path) config_copy.payload = payload_path @@ -121,8 +123,11 @@ def save( if not isinstance(self.tokenizer, dict): config_copy.tokenizer = Tokenizer.to_dict(self.tokenizer) + if self.callbacks: + config_copy.callbacks = [cb.to_dict() for cb in self.callbacks] + with run_config_path.open("w") as f: - f.write(json.dumps(asdict(config_copy), default=str, indent=4)) + f.write(json.dumps(asdict(config_copy), cls=LLMeterEncoder, indent=4)) @classmethod def load(cls, load_path: Path | str, file_name: str = "run_config.json"): @@ -132,11 +137,15 @@ def load(cls, load_path: Path | str, file_name: str = "run_config.json"): output_path: Folder under which the configuration is stored file_name: File name within `output_path` for the run configuration JSON. """ - load_path = Path(load_path) - with open(load_path / file_name) as f: + load_path = ensure_path(load_path) + with (load_path / file_name).open() as f: config = json.load(f) config["endpoint"] = Endpoint.load(config["endpoint"]) config["tokenizer"] = Tokenizer.load(config["tokenizer"]) + if config.get("callbacks"): + from .callbacks.base import Callback # deferred: callbacks.base imports _RunConfig + + config["callbacks"] = [Callback.from_dict(cb) for cb in config["callbacks"]] return cls(**config) @@ -149,15 +158,15 @@ class _Run(_RunConfig): """ def __post_init__(self, disable_client_progress_bar, disable_clients_progress_bar): - assert ( - self.run_name is not None - ), "Test Run must be created with an explicit run_name" + assert self.run_name is not None, ( + "Test Run must be created with an explicit run_name" + ) super().__post_init__(disable_client_progress_bar, disable_clients_progress_bar) - assert ( - self.endpoint is not None - ), "Test Run must be created with an explicit Endpoint" + assert self.endpoint is not None, ( + "Test Run must be created with an explicit Endpoint" + ) self._validate_and_prepare_payload() self._responses = [] @@ -168,7 +177,7 @@ def _validate_and_prepare_payload(self): This method ensures that the payload is valid and prepared for the test run. """ assert self.payload, "No payload provided" - if isinstance(self.payload, (os.PathLike, str)): + if isinstance(self.payload, (Path, str)): self.payload = list(load_payloads(self.payload)) if isinstance(self.payload, dict): self.payload = [self.payload] @@ -403,7 +412,7 @@ async def _invoke_n_c( end_t = time.perf_counter() total_test_time = end_t - start_t logger.info( - f"Generated {clients} connections with {n_requests} invocations each in {total_test_time*1000:.2f} seconds" + f"Generated {clients} connections with {n_requests} invocations each in {total_test_time * 1000:.2f} seconds" ) # Signal the token counting task to exit @@ -455,7 +464,7 @@ async def _run(self): run_start_time = now_utc() _, (total_test_time, start_time, end_time) = await asyncio.gather( self._process_results_from_q( - output_path=Path(self.output_path) / "responses.jsonl" + output_path=ensure_path(self.output_path) / "responses.jsonl" if self.output_path else None, ), @@ -474,7 +483,7 @@ async def _run(self): return result self._progress_bar.close() - logger.info(f"Test completed in {total_test_time*1000:.2f} seconds.") + logger.info(f"Test completed in {total_test_time * 1000:.2f} seconds.") result = replace( result, @@ -575,7 +584,9 @@ def _prepare_run(self, **kwargs) -> _Run: run_params["run_name"] = f"{datetime.now():%Y%m%d-%H%M}" if self.output_path and not kwargs.get("output_path"): # Run output path is nested under run name subfolder unless explicitly set: - run_params["output_path"] = Path(self.output_path) / run_params["run_name"] + run_params["output_path"] = ( + ensure_path(self.output_path) / run_params["run_name"] + ) # Validate that clients parameter is set and is a positive integer clients = run_params.get("clients") if clients is None: @@ -595,7 +606,7 @@ async def run( tokenizer: Tokenizer | Any | None = None, clients: int | None = None, n_requests: int | None = None, - payload: dict | list[dict] | os.PathLike | str | None = None, + payload: dict | list[dict] | ReadablePathLike | None = None, run_name: str | None = None, run_description: str | None = None, timeout: int | float | None = None, diff --git a/llmeter/tokenizers.py b/llmeter/tokenizers.py index 6f4e893..91a606b 100644 --- a/llmeter/tokenizers.py +++ b/llmeter/tokenizers.py @@ -54,7 +54,8 @@ def load_from_file(tokenizer_path: UPath | None) -> Tokenizer: """ if tokenizer_path is None: return DummyTokenizer() - with open(tokenizer_path, "r") as f: + tokenizer_path = UPath(tokenizer_path) + with tokenizer_path.open("r") as f: tokenizer_info = json.load(f) return _load_tokenizer_from_info(tokenizer_info) @@ -123,7 +124,7 @@ def save_tokenizer(tokenizer: Any, output_path: UPath | str) -> UPath: output_path = UPath(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "w") as f: + with output_path.open("w") as f: json.dump(tokenizer_info, f) return output_path diff --git a/llmeter/utils.py b/llmeter/utils.py index d072e58..c9a0170 100644 --- a/llmeter/utils.py +++ b/llmeter/utils.py @@ -4,7 +4,10 @@ from itertools import filterfalse from math import isnan from statistics import StatisticsError, mean, median, quantiles -from typing import Any, Sequence +from typing import Any, Sequence, overload + +from upath import UPath +from upath.types import ReadablePathLike, WritablePathLike class DeferredError: @@ -90,3 +93,30 @@ def now_utc() -> datetime: datetime: Current UTC datetime object """ return datetime.now(timezone.utc) + + +@overload +def ensure_path(path: ReadablePathLike | WritablePathLike) -> UPath: ... + + +@overload +def ensure_path(path: ReadablePathLike | WritablePathLike | None) -> UPath | None: ... + + +def ensure_path( + path: ReadablePathLike | WritablePathLike | None = None, +) -> UPath | None: + """Normalize a path-like argument to a UPath instance. + + Converts strings, os.PathLike objects, and UPath instances into a + consistent UPath representation. Passes through None unchanged. + + Args: + path: A string, path-like object, or None. + + Returns: + A UPath instance, or None if the input was None. + """ + if path is None: + return None + return UPath(path) diff --git a/mkdocs.yml b/mkdocs.yml index 99d5198..1c7e808 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -61,6 +61,7 @@ nav: - openai: reference/endpoints/openai.md - sagemaker: reference/endpoints/sagemaker.md - experiments: reference/experiments.md + - json_utils: reference/json_utils.md - plotting: - reference/plotting/index.md - results: reference/results.md diff --git a/pyproject.toml b/pyproject.toml index 6a5062d..ba00368 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,8 @@ openai = ["openai>=1.35.1"] litellm = ["litellm>=1.47.1"] plotting = ["plotly>=5.24.1", "kaleido<=0.2.1", "pandas>=2.2.0"] mlflow = ["mlflow-skinny>=3.10.0"] -all = ["openai>=1.35.1", "litellm>=1.47.1", "plotly>=5.24.1", "kaleido<=0.2.1", "pandas>=2.2.0", "mlflow-skinny>=3.10.0"] +multimodal = ["puremagic>=1.28"] +all = ["openai>=1.35.1", "litellm>=1.47.1", "plotly>=5.24.1", "kaleido<=0.2.1", "pandas>=2.2.0", "mlflow-skinny>=3.10.0", "puremagic>=1.28"] [project.urls] Repository = "https://github.com/awslabs/llmeter" @@ -57,12 +58,16 @@ test = [ "pytest-cov>=5.0.0", "pillow>=12.1.1", "aws-bedrock-token-generator>=1.1.0", + # Include all optional dependencies for comprehensive testing "openai>=1.35.1", "litellm>=1.47.1", "plotly>=5.24.1", "kaleido<=0.2.1", "pandas>=2.2.0", "mlflow-skinny>=3.10.0", + "puremagic>=1.28", + "transformers>=4.40.2", + "tiktoken>=0.7.0", ] [tool.pytest.ini_options] diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index 8ec054a..2459eb4 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -116,10 +116,59 @@ def bedrock_openai_test_model(): ) +@pytest.fixture(scope="session") +def bedrock_openai_multimodal_test_model(): + """ + Get test model ID for OpenAI SDK multimodal tests. + + The model ID can be overridden via the BEDROCK_OPENAI_MULTIMODAL_TEST_MODEL environment variable. + Defaults to qwen.qwen3-vl-235b-a22b-instruct which supports images and is available in Bedrock. + + Note: This model is specifically for testing multimodal content (images, video, etc.) + via Bedrock's OpenAI-compatible endpoint. + + Returns: + str: OpenAI multimodal model ID for Bedrock OpenAI SDK testing. + """ + return os.environ.get( + "BEDROCK_OPENAI_MULTIMODAL_TEST_MODEL", "qwen.qwen3-vl-235b-a22b-instruct" + ) + + +@pytest.fixture +def test_image_bytes(): + """ + Create test images as binary JPEG data for multimodal testing. + + Returns two small JPEG images (32x32 pixels) as bytes objects. + These are used to test binary content serialization and API calls. + JPEG format is used for broad model compatibility. + + Returns: + tuple: (image1_bytes, image2_bytes) - Two JPEG images as bytes + """ + import io + from PIL import Image + + # 32x32 red square JPEG - binary format + img1 = Image.new("RGB", (32, 32), color=(255, 0, 0)) + buf1 = io.BytesIO() + img1.save(buf1, format="JPEG") + image1_bytes = buf1.getvalue() + + # 32x32 blue square JPEG - binary format + img2 = Image.new("RGB", (32, 32), color=(0, 0, 255)) + buf2 = io.BytesIO() + img2.save(buf2, format="JPEG") + image2_bytes = buf2.getvalue() + + return image1_bytes, image2_bytes + + @pytest.fixture(scope="session") def bedrock_openai_endpoint_url(aws_region): """ - Construct Bedrock OpenAI-compatible endpoint URL. + Construct Bedrock OpenAI-compatible endpoint URL for standard models. Args: aws_region: AWS region from the aws_region fixture. @@ -130,6 +179,23 @@ def bedrock_openai_endpoint_url(aws_region): return f"https://bedrock-runtime.{aws_region}.amazonaws.com/openai/v1" +@pytest.fixture(scope="session") +def bedrock_openai_multimodal_endpoint_url(aws_region): + """ + Construct Bedrock OpenAI-compatible endpoint URL for multimodal models (Bedrock Mantle). + + This endpoint supports non-OpenAI models (e.g., Qwen, Kimi) via the OpenAI-compatible + interface and is required for multimodal content (images, video). + + Args: + aws_region: AWS region from the aws_region fixture. + + Returns: + str: Bedrock Mantle endpoint URL for multimodal testing. + """ + return f"https://bedrock-mantle.{aws_region}.api.aws/v1" + + @pytest.fixture def test_payload(): """ diff --git a/tests/integ/test_bedrock_converse.py b/tests/integ/test_bedrock_converse.py index 76e4652..359541c 100644 --- a/tests/integ/test_bedrock_converse.py +++ b/tests/integ/test_bedrock_converse.py @@ -335,3 +335,254 @@ def test_bedrock_converse_streaming_with_image( # Verify response has an ID assert response.id is not None, "Response should have an ID" + + +def test_save_load_payload_with_image(test_payload_with_image, tmp_path): + """ + Test saving and loading payload with image content using serialization. + + This test validates that: + - Payloads with image bytes can be saved to disk + - Saved payloads contain valid JSON with marker objects + - Loaded payloads restore bytes objects correctly + - Round-trip preserves byte-for-byte equality + + **Validates: Requirements 1.6, 9.3** + + Args: + test_payload_with_image: Test payload with image content (from fixture). + tmp_path: Temporary directory for test files (from pytest). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + + # Save payload with image + output_file = save_payloads(test_payload_with_image, tmp_path, "test_image.jsonl") + assert output_file.exists(), "Output file should be created" + + # Verify file contains valid JSON with marker objects + with output_file.open("r") as f: + content = f.read() + assert "__llmeter_bytes__" in content, "File should contain marker objects" + + # Load payload and verify bytes are restored + loaded_payloads = list(load_payloads(output_file)) + assert len(loaded_payloads) == 1, "Should load exactly one payload" + + loaded = loaded_payloads[0] + original_bytes = test_payload_with_image["messages"][0]["content"][0]["image"]["source"]["bytes"] + loaded_bytes = loaded["messages"][0]["content"][0]["image"]["source"]["bytes"] + + assert isinstance(loaded_bytes, bytes), "Loaded bytes should be bytes type" + assert loaded_bytes == original_bytes, "Bytes should match after round-trip" + assert loaded == test_payload_with_image, "Full payload should match after round-trip" + + +def test_save_load_payload_with_video(tmp_path): + """ + Test saving and loading payload with video content using serialization. + + This test validates that: + - Payloads with video bytes can be saved to disk + - Saved payloads contain valid JSON with marker objects + - Loaded payloads restore bytes objects correctly + - Video content path (messages[].content[].video.source.bytes) is handled + + **Validates: Requirements 1.6, 9.4** + + Args: + tmp_path: Temporary directory for test files (from pytest). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + + # Create a test payload with video content (simulated video bytes) + video_payload = { + "messages": [ + { + "role": "user", + "content": [ + { + "video": { + "format": "mp4", + "source": {"bytes": b"\x00\x00\x00\x18ftypmp42"}, # MP4 header + } + }, + {"text": "What is happening in this video?"}, + ], + } + ], + "inferenceConfig": {"maxTokens": 150}, + } + + # Save payload with video + output_file = save_payloads(video_payload, tmp_path, "test_video.jsonl") + assert output_file.exists(), "Output file should be created" + + # Verify file contains valid JSON with marker objects + with output_file.open("r") as f: + content = f.read() + assert "__llmeter_bytes__" in content, "File should contain marker objects" + + # Load payload and verify bytes are restored + loaded_payloads = list(load_payloads(output_file)) + assert len(loaded_payloads) == 1, "Should load exactly one payload" + + loaded = loaded_payloads[0] + original_bytes = video_payload["messages"][0]["content"][0]["video"]["source"]["bytes"] + loaded_bytes = loaded["messages"][0]["content"][0]["video"]["source"]["bytes"] + + assert isinstance(loaded_bytes, bytes), "Loaded bytes should be bytes type" + assert loaded_bytes == original_bytes, "Bytes should match after round-trip" + assert loaded == video_payload, "Full payload should match after round-trip" + + +def test_save_load_multiple_images(tmp_path): + """ + Test saving and loading payload with multiple images in single payload. + + This test validates that: + - Payloads with multiple image bytes can be saved to disk + - All images are correctly serialized with marker objects + - All images are correctly restored after loading + - Multiple bytes objects in same payload are handled independently + + **Validates: Requirements 9.8** + + Args: + tmp_path: Temporary directory for test files (from pytest). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + + # Create a test payload with multiple images + multi_image_payload = { + "messages": [ + { + "role": "user", + "content": [ + { + "image": { + "format": "png", + "source": {"bytes": b"\x89PNG\r\n\x1a\n"}, # PNG header + } + }, + {"text": "Compare these images:"}, + { + "image": { + "format": "jpeg", + "source": {"bytes": b"\xff\xd8\xff\xe0"}, # JPEG header + } + }, + { + "image": { + "format": "png", + "source": {"bytes": b"\x89PNG\r\n\x1a\n\x00\x00"}, # Different PNG + } + }, + ], + } + ], + "inferenceConfig": {"maxTokens": 200}, + } + + # Save payload with multiple images + output_file = save_payloads(multi_image_payload, tmp_path, "test_multi_image.jsonl") + assert output_file.exists(), "Output file should be created" + + # Verify file contains multiple marker objects + with output_file.open("r") as f: + content = f.read() + marker_count = content.count("__llmeter_bytes__") + assert marker_count == 3, f"File should contain 3 marker objects, found {marker_count}" + + # Load payload and verify all bytes are restored + loaded_payloads = list(load_payloads(output_file)) + assert len(loaded_payloads) == 1, "Should load exactly one payload" + + loaded = loaded_payloads[0] + content = loaded["messages"][0]["content"] + + # Verify first image + assert isinstance(content[0]["image"]["source"]["bytes"], bytes) + assert content[0]["image"]["source"]["bytes"] == b"\x89PNG\r\n\x1a\n" + + # Verify second image + assert isinstance(content[2]["image"]["source"]["bytes"], bytes) + assert content[2]["image"]["source"]["bytes"] == b"\xff\xd8\xff\xe0" + + # Verify third image + assert isinstance(content[3]["image"]["source"]["bytes"], bytes) + assert content[3]["image"]["source"]["bytes"] == b"\x89PNG\r\n\x1a\n\x00\x00" + + # Verify full payload matches + assert loaded == multi_image_payload, "Full payload should match after round-trip" + + +@pytest.mark.integ +def test_round_trip_bedrock_converse_structure(test_payload_with_image, tmp_path, aws_credentials, aws_region): + """ + Test round-trip serialization with actual Bedrock Converse API structure. + + This test validates that: + - Complete Bedrock Converse API payload structure is preserved + - All nested fields (modelId, messages, inferenceConfig) are maintained + - Image bytes in messages[].content[].image.source.bytes path are handled + - Round-trip produces identical payload structure + - Loaded payload can be used with the endpoint's invoke method + + **Validates: Requirements 1.6, 9.3** + + Args: + test_payload_with_image: Test payload with image content (from fixture). + tmp_path: Temporary directory for test files (from pytest). + aws_credentials: Boto3 session with valid AWS credentials (from fixture). + aws_region: AWS region for testing (from fixture). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + from llmeter.endpoints.bedrock import BedrockConverse + + # Create a complete Bedrock Converse payload with all typical fields + complete_payload = { + "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "messages": test_payload_with_image["messages"], + "inferenceConfig": { + "maxTokens": 150, + "temperature": 0.7, + "topP": 0.9, + }, + "system": [{"text": "You are a helpful assistant that describes images."}], + } + + # Save and load the complete payload + output_file = save_payloads(complete_payload, tmp_path, "test_complete.jsonl") + loaded_payloads = list(load_payloads(output_file)) + + assert len(loaded_payloads) == 1, "Should load exactly one payload" + loaded = loaded_payloads[0] + + # Verify all top-level fields are preserved + assert loaded["modelId"] == complete_payload["modelId"] + assert loaded["inferenceConfig"] == complete_payload["inferenceConfig"] + assert loaded["system"] == complete_payload["system"] + + # Verify messages structure is preserved + assert len(loaded["messages"]) == len(complete_payload["messages"]) + assert loaded["messages"][0]["role"] == complete_payload["messages"][0]["role"] + + # Verify image bytes are correctly restored + original_bytes = complete_payload["messages"][0]["content"][0]["image"]["source"]["bytes"] + loaded_bytes = loaded["messages"][0]["content"][0]["image"]["source"]["bytes"] + assert isinstance(loaded_bytes, bytes) + assert loaded_bytes == original_bytes + + # Verify complete equality + assert loaded == complete_payload, "Complete payload should match after round-trip" + + # Verify the loaded payload can be used with the endpoint's invoke method + # Extract model_id from the loaded payload + model_id = loaded.pop("modelId") + endpoint = BedrockConverse(model_id=model_id, region=aws_region) + response = endpoint.invoke(loaded) + + # Verify the endpoint successfully processed the loaded payload + assert response.response_text is not None, "Response should contain text" + assert len(response.response_text) > 0, "Response text should not be empty" + assert response.error is None, f"Response should not contain errors: {response.error}" diff --git a/tests/integ/test_bedrock_invoke.py b/tests/integ/test_bedrock_invoke.py index a816f3e..00750e1 100644 --- a/tests/integ/test_bedrock_invoke.py +++ b/tests/integ/test_bedrock_invoke.py @@ -95,9 +95,9 @@ def test_bedrock_invoke_non_streaming(aws_credentials, aws_region, bedrock_test_ # Verify token counts are present and positive assert response.num_tokens_input is not None, "Input token count should not be None" assert response.num_tokens_input > 0, "Input token count should be positive" - assert ( - response.num_tokens_output is not None - ), "Output token count should not be None" + assert response.num_tokens_output is not None, ( + "Output token count should not be None" + ) assert response.num_tokens_output > 0, "Output token count should be positive" # Verify response time is measured and positive @@ -105,9 +105,9 @@ def test_bedrock_invoke_non_streaming(aws_credentials, aws_region, bedrock_test_ assert response.time_to_last_token > 0, "Response time should be positive" # Verify no errors in response - assert ( - response.error is None - ), f"Response should not contain errors: {response.error}" + assert response.error is None, ( + f"Response should not contain errors: {response.error}" + ) # Verify response has an ID assert response.id is not None, "Response should have an ID" @@ -187,32 +187,196 @@ def test_bedrock_invoke_streaming(aws_credentials, aws_region, bedrock_test_mode # Verify token counts are present and positive assert response.num_tokens_input is not None, "Input token count should not be None" assert response.num_tokens_input > 0, "Input token count should be positive" - assert ( - response.num_tokens_output is not None - ), "Output token count should not be None" + assert response.num_tokens_output is not None, ( + "Output token count should not be None" + ) assert response.num_tokens_output > 0, "Output token count should be positive" # Verify time to first token is measured and positive - assert ( - response.time_to_first_token is not None - ), "Time to first token should not be None" + assert response.time_to_first_token is not None, ( + "Time to first token should not be None" + ) assert response.time_to_first_token > 0, "Time to first token should be positive" # Verify time to last token is measured and positive - assert ( - response.time_to_last_token is not None - ), "Time to last token should not be None" + assert response.time_to_last_token is not None, ( + "Time to last token should not be None" + ) assert response.time_to_last_token > 0, "Time to last token should be positive" # Verify TTLT > TTFT (streaming should take time to complete) - assert ( - response.time_to_last_token > response.time_to_first_token - ), "Time to last token should be greater than time to first token" + assert response.time_to_last_token > response.time_to_first_token, ( + "Time to last token should be greater than time to first token" + ) # Verify no errors in response - assert ( - response.error is None - ), f"Response should not contain errors: {response.error}" + assert response.error is None, ( + f"Response should not contain errors: {response.error}" + ) # Verify response has an ID assert response.id is not None, "Response should have an ID" + + +def test_save_load_invoke_payload_with_image(tmp_path): + """ + Test saving and loading Bedrock Invoke API payload with image content. + + This test validates that: + - Bedrock Invoke API payloads with image bytes can be saved to disk + - Provider-specific payload structure (Anthropic Claude format) is preserved + - Saved payloads contain valid JSON with marker objects + - Loaded payloads restore bytes objects correctly + - Round-trip preserves byte-for-byte equality + + **Validates: Requirements 9.5** + + Args: + tmp_path: Temporary directory for test files (from pytest). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + import io + from PIL import Image + + # Create a simple test image + img = Image.new("RGB", (50, 50), color="blue") + img_buffer = io.BytesIO() + img.save(img_buffer, format="PNG") + img_bytes = img_buffer.getvalue() + + # Create Bedrock Invoke API payload with Anthropic Claude format + # This uses the native Messages API format, not Converse API + invoke_payload = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 150, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": img_bytes, # Binary data in provider-specific format + }, + }, + { + "type": "text", + "text": "What is in this image?", + }, + ], + } + ], + } + + # Save payload with image + output_file = save_payloads(invoke_payload, tmp_path, "test_invoke_image.jsonl") + assert output_file.exists(), "Output file should be created" + + # Verify file contains valid JSON with marker objects + with output_file.open("r") as f: + content = f.read() + assert "__llmeter_bytes__" in content, "File should contain marker objects" + + # Load payload and verify bytes are restored + loaded_payloads = list(load_payloads(output_file)) + assert len(loaded_payloads) == 1, "Should load exactly one payload" + + loaded = loaded_payloads[0] + original_bytes = invoke_payload["messages"][0]["content"][0]["source"]["data"] + loaded_bytes = loaded["messages"][0]["content"][0]["source"]["data"] + + assert isinstance(loaded_bytes, bytes), "Loaded bytes should be bytes type" + assert loaded_bytes == original_bytes, "Bytes should match after round-trip" + assert loaded == invoke_payload, "Full payload should match after round-trip" + + +@pytest.mark.integ +def test_round_trip_invoke_structure( + tmp_path, aws_credentials, aws_region, bedrock_test_model +): + """ + Test round-trip serialization with actual Bedrock Invoke API structure. + + This test validates that: + - Complete Bedrock Invoke API payload structure is preserved + - Provider-specific fields (anthropic_version, max_tokens) are maintained + - Text content in messages is preserved correctly + - Round-trip produces identical payload structure + - Loaded payload can be used with the endpoint's invoke method + + **Validates: Requirements 9.5** + + Args: + tmp_path: Temporary directory for test files (from pytest). + aws_credentials: Boto3 session with valid AWS credentials (from fixture). + aws_region: AWS region for testing (from fixture). + bedrock_test_model: Model ID for testing (from fixture). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + + # Create a complete Bedrock Invoke payload with provider-specific structure + # This mimics the actual Anthropic Claude Messages API format used by InvokeModel + complete_invoke_payload = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 200, + "temperature": 0.7, + "top_p": 0.9, + "messages": [ + { + "role": "user", + "content": "Please provide a brief response to this test message.", + } + ], + "system": "You are a helpful assistant.", + } + + # Save and load the complete payload + output_file = save_payloads( + complete_invoke_payload, tmp_path, "test_invoke_complete.jsonl" + ) + loaded_payloads = list(load_payloads(output_file)) + + assert len(loaded_payloads) == 1, "Should load exactly one payload" + loaded = loaded_payloads[0] + + # Verify all top-level fields are preserved + assert loaded["anthropic_version"] == complete_invoke_payload["anthropic_version"] + assert loaded["max_tokens"] == complete_invoke_payload["max_tokens"] + assert loaded["temperature"] == complete_invoke_payload["temperature"] + assert loaded["top_p"] == complete_invoke_payload["top_p"] + assert loaded["system"] == complete_invoke_payload["system"] + + # Verify messages structure is preserved + assert len(loaded["messages"]) == len(complete_invoke_payload["messages"]) + assert ( + loaded["messages"][0]["role"] == complete_invoke_payload["messages"][0]["role"] + ) + assert ( + loaded["messages"][0]["content"] + == complete_invoke_payload["messages"][0]["content"] + ) + + # Verify complete equality + assert loaded == complete_invoke_payload, ( + "Complete payload should match after round-trip" + ) + + # Verify the loaded payload can be used with the endpoint's invoke method + endpoint = BedrockInvoke( + model_id=bedrock_test_model, + region=aws_region, + generated_text_jmespath="content[0].text", + generated_token_count_jmespath="usage.output_tokens", + input_token_count_jmespath="usage.input_tokens", + input_text_jmespath="messages[0].content", + ) + response = endpoint.invoke(loaded) + + # Verify the endpoint successfully processed the loaded payload + assert response.response_text is not None, "Response should contain text" + assert len(response.response_text) > 0, "Response text should not be empty" + assert response.error is None, ( + f"Response should not contain errors: {response.error}" + ) diff --git a/tests/integ/test_openai_bedrock.py b/tests/integ/test_openai_bedrock.py index 63c8902..974f317 100644 --- a/tests/integ/test_openai_bedrock.py +++ b/tests/integ/test_openai_bedrock.py @@ -92,25 +92,25 @@ def test_openai_bedrock_non_streaming( # Verify response contains non-empty text assert response.choices is not None, "Response choices should not be None" assert len(response.choices) > 0, "Response should have at least one choice" - assert ( - response.choices[0].message.content is not None - ), "Response content should not be None" - assert ( - len(response.choices[0].message.content) > 0 - ), "Response content should not be empty" - assert isinstance( - response.choices[0].message.content, str - ), "Response content should be a string" + assert response.choices[0].message.content is not None, ( + "Response content should not be None" + ) + assert len(response.choices[0].message.content) > 0, ( + "Response content should not be empty" + ) + assert isinstance(response.choices[0].message.content, str), ( + "Response content should be a string" + ) # Verify token counts are present and positive assert response.usage is not None, "Response usage should not be None" - assert ( - response.usage.prompt_tokens is not None - ), "Input token count should not be None" + assert response.usage.prompt_tokens is not None, ( + "Input token count should not be None" + ) assert response.usage.prompt_tokens > 0, "Input token count should be positive" - assert ( - response.usage.completion_tokens is not None - ), "Output token count should not be None" + assert response.usage.completion_tokens is not None, ( + "Output token count should not be None" + ) assert response.usage.completion_tokens > 0, "Output token count should be positive" # Verify response time is measured and positive @@ -225,9 +225,213 @@ def test_openai_bedrock_streaming( assert time_to_last_token > 0, "Time to last token should be positive" # Verify TTLT > TTFT (streaming should take time to complete) - assert ( - time_to_last_token > time_to_first_token - ), "Time to last token should be greater than time to first token" + assert time_to_last_token > time_to_first_token, ( + "Time to last token should be greater than time to first token" + ) # Verify response has an ID assert response_id is not None, "Response should have an ID" + + +def test_save_load_openai_payload_with_image_url(tmp_path): + """ + Test saving and loading OpenAI payload with image_url data URI. + + This test validates that: + - OpenAI payloads with image_url data URIs can be saved to disk + - Saved payloads contain valid JSON with marker objects + - Loaded payloads restore bytes objects correctly in image_url.url field + - Round-trip preserves byte-for-byte equality + + **Validates: Requirements 9.6, 9.7** + + Args: + tmp_path: Temporary directory for test files (from pytest). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + import base64 + + # Create a small test image (1x1 red pixel PNG) + test_image_bytes = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + ) + + # Create OpenAI chat.completions payload with image_url data URI + openai_payload = { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": { + "url": test_image_bytes # Store as bytes for serialization + }, + }, + ], + } + ], + "max_tokens": 100, + } + + # Save payload with image_url + output_file = save_payloads(openai_payload, tmp_path, "test_openai_image.jsonl") + assert output_file.exists(), "Output file should be created" + + # Verify file contains valid JSON with marker objects + with output_file.open("r") as f: + content = f.read() + assert "__llmeter_bytes__" in content, "File should contain marker objects" + + # Load payload and verify bytes are restored + loaded_payloads = list(load_payloads(output_file)) + assert len(loaded_payloads) == 1, "Should load exactly one payload" + + loaded = loaded_payloads[0] + original_bytes = openai_payload["messages"][0]["content"][1]["image_url"]["url"] + loaded_bytes = loaded["messages"][0]["content"][1]["image_url"]["url"] + + assert isinstance(loaded_bytes, bytes), "Loaded bytes should be bytes type" + assert loaded_bytes == original_bytes, "Bytes should match after round-trip" + assert loaded == openai_payload, "Full payload should match after round-trip" + + +@pytest.mark.integ +@pytest.mark.skipif(not OPENAI_AVAILABLE, reason="OpenAI SDK not installed") +def test_save_load_openai_complete_structure( + tmp_path, aws_credentials, aws_region, bedrock_openai_multimodal_test_model, + bedrock_openai_multimodal_endpoint_url, test_image_bytes +): + """ + Test round-trip with actual OpenAI chat.completions structure and API call. + + This test validates that: + - Complete OpenAI chat.completions payloads serialize correctly + - All typical OpenAI fields are preserved + - Multiple content items with mixed text and images work correctly + - The messages[].content[].image_url.url path is handled correctly + - Binary image data is preserved through save/load cycle + - Loaded payload can be used with the OpenAI client for multimodal requests + + **Validates: Requirements 9.6, 9.7** + + Args: + tmp_path: Temporary directory for test files (from pytest). + aws_credentials: Boto3 session with valid AWS credentials (from fixture). + aws_region: AWS region for testing (from fixture). + bedrock_openai_multimodal_test_model: Model ID for OpenAI SDK multimodal testing (from fixture). + bedrock_openai_multimodal_endpoint_url: Bedrock Mantle endpoint URL for multimodal (from fixture). + test_image_bytes: Tuple of test images as binary data (from fixture). + + AWS Permissions Required: + - bedrock:InvokeModel + + Estimated Cost: + ~$0.0002 per test run (using Qwen3-VL with minimal tokens) + """ + from llmeter.prompt_utils import save_payloads, load_payloads + import base64 + + # Get test images as binary data + image1_binary, image2_binary = test_image_bytes + + # Create complete OpenAI payload with images stored as binary bytes + # Note: We store raw bytes in the payload, which will be serialized to base64 for storage + complete_payload = { + "model": bedrock_openai_multimodal_test_model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe these images briefly:"}, + { + "type": "image_url", + "image_url": {"url": image1_binary}, # Binary bytes + }, + { + "type": "image_url", + "image_url": {"url": image2_binary}, # Binary bytes + }, + ], + }, + ], + "max_tokens": 100, + } + + # Save and load the complete payload + output_file = save_payloads( + complete_payload, tmp_path, "test_openai_complete.jsonl" + ) + loaded_payloads = list(load_payloads(output_file)) + + assert len(loaded_payloads) == 1, "Should load exactly one payload" + loaded = loaded_payloads[0] + + # Verify structure is preserved + assert loaded["model"] == complete_payload["model"] + assert len(loaded["messages"]) == 1 + assert loaded["messages"][0]["role"] == "user" + assert len(loaded["messages"][0]["content"]) == 3 + + # Verify image bytes are restored correctly from serialization + loaded_image1_binary = loaded["messages"][0]["content"][1]["image_url"]["url"] + loaded_image2_binary = loaded["messages"][0]["content"][2]["image_url"]["url"] + + assert isinstance(loaded_image1_binary, bytes), "First image should be bytes" + assert isinstance(loaded_image2_binary, bytes), "Second image should be bytes" + assert loaded_image1_binary == image1_binary, "First image bytes should match" + assert loaded_image2_binary == image2_binary, "Second image bytes should match" + + # Verify full round-trip equality + assert loaded == complete_payload, "Full payload should match after round-trip" + + # Verify the loaded payload can be used with the OpenAI client + # Note: OpenAI API expects base64-encoded data URIs (ASCII strings), not raw bytes + # Multimodal models require the Bedrock Mantle endpoint, not bedrock-runtime + token = provide_token(region=aws_region) + client = OpenAI(api_key=token, base_url=bedrock_openai_multimodal_endpoint_url) + + # Convert binary bytes to base64-encoded ASCII strings for API call + image1_base64_ascii = base64.b64encode(loaded_image1_binary).decode('utf-8') + image2_base64_ascii = base64.b64encode(loaded_image2_binary).decode('utf-8') + + # Build API payload - data URI format with JPEG MIME type + api_payload = { + "model": bedrock_openai_multimodal_test_model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe these images briefly:"}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image1_base64_ascii}", + }, + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image2_base64_ascii}", + }, + }, + ], + }, + ], + "max_tokens": 100, + } + + # Invoke the API with the loaded multimodal payload + response = client.chat.completions.create(**api_payload) + + # Verify the client successfully processed the loaded payload + assert response.choices is not None, "Response should contain choices" + assert len(response.choices) > 0, "Response should have at least one choice" + assert response.choices[0].message.content is not None, ( + "Response should contain text" + ) + assert len(response.choices[0].message.content) > 0, ( + "Response text should not be empty" + ) diff --git a/tests/unit/callbacks/cost/test_model.py b/tests/unit/callbacks/cost/test_model.py index b1ea1a8..6eed5bd 100644 --- a/tests/unit/callbacks/cost/test_model.py +++ b/tests/unit/callbacks/cost/test_model.py @@ -25,6 +25,7 @@ def test_cost_model_serialization(): assert model.run_dims["ComputeSeconds"].price_per_hour == 50 assert model.to_dict() == { "_type": "CostModel", + "_callback_type": "llmeter.callbacks.cost.model:CostModel", "request_dims": { "TokensIn": { "_type": "InputTokens", diff --git a/tests/unit/callbacks/test_base.py b/tests/unit/callbacks/test_base.py index b3aa824..c236350 100644 --- a/tests/unit/callbacks/test_base.py +++ b/tests/unit/callbacks/test_base.py @@ -1,51 +1,67 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import json import pytest from llmeter.callbacks.base import Callback -class TestBase: - def test__load_from_file_not_implemented(self): - """ - Test that _load_from_file raises NotImplementedError. - """ - with pytest.raises(NotImplementedError): - Callback._load_from_file("valid_path.json") - - def test_load_from_file_not_implemented(self): - """ - Test that load_from_file raises a NotImplementedError as it's not yet implemented. - """ - with pytest.raises(NotImplementedError): - Callback.load_from_file("valid_path.txt") - - def test_load_from_file_raises_not_implemented_error(self): - """ - Test that Callback.load_from_file raises a NotImplementedError. - - This test verifies that calling the static method load_from_file - on the Callback class raises a NotImplementedError, as the method - is not yet implemented. - """ - with pytest.raises(NotImplementedError) as excinfo: - Callback.load_from_file("dummy_path") +class _DummyCallback(Callback): + """A minimal concrete callback for testing serialization.""" - assert ( - str(excinfo.value) - == "TODO: Callback.load_from_file is not yet implemented!" - ) + def __init__(self, alpha: int = 1, beta: str = "hello"): + self.alpha = alpha + self.beta = beta - def test_load_from_file_raises_not_implemented_error_2(self): - """ - Test that _load_from_file raises NotImplementedError when called. - This method is not yet implemented in the base Callback class. - """ - with pytest.raises(NotImplementedError) as excinfo: - Callback._load_from_file("dummy_path") +class TestCallbackSerialization: + def test_to_dict_includes_callback_type(self): + cb = _DummyCallback(alpha=42, beta="world") + d = cb.to_dict() assert ( - str(excinfo.value) - == "TODO: Callback._load_from_file is not yet implemented!" + d["_callback_type"] + == f"{_DummyCallback.__module__}:{_DummyCallback.__qualname__}" ) + assert d["alpha"] == 42 + assert d["beta"] == "world" + + def test_to_dict_excludes_private_attrs(self): + cb = _DummyCallback() + cb._internal = "secret" + d = cb.to_dict() + assert "_internal" not in d + + def test_from_dict_round_trip(self): + cb = _DummyCallback(alpha=7, beta="test") + d = cb.to_dict() + restored = Callback.from_dict(d) + assert isinstance(restored, _DummyCallback) + assert restored.alpha == 7 + assert restored.beta == "test" + + def test_from_dict_missing_callback_type_raises(self): + with pytest.raises(ValueError, match="_callback_type"): + Callback.from_dict({"alpha": 1}) + + def test_to_json_round_trip(self): + cb = _DummyCallback(alpha=99) + json_str = cb.to_json() + restored = Callback.from_json(json_str) + assert isinstance(restored, _DummyCallback) + assert restored.alpha == 99 + + def test_save_and_load_from_file(self, tmp_path): + cb = _DummyCallback(alpha=5, beta="file_test") + file_path = tmp_path / "callback.json" + cb.save_to_file(file_path) + + loaded = Callback.load_from_file(file_path) + assert isinstance(loaded, _DummyCallback) + assert loaded.alpha == 5 + assert loaded.beta == "file_test" + + # Verify the file is valid JSON with the type marker + with open(file_path) as f: + data = json.load(f) + assert "_callback_type" in data diff --git a/tests/unit/endpoints/test_bedrock_multimodal.py b/tests/unit/endpoints/test_bedrock_multimodal.py new file mode 100644 index 0000000..c3a9a06 --- /dev/null +++ b/tests/unit/endpoints/test_bedrock_multimodal.py @@ -0,0 +1,247 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pytest + +from llmeter.endpoints.bedrock import BedrockBase + + +class TestBedrockMultiModal: + """Test multi-modal functionality for Bedrock endpoints.""" + + def test_create_payload_single_image_from_file(self): + """Test creating payload with single image from file path.""" + # Create a temporary image file + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(b"\xff\xd8\xff\xe0") # JPEG magic bytes + temp_path = f.name + + try: + payload = BedrockBase.create_payload( + user_message="What's in this image?", images=[temp_path], max_tokens=256 + ) + + assert "messages" in payload + assert len(payload["messages"]) == 1 + assert payload["messages"][0]["role"] == "user" + + content = payload["messages"][0]["content"] + assert len(content) == 2 # text + image + assert content[0]["text"] == "What's in this image?" + assert "image" in content[1] + assert content[1]["image"]["format"] == "jpeg" + assert "bytes" in content[1]["image"]["source"] + + finally: + Path(temp_path).unlink() + + def test_create_payload_single_image_from_bytes(self): + """Test creating payload with single image from bytes.""" + # Create a minimal valid JPEG file + # JPEG structure: SOI (FFD8) + APP0 marker + minimal data + EOI (FFD9) + jpeg_bytes = ( + b"\xff\xd8" # SOI (Start of Image) + b"\xff\xe0" # APP0 marker + b"\x00\x10" # APP0 length (16 bytes) + b"JFIF\x00" # JFIF identifier + b"\x01\x01" # JFIF version 1.1 + b"\x00" # density units (0 = no units) + b"\x00\x01" # X density = 1 + b"\x00\x01" # Y density = 1 + b"\x00\x00" # thumbnail width and height = 0 + b"\xff\xd9" # EOI (End of Image) + ) + + try: + payload = BedrockBase.create_payload( + user_message="What's in this image?", + images=[jpeg_bytes], + max_tokens=256, + ) + + assert "messages" in payload + content = payload["messages"][0]["content"] + assert len(content) == 2 # text + image + assert "image" in content[1] + assert content[1]["image"]["format"] == "jpeg" + assert content[1]["image"]["source"]["bytes"] == jpeg_bytes + except ValueError as e: + # If puremagic can't detect the format, skip this test + if "Cannot detect format from bytes" in str(e): + pytest.skip("puremagic cannot detect format from minimal JPEG bytes") + raise + + def test_create_payload_multiple_images(self): + """Test creating payload with multiple images.""" + # Create temporary image files + temp_files = [] + for i in range(2): + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n\x1a\n") # PNG magic bytes + temp_files.append(f.name) + + try: + payload = BedrockBase.create_payload( + user_message="Compare these images", images=temp_files, max_tokens=512 + ) + + content = payload["messages"][0]["content"] + assert len(content) == 3 # text + 2 images + assert content[0]["text"] == "Compare these images" + assert "image" in content[1] + assert "image" in content[2] + + finally: + for path in temp_files: + Path(path).unlink() + + def test_create_payload_mixed_content(self): + """Test creating payload with mixed content types.""" + # Create temporary files + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_file: + img_file.write(b"\xff\xd8\xff\xe0") + img_path = img_file.name + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as doc_file: + doc_file.write(b"%PDF-1.4") + doc_path = doc_file.name + + try: + payload = BedrockBase.create_payload( + user_message="Analyze this", + images=[img_path], + documents=[doc_path], + max_tokens=1024, + ) + + content = payload["messages"][0]["content"] + assert len(content) == 3 # text + image + document + assert content[0]["text"] == "Analyze this" + assert "image" in content[1] + assert content[1]["image"]["format"] == "jpeg" + assert "document" in content[2] + assert content[2]["document"]["format"] == "pdf" + + finally: + Path(img_path).unlink() + Path(doc_path).unlink() + + def test_create_payload_text_only_backward_compatible(self): + """Test that text-only payloads still work (backward compatibility).""" + payload = BedrockBase.create_payload( + user_message="Hello, world!", max_tokens=256 + ) + + assert "messages" in payload + content = payload["messages"][0]["content"] + assert len(content) == 1 + assert content[0]["text"] == "Hello, world!" + + def test_create_payload_empty_media_lists(self): + """Test that empty media lists are handled correctly.""" + payload = BedrockBase.create_payload( + user_message="Hello", images=[], documents=None, max_tokens=256 + ) + + # Should behave like text-only + content = payload["messages"][0]["content"] + assert len(content) == 1 + assert content[0]["text"] == "Hello" + + def test_create_payload_invalid_image_type(self): + """Test that invalid image types raise TypeError.""" + with pytest.raises( + TypeError, match="Items in images list must be bytes or str" + ): + BedrockBase.create_payload( + user_message="Test", + images=[123], # Invalid type + max_tokens=256, + ) + + def test_create_payload_invalid_images_not_list(self): + """Test that non-list images parameter raises TypeError.""" + with pytest.raises(TypeError, match="images must be a list"): + BedrockBase.create_payload( + user_message="Test", images="not_a_list", max_tokens=256 + ) + + def test_create_payload_missing_file(self): + """Test that missing file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File not found"): + BedrockBase.create_payload( + user_message="Test", images=["/nonexistent/file.jpg"], max_tokens=256 + ) + + def test_create_payload_bytes_without_puremagic(self): + """Test that bytes without detectable format raises ValueError.""" + # Random bytes that don't match any known format + random_bytes = b"\x00\x01\x02\x03\x04\x05" + + with pytest.raises(ValueError, match="Cannot detect format from bytes"): + BedrockBase.create_payload( + user_message="Test", images=[random_bytes], max_tokens=256 + ) + + def test_create_payload_file_without_extension(self): + """Test that file without recognized extension raises ValueError.""" + with tempfile.NamedTemporaryFile(suffix="", delete=False) as f: + f.write(b"some content") + temp_path = f.name + + try: + with pytest.raises( + ValueError, match="(Cannot detect format|Unsupported MIME type)" + ): + BedrockBase.create_payload( + user_message="Test", images=[temp_path], max_tokens=256 + ) + finally: + Path(temp_path).unlink() + + def test_create_payload_content_ordering(self): + """Test that content blocks are ordered correctly: text, images, videos, audio, documents.""" + # Create temporary files for each media type with proper extensions + # puremagic may not detect these minimal magic bytes, so we rely on extensions + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img: + img.write( + b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9" + ) + img_path = img.name + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as vid: + vid.write(b"\x00\x00\x00\x20ftypisom\x00\x00\x02\x00isomiso2mp41") + vid_path = vid.name + + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as aud: + aud.write(b"ID3\x03\x00\x00\x00\x00\x00\x00") + aud_path = aud.name + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as doc: + doc.write(b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n") + doc_path = doc.name + + try: + payload = BedrockBase.create_payload( + user_message="Analyze all", + images=[img_path], + videos=[vid_path], + audio=[aud_path], + documents=[doc_path], + max_tokens=1024, + ) + + content = payload["messages"][0]["content"] + assert len(content) == 5 # text + 4 media types + assert "text" in content[0] + assert "image" in content[1] + assert "video" in content[2] + assert "audio" in content[3] + assert "document" in content[4] + + finally: + for path in [img_path, vid_path, aud_path, doc_path]: + Path(path).unlink() diff --git a/tests/unit/endpoints/test_multimodal_properties.py b/tests/unit/endpoints/test_multimodal_properties.py new file mode 100644 index 0000000..8d8736a --- /dev/null +++ b/tests/unit/endpoints/test_multimodal_properties.py @@ -0,0 +1,682 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Property-based tests for multi-modal payload functionality. + +This module contains property-based tests using Hypothesis to validate +the correctness properties defined in the multi-modal payload design document. +Each test validates universal properties that should hold across all valid inputs. +""" + +import json +import tempfile +from pathlib import Path + +import pytest +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st + +from llmeter.endpoints.bedrock import BedrockBase +from llmeter.json_utils import LLMeterEncoder, llmeter_bytes_decoder +from llmeter.prompt_utils import ( + load_payloads, + save_payloads, +) + + +# Test file content generators +@st.composite +def image_bytes(draw): + """Generate valid JPEG image bytes with magic bytes.""" + # JPEG magic bytes: FF D8 FF E0 + jpeg_header = b"\xff\xd8\xff\xe0" + # Add some random content + content_size = draw(st.integers(min_value=10, max_value=100)) + content = draw(st.binary(min_size=content_size, max_size=content_size)) + # JPEG end marker: FF D9 + jpeg_footer = b"\xff\xd9" + return jpeg_header + content + jpeg_footer + + +@st.composite +def png_bytes(draw): + """Generate valid PNG image bytes with magic bytes.""" + # PNG magic bytes: 89 50 4E 47 0D 0A 1A 0A + png_header = b"\x89PNG\r\n\x1a\n" + content_size = draw(st.integers(min_value=10, max_value=100)) + content = draw(st.binary(min_size=content_size, max_size=content_size)) + return png_header + content + + +@st.composite +def pdf_bytes(draw): + """Generate valid PDF bytes with magic bytes.""" + # PDF magic bytes: %PDF- + pdf_header = b"%PDF-1.4\n" + content_size = draw(st.integers(min_value=10, max_value=100)) + content = draw(st.binary(min_size=content_size, max_size=content_size)) + pdf_footer = b"\n%%EOF" + return pdf_header + content + pdf_footer + + +@st.composite +def mp4_bytes(draw): + """Generate valid MP4 video bytes with magic bytes.""" + # MP4 magic bytes typically start with ftyp + mp4_header = b"\x00\x00\x00\x20ftypisom" + content_size = draw(st.integers(min_value=10, max_value=100)) + content = draw(st.binary(min_size=content_size, max_size=content_size)) + return mp4_header + content + + +@st.composite +def mp3_bytes(draw): + """Generate valid MP3 audio bytes with magic bytes.""" + # MP3 magic bytes: ID3 or FF FB + mp3_header = b"ID3\x03\x00\x00" + content_size = draw(st.integers(min_value=10, max_value=100)) + content = draw(st.binary(min_size=content_size, max_size=content_size)) + return mp3_header + content + + +# Strategy for generating media bytes +media_bytes_strategy = st.one_of( + image_bytes(), + png_bytes(), + pdf_bytes(), + mp4_bytes(), + mp3_bytes(), +) + + +# Strategy for generating file extensions +image_extensions = st.sampled_from([".jpg", ".jpeg", ".png", ".gif", ".webp"]) +video_extensions = st.sampled_from([".mp4", ".mov", ".avi"]) +audio_extensions = st.sampled_from([".mp3", ".wav", ".ogg"]) +document_extensions = st.sampled_from([".pdf"]) + + +def create_temp_file(tmp_path: Path, content: bytes, extension: str) -> str: + """Create a temporary file with given content and extension.""" + file_path = tmp_path / f"test_file_{id(content)}{extension}" + file_path.write_bytes(content) + return str(file_path) + + +# Property-based tests + + +@given( + file_content=image_bytes(), + extension=image_extensions, +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_1_file_path_content_inclusion(file_content, extension): + """Property 1: File path content inclusion. + + **Validates: Requirements 1.2, 2.2, 3.2, 4.2** + + For any media type (image, video, audio, document) and valid file path, + when creating a payload with that file path in the corresponding parameter list, + the resulting payload SHALL contain a content block with the file's bytes and + the format detected from the file extension. + """ + # Create temporary file and use it within the same context + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + file_path = create_temp_file(tmp_path, file_content, extension) + + # Create payload with the file + payload = BedrockBase.create_payload( + user_message="Test message", images=[file_path], max_tokens=256 + ) + + # Verify payload structure + assert "messages" in payload + assert len(payload["messages"]) == 1 + assert "content" in payload["messages"][0] + + content_blocks = payload["messages"][0]["content"] + + # Find the image content block + image_blocks = [block for block in content_blocks if "image" in block] + assert len(image_blocks) == 1 + + # Verify the image block contains the file's bytes + image_block = image_blocks[0] + assert "image" in image_block + assert "source" in image_block["image"] + assert "bytes" in image_block["image"]["source"] + assert image_block["image"]["source"]["bytes"] == file_content + + # Verify format is detected + assert "format" in image_block["image"] + assert image_block["image"]["format"] in ["jpeg", "png", "gif", "webp"] + + +@given( + num_images=st.integers(min_value=1, max_value=5), + file_content=image_bytes(), +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_2_multiple_items_preservation(num_images, file_content): + """Property 2: Multiple items preservation. + + **Validates: Requirements 1.3, 2.3, 3.3, 4.3** + + For any media type and list of file paths, when creating a payload, + the resulting payload SHALL contain exactly as many content blocks of that + media type as there are items in the list, preserving their count. + """ + # Create multiple temporary files + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + file_paths = [] + for i in range(num_images): + file_path = create_temp_file(tmp_path, file_content, ".jpg") + file_paths.append(file_path) + + # Create payload with multiple images + payload = BedrockBase.create_payload( + user_message="Test message", images=file_paths, max_tokens=256 + ) + + # Verify payload structure + content_blocks = payload["messages"][0]["content"] + + # Count image blocks + image_blocks = [block for block in content_blocks if "image" in block] + + # Verify count matches input + assert len(image_blocks) == num_images + + +@given( + num_texts=st.integers(min_value=1, max_value=3), + num_images=st.integers(min_value=1, max_value=3), + num_videos=st.integers(min_value=0, max_value=2), + num_audio=st.integers(min_value=0, max_value=2), + num_documents=st.integers(min_value=0, max_value=2), +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_3_content_ordering_preservation( + num_texts, num_images, num_videos, num_audio, num_documents +): + """Property 3: Content ordering preservation. + + **Validates: Requirements 1.4** + + For any combination of text messages and media items, when creating a payload, + the resulting content array SHALL preserve the order: text blocks first (in order), + then media blocks (in the order: images, videos, audio, documents). + """ + # Create text messages + text_messages = [f"Text {i}" for i in range(num_texts)] + + # Create media files + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + images = [ + create_temp_file(tmp_path, b"\xff\xd8\xff\xe0test\xff\xd9", ".jpg") + for _ in range(num_images) + ] + videos = [ + create_temp_file(tmp_path, b"\x00\x00\x00\x20ftypisomtest", ".mp4") + for _ in range(num_videos) + ] + audio = [ + create_temp_file(tmp_path, b"ID3\x03\x00\x00test", ".mp3") + for _ in range(num_audio) + ] + documents = [ + create_temp_file(tmp_path, b"%PDF-1.4\ntest\n%%EOF", ".pdf") + for _ in range(num_documents) + ] + + # Create payload + payload = BedrockBase.create_payload( + user_message=text_messages if len(text_messages) > 1 else text_messages[0], + images=images if images else None, + videos=videos if videos else None, + audio=audio if audio else None, + documents=documents if documents else None, + max_tokens=256, + ) + + # Verify ordering + content_blocks = payload["messages"][0]["content"] + + # Extract block types in order + block_types = [] + for block in content_blocks: + if "text" in block: + block_types.append("text") + elif "image" in block: + block_types.append("image") + elif "video" in block: + block_types.append("video") + elif "audio" in block: + block_types.append("audio") + elif "document" in block: + block_types.append("document") + + # Verify text blocks come first + text_count = block_types.count("text") + assert text_count == num_texts + assert block_types[:text_count] == ["text"] * text_count + + # Verify media blocks follow in order: images, videos, audio, documents + media_blocks = block_types[text_count:] + expected_order = ( + ["image"] * num_images + + ["video"] * num_videos + + ["audio"] * num_audio + + ["document"] * num_documents + ) + assert media_blocks == expected_order + + +def valid_file_path_string(s: str) -> bool: + """Check if a string is a valid file path (no null bytes, not too long).""" + if not s or len(s) > 255: + return False + if "\x00" in s: + return False + # Check if it's a valid path that doesn't exist + try: + path = Path(s) + return not path.exists() + except (ValueError, OSError): + return False + + +@given( + non_existent_path=st.text(min_size=1, max_size=50).filter(valid_file_path_string) +) +@settings(max_examples=100) +def test_property_4_missing_file_error_handling(non_existent_path): + """Property 4: Missing file error handling. + + **Validates: Requirements 5.2** + + For any non-existent file path provided in any media parameter, + attempting to create a payload SHALL raise a FileNotFoundError + with a message containing the file path. + """ + with pytest.raises(FileNotFoundError) as exc_info: + BedrockBase.create_payload( + user_message="Test message", images=[non_existent_path], max_tokens=256 + ) + + # Verify error message contains the file path + assert non_existent_path in str(exc_info.value) + + +@given( + invalid_item=st.one_of( + st.integers(), + st.floats(), + st.dictionaries(st.text(), st.text()), + st.lists(st.integers()), + ) +) +@settings(max_examples=100) +def test_property_5_invalid_type_rejection(invalid_item): + """Property 5: Invalid type rejection. + + **Validates: Requirements 5.4** + + For any media parameter list containing items that are neither bytes nor strings, + attempting to create a payload SHALL raise a TypeError with a descriptive message. + """ + with pytest.raises(TypeError) as exc_info: + BedrockBase.create_payload( + user_message="Test message", images=[invalid_item], max_tokens=256 + ) + + # Verify error message is descriptive + error_msg = str(exc_info.value) + assert "images" in error_msg + assert "bytes" in error_msg or "str" in error_msg + + +@given( + user_message=st.one_of( + st.text(min_size=1, max_size=100), + st.lists(st.text(min_size=1, max_size=50), min_size=1, max_size=3), + ), + max_tokens=st.integers(min_value=1, max_value=4096), +) +@settings(max_examples=100) +def test_property_6_backward_compatibility(user_message, max_tokens): + """Property 6: Backward compatibility for text-only payloads. + + **Validates: Requirements 6.1, 6.2, 6.4** + + For any text input provided using the user_message parameter (without media parameters), + the new create_payload implementation SHALL produce output identical to the current + implementation, maintaining the same structure and field values. + """ + # Create payload with text only + payload = BedrockBase.create_payload( + user_message=user_message, max_tokens=max_tokens + ) + + # Verify expected structure + assert "messages" in payload + assert "inferenceConfig" in payload + assert payload["inferenceConfig"]["maxTokens"] == max_tokens + + # Verify messages structure + messages = payload["messages"] + if isinstance(user_message, str): + assert len(messages) == 1 + assert messages[0]["role"] == "user" + assert len(messages[0]["content"]) == 1 + assert messages[0]["content"][0]["text"] == user_message + else: + assert len(messages) == len(user_message) + for i, msg in enumerate(messages): + assert msg["role"] == "user" + assert len(msg["content"]) == 1 + assert msg["content"][0]["text"] == user_message[i] + + +@given( + max_tokens=st.integers(min_value=1, max_value=4096), + extra_kwargs=st.dictionaries( + st.text(min_size=1, max_size=20).filter( + lambda x: ( + x + not in [ + "messages", + "inferenceConfig", + "images", + "documents", + "videos", + "audio", + "user_message", + "max_tokens", + ] + ) + ), + st.one_of(st.text(), st.integers(), st.booleans()), + min_size=0, + max_size=3, + ), +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_7_existing_parameters_preservation(max_tokens, extra_kwargs): + """Property 7: Existing parameters preservation. + + **Validates: Requirements 6.3** + + For any values of max_tokens and additional kwargs, when creating a payload + (text-only or multi-modal), these parameters SHALL be preserved in the + resulting payload structure exactly as they are in the current implementation. + """ + # Create a test image file + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + image_path = create_temp_file(tmp_path, b"\xff\xd8\xff\xe0test\xff\xd9", ".jpg") + + # Create payload with extra kwargs + payload = BedrockBase.create_payload( + user_message="Test message", + images=[image_path], + max_tokens=max_tokens, + **extra_kwargs, + ) + + # Verify max_tokens is preserved + assert payload["inferenceConfig"]["maxTokens"] == max_tokens + + # Verify extra kwargs are preserved + for key, value in extra_kwargs.items(): + assert key in payload + assert payload[key] == value + + +@given( + image_content=image_bytes(), +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_8_serialization_round_trip(image_content): + """Property 8: Multi-modal payload serialization round-trip. + + **Validates: Requirements 7.1, 7.2, 7.3, 7.4, 7.5, 9.2, 9.3** + + For any valid multi-modal payload containing binary content, the round-trip + property SHALL hold: load_payloads(save_payloads(payload)) == payload, + preserving byte-for-byte equality of all binary content and exact equality + of all other values. + """ + # Create a test image file + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + image_path = create_temp_file(tmp_path, image_content, ".jpg") + + # Create payload with binary content + original_payload = BedrockBase.create_payload( + user_message="Test message", images=[image_path], max_tokens=256 + ) + + # Save and load the payload + save_path = save_payloads(original_payload, tmp_dir, "test_payload.jsonl") + loaded_payloads = list(load_payloads(save_path)) + + # Verify round-trip preservation + assert len(loaded_payloads) == 1 + loaded_payload = loaded_payloads[0] + + # Verify structure equality + assert loaded_payload == original_payload + + # Verify binary content is preserved byte-for-byte + original_bytes = original_payload["messages"][0]["content"][1]["image"][ + "source" + ]["bytes"] + loaded_bytes = loaded_payload["messages"][0]["content"][1]["image"]["source"][ + "bytes" + ] + assert original_bytes == loaded_bytes + assert isinstance(loaded_bytes, bytes) + + +@given( + extension=st.sampled_from([".jpg", ".png", ".pdf", ".mp4", ".mp3"]), +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_9_format_detection(extension): + """Property 9: Format detection from file content. + + **Validates: Requirements 8.5** + + For any file with recognizable content (JPEG, PNG, PDF, MP4, etc.), + when creating a payload, the format SHALL be correctly detected from + the file's magic bytes using puremagic (or from extension as fallback) + and included in the content block. + """ + # Generate appropriate content for the extension + if extension in [".jpg", ".jpeg"]: + file_content = b"\xff\xd8\xff\xe0test\xff\xd9" + media_param = "images" + media_key = "image" + elif extension == ".png": + file_content = b"\x89PNG\r\n\x1a\ntest" + media_param = "images" + media_key = "image" + elif extension == ".pdf": + file_content = b"%PDF-1.4\ntest\n%%EOF" + media_param = "documents" + media_key = "document" + elif extension == ".mp4": + file_content = b"\x00\x00\x00\x20ftypisomtest" + media_param = "videos" + media_key = "video" + elif extension == ".mp3": + file_content = b"ID3\x03\x00\x00test" + media_param = "audio" + media_key = "audio" + + # Create temporary file and use it within the same context + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + file_path = create_temp_file(tmp_path, file_content, extension) + + # Create payload + kwargs = {media_param: [file_path]} + payload = BedrockBase.create_payload( + user_message="Test message", max_tokens=256, **kwargs + ) + + # Find the media content block + content_blocks = payload["messages"][0]["content"] + media_blocks = [block for block in content_blocks if media_key in block] + assert len(media_blocks) == 1 + + # Verify format is detected and included + media_block = media_blocks[0] + assert "format" in media_block[media_key] + assert isinstance(media_block[media_key]["format"], str) + assert len(media_block[media_key]["format"]) > 0 + + +@given( + num_prompts=st.integers(min_value=1, max_value=5), +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_10_load_prompts_integration(num_prompts): + """Property 10: load_prompts integration with multi-modal payloads. + + **Validates: Requirements 9.1** + + For any create_payload call that produces multi-modal payloads, + when used with load_prompts, the function SHALL yield valid multi-modal + payloads that can be serialized and used with endpoint invoke methods. + """ + # Create a test image file + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + image_path = create_temp_file(tmp_path, b"\xff\xd8\xff\xe0test\xff\xd9", ".jpg") + + # Create a prompts file + prompts_file = tmp_path / "prompts.txt" + prompts = [f"Prompt {i}" for i in range(num_prompts)] + prompts_file.write_text("\n".join(prompts)) + + # Define create_payload function that produces multi-modal payloads + def create_multimodal_payload(input_text, **kwargs): + return BedrockBase.create_payload( + user_message=input_text, images=[image_path], max_tokens=256, **kwargs + ) + + # Use load_prompts with the multi-modal create_payload function + from llmeter.prompt_utils import load_prompts + + payloads = list( + load_prompts( + prompts_file, + create_payload_fn=create_multimodal_payload, + create_payload_kwargs={}, + ) + ) + + # Verify correct number of payloads + assert len(payloads) == num_prompts + + # Verify each payload is valid and contains multi-modal content + for i, payload in enumerate(payloads): + assert "messages" in payload + content_blocks = payload["messages"][0]["content"] + + # Should have text and image blocks + text_blocks = [block for block in content_blocks if "text" in block] + image_blocks = [block for block in content_blocks if "image" in block] + + assert len(text_blocks) >= 1 + assert len(image_blocks) == 1 + + # Verify the payload can be serialized + json_str = json.dumps(payload, cls=LLMeterEncoder) + assert len(json_str) > 0 + + # Verify it can be deserialized + deserialized = json.loads(json_str, object_hook=llmeter_bytes_decoder) + assert deserialized == payload + + +@given( + extra_kwargs=st.dictionaries( + st.text(min_size=1, max_size=20).filter( + lambda x: ( + x + not in [ + "messages", + "inferenceConfig", + "user_message", + "max_tokens", + "images", + "documents", + "videos", + "audio", + ] + ) + ), + st.one_of(st.text(min_size=1), st.integers(), st.booleans()), + min_size=1, + max_size=3, + ) +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_11_create_payload_kwargs_compatibility(extra_kwargs): + """Property 11: create_payload_kwargs pattern compatibility. + + **Validates: Requirements 9.4** + + For any dictionary of additional parameters passed via create_payload_kwargs, + when used with load_prompts or directly with create_payload, these parameters + SHALL be passed through and included in the resulting payload. + """ + # Create a test image file + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + image_path = create_temp_file(tmp_path, b"\xff\xd8\xff\xe0test\xff\xd9", ".jpg") + + # Create payload with extra kwargs + payload = BedrockBase.create_payload( + user_message="Test message", + images=[image_path], + max_tokens=256, + **extra_kwargs, + ) + + # Verify all extra kwargs are present in the payload + for key, value in extra_kwargs.items(): + assert key in payload + assert payload[key] == value + + # Test with load_prompts pattern + prompts_file = tmp_path / "prompts.txt" + prompts_file.write_text("Test prompt") + + def create_multimodal_payload(input_text, **kwargs): + return BedrockBase.create_payload( + user_message=input_text, images=[image_path], max_tokens=256, **kwargs + ) + + from llmeter.prompt_utils import load_prompts + + payloads = list( + load_prompts( + prompts_file, + create_payload_fn=create_multimodal_payload, + create_payload_kwargs=extra_kwargs, + ) + ) + + # Verify kwargs are passed through + assert len(payloads) == 1 + for key, value in extra_kwargs.items(): + assert key in payloads[0] + assert payloads[0][key] == value diff --git a/tests/unit/endpoints/test_multimodal_serialization.py b/tests/unit/endpoints/test_multimodal_serialization.py new file mode 100644 index 0000000..0d353fb --- /dev/null +++ b/tests/unit/endpoints/test_multimodal_serialization.py @@ -0,0 +1,253 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + + +from llmeter.endpoints.bedrock import BedrockBase +from llmeter.endpoints.openai import OpenAIEndpoint +from llmeter.endpoints.sagemaker import SageMakerBase +from llmeter.prompt_utils import save_payloads, load_payloads, load_prompts + + +class TestMultiModalSerialization: + """Test serialization and deserialization of multi-modal payloads.""" + + def test_save_and_load_single_image_payload(self): + """Test saving and loading a payload with a single image.""" + # Create a payload with image bytes + image_bytes = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9" + + # Create temporary image file + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(image_bytes) + temp_image_path = f.name + + try: + payload = BedrockBase.create_payload( + user_message="What's in this image?", + images=[temp_image_path], + max_tokens=256, + ) + + # Save payload + with tempfile.TemporaryDirectory() as temp_dir: + output_path = save_payloads(payload, temp_dir, "test_payload.jsonl") + assert output_path.exists() + + # Load payload + loaded_payloads = list(load_payloads(output_path)) + assert len(loaded_payloads) == 1 + + loaded_payload = loaded_payloads[0] + + # Verify structure + assert "messages" in loaded_payload + content = loaded_payload["messages"][0]["content"] + assert len(content) == 2 # text + image + + # Verify binary content is preserved + loaded_image_bytes = content[1]["image"]["source"]["bytes"] + assert isinstance(loaded_image_bytes, bytes) + assert loaded_image_bytes == image_bytes + + finally: + Path(temp_image_path).unlink() + + def test_save_and_load_multiple_content_types(self): + """Test saving and loading a payload with multiple content types.""" + # Create test files + image_bytes = b"\xff\xd8\xff\xe0" + pdf_bytes = b"%PDF-1.4\n" + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_file: + img_file.write(image_bytes) + img_path = img_file.name + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as doc_file: + doc_file.write(pdf_bytes) + doc_path = doc_file.name + + try: + payload = BedrockBase.create_payload( + user_message="Analyze this", + images=[img_path], + documents=[doc_path], + max_tokens=1024, + ) + + # Save and load + with tempfile.TemporaryDirectory() as temp_dir: + output_path = save_payloads(payload, temp_dir) + loaded_payloads = list(load_payloads(output_path)) + + assert len(loaded_payloads) == 1 + loaded_payload = loaded_payloads[0] + + # Verify all content is preserved + content = loaded_payload["messages"][0]["content"] + assert len(content) == 3 # text + image + document + + # Verify binary content + loaded_image = content[1]["image"]["source"]["bytes"] + loaded_doc = content[2]["document"]["source"]["bytes"] + + assert isinstance(loaded_image, bytes) + assert isinstance(loaded_doc, bytes) + assert loaded_image == image_bytes + assert loaded_doc == pdf_bytes + + finally: + Path(img_path).unlink() + Path(doc_path).unlink() + + def test_round_trip_preservation(self): + """Test that binary content is preserved byte-for-byte in round-trip.""" + # Create payload with bytes directly + image_bytes = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9" + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(image_bytes) + temp_path = f.name + + try: + original_payload = BedrockBase.create_payload( + user_message="Test", images=[temp_path], max_tokens=256 + ) + + # Save and load + with tempfile.TemporaryDirectory() as temp_dir: + output_path = save_payloads(original_payload, temp_dir) + loaded_payload = list(load_payloads(output_path))[0] + + # Verify exact equality + assert original_payload == loaded_payload + + finally: + Path(temp_path).unlink() + + def test_save_multiple_payloads(self): + """Test saving and loading multiple payloads.""" + image_bytes = b"\xff\xd8\xff\xe0" + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(image_bytes) + temp_path = f.name + + try: + payloads = [ + BedrockBase.create_payload( + user_message=f"Image {i}", images=[temp_path], max_tokens=256 + ) + for i in range(3) + ] + + # Save and load + with tempfile.TemporaryDirectory() as temp_dir: + output_path = save_payloads(payloads, temp_dir) + loaded_payloads = list(load_payloads(output_path)) + + assert len(loaded_payloads) == 3 + + # Verify each payload + for i, loaded in enumerate(loaded_payloads): + content = loaded["messages"][0]["content"] + assert content[0]["text"] == f"Image {i}" + assert isinstance(content[1]["image"]["source"]["bytes"], bytes) + + finally: + Path(temp_path).unlink() + + def test_load_prompts_with_multimodal_create_payload(self): + """Test load_prompts integration with multi-modal create_payload.""" + # Create a prompts file + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f: + f.write("What is this?\n") + f.write("Describe the image\n") + prompts_path = f.name + + # Create an image file + image_bytes = b"\xff\xd8\xff\xe0" + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_file: + img_file.write(image_bytes) + img_path = img_file.name + + try: + # Define a create_payload function that includes an image + def create_multimodal_payload(input_text, **kwargs): + return BedrockBase.create_payload( + user_message=input_text, images=[img_path], max_tokens=256, **kwargs + ) + + # Load prompts with multi-modal payload creation + payloads = list(load_prompts(prompts_path, create_multimodal_payload)) + + assert len(payloads) == 2 + + # Verify each payload has the image + for payload in payloads: + content = payload["messages"][0]["content"] + assert len(content) == 2 # text + image + assert "image" in content[1] + assert isinstance(content[1]["image"]["source"]["bytes"], bytes) + + finally: + Path(prompts_path).unlink() + Path(img_path).unlink() + + def test_openai_payload_serialization(self): + """Test serialization of OpenAI multi-modal payloads.""" + image_bytes = b"\xff\xd8\xff\xe0" + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(image_bytes) + temp_path = f.name + + try: + payload = OpenAIEndpoint.create_payload( + user_message="Test", images=[temp_path], max_tokens=256 + ) + + # Save and load + with tempfile.TemporaryDirectory() as temp_dir: + output_path = save_payloads(payload, temp_dir) + loaded_payload = list(load_payloads(output_path))[0] + + # Verify exact equality + assert payload == loaded_payload + + # Verify OpenAI-specific format + content = loaded_payload["messages"][0]["content"] + assert content[1]["image"]["format"] == "image/jpeg" + + finally: + Path(temp_path).unlink() + + def test_sagemaker_payload_serialization(self): + """Test serialization of SageMaker multi-modal payloads.""" + image_bytes = b"\xff\xd8\xff\xe0" + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(image_bytes) + temp_path = f.name + + try: + payload = SageMakerBase.create_payload( + input_text="Test", images=[temp_path], max_tokens=256 + ) + + # Save and load + with tempfile.TemporaryDirectory() as temp_dir: + output_path = save_payloads(payload, temp_dir) + loaded_payload = list(load_payloads(output_path))[0] + + # Verify exact equality + assert payload == loaded_payload + + # Verify SageMaker-specific format + content = loaded_payload["inputs"] + assert content[1]["image"]["format"] == "jpeg" + + finally: + Path(temp_path).unlink() diff --git a/tests/unit/endpoints/test_multimodal_utilities.py b/tests/unit/endpoints/test_multimodal_utilities.py new file mode 100644 index 0000000..1a7f82f --- /dev/null +++ b/tests/unit/endpoints/test_multimodal_utilities.py @@ -0,0 +1,212 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +from llmeter.prompt_utils import ( + read_file, + detect_format_from_extension, + detect_format_from_bytes, + detect_format_from_file, +) + + +class TestReadFile: + """Test the read_file utility function.""" + + def test_read_file_valid_path(self): + """Test reading a valid file.""" + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(b"test content") + temp_path = f.name + + try: + content = read_file(temp_path) + assert content == b"test content" + finally: + Path(temp_path).unlink() + + def test_read_file_nonexistent(self): + """Test reading a non-existent file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File not found"): + read_file("/nonexistent/file.txt") + + def test_read_file_binary_content(self): + """Test reading binary content.""" + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(b"\xff\xd8\xff\xe0") # JPEG magic bytes + temp_path = f.name + + try: + content = read_file(temp_path) + assert content == b"\xff\xd8\xff\xe0" + finally: + Path(temp_path).unlink() + + +class TestDetectFormatFromExtension: + """Test the detect_format_from_extension utility function.""" + + def test_detect_jpeg_extension(self): + """Test detecting JPEG format from .jpg extension.""" + mime_type = detect_format_from_extension("image.jpg") + assert mime_type == "image/jpeg" + + def test_detect_jpeg_extension_uppercase(self): + """Test detecting JPEG format from .JPG extension.""" + mime_type = detect_format_from_extension("image.JPG") + assert mime_type == "image/jpeg" + + def test_detect_png_extension(self): + """Test detecting PNG format from .png extension.""" + mime_type = detect_format_from_extension("image.png") + assert mime_type == "image/png" + + def test_detect_pdf_extension(self): + """Test detecting PDF format from .pdf extension.""" + mime_type = detect_format_from_extension("document.pdf") + assert mime_type == "application/pdf" + + def test_detect_mp4_extension(self): + """Test detecting MP4 format from .mp4 extension.""" + mime_type = detect_format_from_extension("video.mp4") + assert mime_type == "video/mp4" + + def test_detect_mp3_extension(self): + """Test detecting MP3 format from .mp3 extension.""" + mime_type = detect_format_from_extension("audio.mp3") + assert mime_type == "audio/mpeg" + + def test_detect_wav_extension(self): + """Test detecting WAV format from .wav extension.""" + mime_type = detect_format_from_extension("audio.wav") + assert mime_type == "audio/wav" + + def test_detect_unknown_extension(self): + """Test detecting unknown extension returns None.""" + mime_type = detect_format_from_extension("file.unknown") + assert mime_type is None + + def test_detect_no_extension(self): + """Test detecting file without extension returns None.""" + mime_type = detect_format_from_extension("file") + assert mime_type is None + + +class TestDetectFormatFromBytes: + """Test the detect_format_from_bytes utility function.""" + + def test_detect_jpeg_from_bytes_with_puremagic(self): + """Test detecting JPEG format from bytes with puremagic.""" + jpeg_bytes = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9" + try: + mime_type = detect_format_from_bytes(jpeg_bytes) + # If puremagic is installed, it should detect the format + if mime_type is not None: + assert mime_type == "image/jpeg" + else: + # If puremagic is not installed, it returns None + pytest.skip("puremagic not installed") + except ImportError: + pytest.skip("puremagic not installed") + + def test_detect_png_from_bytes_with_puremagic(self): + """Test detecting PNG format from bytes with puremagic.""" + png_bytes = b"\x89PNG\r\n\x1a\n" + try: + mime_type = detect_format_from_bytes(png_bytes) + # If puremagic is installed, it should detect the format + if mime_type is not None: + assert mime_type == "image/png" + else: + # If puremagic is not installed, it returns None + pytest.skip("puremagic not installed") + except ImportError: + pytest.skip("puremagic not installed") + + def test_detect_format_without_puremagic(self): + """Test that detection returns None when puremagic is not available.""" + # Mock puremagic to raise ImportError when accessed + with patch("llmeter.prompt_utils.puremagic") as mock_puremagic: + mock_puremagic.from_string.side_effect = ImportError("puremagic not available") + mime_type = detect_format_from_bytes(b"\xff\xd8\xff\xe0") + assert mime_type is None + + +class TestDetectFormatFromFile: + """Test the detect_format_from_file utility function.""" + + def test_detect_format_from_jpeg_file(self): + """Test detecting format from JPEG file.""" + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write( + b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9" + ) + temp_path = f.name + + try: + mime_type = detect_format_from_file(temp_path) + assert mime_type == "image/jpeg" + finally: + Path(temp_path).unlink() + + def test_detect_format_from_png_file(self): + """Test detecting format from PNG file.""" + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n\x1a\n") + temp_path = f.name + + try: + mime_type = detect_format_from_file(temp_path) + assert mime_type == "image/png" + finally: + Path(temp_path).unlink() + + def test_detect_format_from_pdf_file(self): + """Test detecting format from PDF file.""" + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: + f.write(b"%PDF-1.4\n") + temp_path = f.name + + try: + mime_type = detect_format_from_file(temp_path) + assert mime_type == "application/pdf" + finally: + Path(temp_path).unlink() + + def test_detect_format_fallback_to_extension(self): + """Test that detection falls back to extension when puremagic not installed.""" + # Create a file with JPEG magic bytes + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(b"\xff\xd8\xff\xe0") + temp_path = f.name + + try: + # Mock puremagic to raise ImportError when accessed + with patch("llmeter.prompt_utils.puremagic") as mock_puremagic: + mock_puremagic.magic_file.side_effect = ImportError("puremagic not available") + mime_type = detect_format_from_file(temp_path) + # Should fall back to extension-based detection + assert mime_type == "image/jpeg" + finally: + Path(temp_path).unlink() + + def test_detect_format_no_extension_no_puremagic(self): + """Test that detection returns None for file without extension when puremagic unavailable.""" + with tempfile.NamedTemporaryFile(suffix="", delete=False) as f: + f.write(b"some content") + temp_path = f.name + + try: + # Mock puremagic to raise ImportError when accessed + with patch("llmeter.prompt_utils.puremagic") as mock_puremagic: + mock_puremagic.magic_file.side_effect = ImportError("puremagic not available") + mime_type = detect_format_from_file(temp_path) + # Should fall back to extension-based detection, which returns None for no extension + assert mime_type is None + finally: + Path(temp_path).unlink() diff --git a/tests/unit/endpoints/test_openai_multimodal.py b/tests/unit/endpoints/test_openai_multimodal.py new file mode 100644 index 0000000..d587178 --- /dev/null +++ b/tests/unit/endpoints/test_openai_multimodal.py @@ -0,0 +1,144 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pytest + +from llmeter.endpoints.openai import OpenAIEndpoint + + +class TestOpenAIMultiModal: + """Test multi-modal functionality for OpenAI endpoints.""" + + def test_create_payload_single_image_from_file(self): + """Test creating payload with single image from file path.""" + # Create a temporary image file + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(b"\xff\xd8\xff\xe0") # JPEG magic bytes + temp_path = f.name + + try: + payload = OpenAIEndpoint.create_payload( + user_message="What's in this image?", images=[temp_path], max_tokens=256 + ) + + assert "messages" in payload + assert len(payload["messages"]) == 1 + assert payload["messages"][0]["role"] == "user" + + content = payload["messages"][0]["content"] + assert len(content) == 2 # text + image + assert content[0]["text"] == "What's in this image?" + assert "image" in content[1] + # OpenAI uses full MIME types + assert content[1]["image"]["format"] == "image/jpeg" + assert "bytes" in content[1]["image"]["source"] + + finally: + Path(temp_path).unlink() + + def test_create_payload_single_image_from_bytes(self): + """Test creating payload with single image from bytes.""" + # Create a minimal valid JPEG file + jpeg_bytes = ( + b"\xff\xd8" # SOI (Start of Image) + b"\xff\xe0" # APP0 marker + b"\x00\x10" # APP0 length (16 bytes) + b"JFIF\x00" # JFIF identifier + b"\x01\x01" # JFIF version 1.1 + b"\x00" # density units (0 = no units) + b"\x00\x01" # X density = 1 + b"\x00\x01" # Y density = 1 + b"\x00\x00" # thumbnail width and height = 0 + b"\xff\xd9" # EOI (End of Image) + ) + + try: + payload = OpenAIEndpoint.create_payload( + user_message="What's in this image?", + images=[jpeg_bytes], + max_tokens=256, + ) + + assert "messages" in payload + content = payload["messages"][0]["content"] + assert len(content) == 2 # text + image + assert "image" in content[1] + # OpenAI uses full MIME types + assert content[1]["image"]["format"] == "image/jpeg" + assert content[1]["image"]["source"]["bytes"] == jpeg_bytes + except ValueError as e: + # If puremagic can't detect the format, skip this test + if "Cannot detect format from bytes" in str(e): + pytest.skip("puremagic cannot detect format from minimal JPEG bytes") + raise + + def test_create_payload_mixed_content(self): + """Test creating payload with mixed content types.""" + # Create temporary files + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_file: + img_file.write(b"\xff\xd8\xff\xe0") + img_path = img_file.name + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as doc_file: + doc_file.write(b"%PDF-1.4") + doc_path = doc_file.name + + try: + payload = OpenAIEndpoint.create_payload( + user_message="Analyze this", + images=[img_path], + documents=[doc_path], + max_tokens=1024, + ) + + content = payload["messages"][0]["content"] + assert len(content) == 3 # text + image + document + assert content[0]["text"] == "Analyze this" + assert "image" in content[1] + # OpenAI uses full MIME types + assert content[1]["image"]["format"] == "image/jpeg" + assert "document" in content[2] + assert content[2]["document"]["format"] == "application/pdf" + + finally: + Path(img_path).unlink() + Path(doc_path).unlink() + + def test_create_payload_text_only_backward_compatible(self): + """Test that text-only payloads still work (backward compatibility).""" + payload = OpenAIEndpoint.create_payload( + user_message="Hello, world!", max_tokens=256 + ) + + assert "messages" in payload + content = payload["messages"][0]["content"] + # Text-only should be a string, not a list + assert content == "Hello, world!" + + def test_create_payload_invalid_image_type(self): + """Test that invalid image types raise TypeError.""" + with pytest.raises( + TypeError, match="Items in images list must be bytes or str" + ): + OpenAIEndpoint.create_payload( + user_message="Test", + images=[123], # Invalid type + max_tokens=256, + ) + + def test_create_payload_invalid_images_not_list(self): + """Test that non-list images parameter raises TypeError.""" + with pytest.raises(TypeError, match="images must be a list"): + OpenAIEndpoint.create_payload( + user_message="Test", images="not_a_list", max_tokens=256 + ) + + def test_create_payload_missing_file(self): + """Test that missing file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File not found"): + OpenAIEndpoint.create_payload( + user_message="Test", images=["/nonexistent/file.jpg"], max_tokens=256 + ) diff --git a/tests/unit/endpoints/test_sagemaker_multimodal.py b/tests/unit/endpoints/test_sagemaker_multimodal.py new file mode 100644 index 0000000..cb82236 --- /dev/null +++ b/tests/unit/endpoints/test_sagemaker_multimodal.py @@ -0,0 +1,140 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pytest + +from llmeter.endpoints.sagemaker import SageMakerBase + + +class TestSageMakerMultiModal: + """Test multi-modal functionality for SageMaker endpoints.""" + + def test_create_payload_single_image_from_file(self): + """Test creating payload with single image from file path.""" + # Create a temporary image file + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(b"\xff\xd8\xff\xe0") # JPEG magic bytes + temp_path = f.name + + try: + payload = SageMakerBase.create_payload( + input_text="What's in this image?", images=[temp_path], max_tokens=256 + ) + + assert "inputs" in payload + content = payload["inputs"] + assert len(content) == 2 # text + image + assert content[0]["text"] == "What's in this image?" + assert "image" in content[1] + # SageMaker uses Bedrock-style short format strings + assert content[1]["image"]["format"] == "jpeg" + assert "bytes" in content[1]["image"]["source"] + + finally: + Path(temp_path).unlink() + + def test_create_payload_single_image_from_bytes(self): + """Test creating payload with single image from bytes.""" + # Create a minimal valid JPEG file + jpeg_bytes = ( + b"\xff\xd8" # SOI (Start of Image) + b"\xff\xe0" # APP0 marker + b"\x00\x10" # APP0 length (16 bytes) + b"JFIF\x00" # JFIF identifier + b"\x01\x01" # JFIF version 1.1 + b"\x00" # density units (0 = no units) + b"\x00\x01" # X density = 1 + b"\x00\x01" # Y density = 1 + b"\x00\x00" # thumbnail width and height = 0 + b"\xff\xd9" # EOI (End of Image) + ) + + try: + payload = SageMakerBase.create_payload( + input_text="What's in this image?", + images=[jpeg_bytes], + max_tokens=256, + ) + + assert "inputs" in payload + content = payload["inputs"] + assert len(content) == 2 # text + image + assert "image" in content[1] + # SageMaker uses Bedrock-style short format strings + assert content[1]["image"]["format"] == "jpeg" + assert content[1]["image"]["source"]["bytes"] == jpeg_bytes + except ValueError as e: + # If puremagic can't detect the format, skip this test + if "Cannot detect format from bytes" in str(e): + pytest.skip("puremagic cannot detect format from minimal JPEG bytes") + raise + + def test_create_payload_mixed_content(self): + """Test creating payload with mixed content types.""" + # Create temporary files + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_file: + img_file.write(b"\xff\xd8\xff\xe0") + img_path = img_file.name + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as doc_file: + doc_file.write(b"%PDF-1.4") + doc_path = doc_file.name + + try: + payload = SageMakerBase.create_payload( + input_text="Analyze this", + images=[img_path], + documents=[doc_path], + max_tokens=1024, + ) + + content = payload["inputs"] + assert len(content) == 3 # text + image + document + assert content[0]["text"] == "Analyze this" + assert "image" in content[1] + # SageMaker uses Bedrock-style short format strings + assert content[1]["image"]["format"] == "jpeg" + assert "document" in content[2] + assert content[2]["document"]["format"] == "pdf" + + finally: + Path(img_path).unlink() + Path(doc_path).unlink() + + def test_create_payload_text_only_backward_compatible(self): + """Test that text-only payloads still work (backward compatibility).""" + payload = SageMakerBase.create_payload( + input_text="Hello, world!", max_tokens=256 + ) + + assert "inputs" in payload + # Text-only should be a string, not a list + assert payload["inputs"] == "Hello, world!" + + def test_create_payload_invalid_image_type(self): + """Test that invalid image types raise TypeError.""" + with pytest.raises( + TypeError, match="Items in images list must be bytes or str" + ): + SageMakerBase.create_payload( + input_text="Test", + images=[123], # Invalid type + max_tokens=256, + ) + + def test_create_payload_invalid_images_not_list(self): + """Test that non-list images parameter raises TypeError.""" + with pytest.raises(TypeError, match="images must be a list"): + SageMakerBase.create_payload( + input_text="Test", images="not_a_list", max_tokens=256 + ) + + def test_create_payload_missing_file(self): + """Test that missing file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File not found"): + SageMakerBase.create_payload( + input_text="Test", images=["/nonexistent/file.jpg"], max_tokens=256 + ) diff --git a/tests/unit/test_prompt_utils.py b/tests/unit/test_prompt_utils.py index c8a07e9..cac7ddb 100644 --- a/tests/unit/test_prompt_utils.py +++ b/tests/unit/test_prompt_utils.py @@ -17,9 +17,430 @@ load_prompts, save_payloads, ) +from llmeter.json_utils import LLMeterEncoder, llmeter_bytes_decoder from llmeter.tokenizers import DummyTokenizer +class TestLLMeterEncoder: + """Unit tests for LLMeterEncoder class. + + These tests verify specific examples and edge cases for the LLMeterEncoder + class, complementing the property-based tests. + + Requirements: 1.1, 1.2, 1.3, 1.6 + """ + + def test_simple_bytes_object_serialization(self): + """Test serialization of a simple bytes object. + + Validates: Requirements 1.1, 1.2, 1.3 + """ + payload = {"data": b"hello world"} + + # Serialize using the encoder + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify the marker object structure + assert "__llmeter_bytes__" in parsed["data"] + assert len(parsed["data"]) == 1 + assert isinstance(parsed["data"]["__llmeter_bytes__"], str) + + # Verify the base64 encoding is correct + import base64 + + decoded = base64.b64decode(parsed["data"]["__llmeter_bytes__"]) + assert decoded == b"hello world" + + def test_nested_bytes_in_dict_structure(self): + """Test serialization of bytes nested in complex dict structure. + + Validates: Requirements 1.1, 1.3, 1.5 + """ + payload = { + "modelId": "test-model", + "messages": [ + { + "role": "user", + "content": [ + {"text": "What is in this image?"}, + { + "image": { + "format": "jpeg", + "source": {"bytes": b"\xff\xd8\xff\xe0\x00\x10JFIF"}, + } + }, + ], + } + ], + } + + # Serialize + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify structure is preserved + assert parsed["modelId"] == "test-model" + assert parsed["messages"][0]["role"] == "user" + assert parsed["messages"][0]["content"][0]["text"] == "What is in this image?" + + # Verify bytes are replaced with marker + bytes_marker = parsed["messages"][0]["content"][1]["image"]["source"]["bytes"] + assert "__llmeter_bytes__" in bytes_marker + assert len(bytes_marker) == 1 + + # Verify the base64 encoding + import base64 + + decoded = base64.b64decode(bytes_marker["__llmeter_bytes__"]) + assert decoded == b"\xff\xd8\xff\xe0\x00\x10JFIF" + + def test_empty_bytes_object(self): + """Test serialization of an empty bytes object. + + Validates: Requirements 1.1, 1.3 + """ + payload = {"empty": b""} + + # Serialize + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify marker object exists + assert "__llmeter_bytes__" in parsed["empty"] + + # Verify empty bytes decodes correctly + import base64 + + decoded = base64.b64decode(parsed["empty"]["__llmeter_bytes__"]) + assert decoded == b"" + assert len(decoded) == 0 + + def test_large_binary_data_1mb(self): + """Test serialization of large binary data (1MB). + + Validates: Requirements 1.6, 10.1 + """ + import os + + # Create 1MB of random binary data + large_data = os.urandom(1024 * 1024) + payload = {"large_image": large_data} + + # Serialize + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify marker object exists + assert "__llmeter_bytes__" in parsed["large_image"] + + # Verify the data round-trips correctly + import base64 + + decoded = base64.b64decode(parsed["large_image"]["__llmeter_bytes__"]) + assert decoded == large_data + assert len(decoded) == 1024 * 1024 + + def test_multiple_bytes_objects_in_payload(self): + """Test serialization of payload with multiple bytes objects. + + Validates: Requirements 1.1, 1.3, 9.8 + """ + payload = { + "image1": b"first image data", + "image2": b"second image data", + "nested": {"image3": b"third image data"}, + } + + # Serialize + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify all bytes objects have markers + assert "__llmeter_bytes__" in parsed["image1"] + assert "__llmeter_bytes__" in parsed["image2"] + assert "__llmeter_bytes__" in parsed["nested"]["image3"] + + # Verify all decode correctly + import base64 + + assert ( + base64.b64decode(parsed["image1"]["__llmeter_bytes__"]) + == b"first image data" + ) + assert ( + base64.b64decode(parsed["image2"]["__llmeter_bytes__"]) + == b"second image data" + ) + assert ( + base64.b64decode(parsed["nested"]["image3"]["__llmeter_bytes__"]) + == b"third image data" + ) + + def test_bytes_in_list(self): + """Test serialization of bytes objects within lists. + + Validates: Requirements 1.1, 1.3, 1.5 + """ + payload = {"images": [b"image1", b"image2", b"image3"]} + + # Serialize + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify all list items have markers + assert len(parsed["images"]) == 3 + for item in parsed["images"]: + assert "__llmeter_bytes__" in item + assert len(item) == 1 + + def test_mixed_types_with_bytes(self): + """Test serialization of payload with mixed types including bytes. + + Validates: Requirements 1.1, 1.5 + """ + payload = { + "string": "text value", + "number": 42, + "float": 3.14, + "boolean": True, + "null": None, + "list": [1, 2, 3], + "bytes": b"binary data", + "nested": {"more_bytes": b"more binary"}, + } + + # Serialize + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify non-bytes types are preserved + assert parsed["string"] == "text value" + assert parsed["number"] == 42 + assert parsed["float"] == 3.14 + assert parsed["boolean"] is True + assert parsed["null"] is None + assert parsed["list"] == [1, 2, 3] + + # Verify bytes have markers + assert "__llmeter_bytes__" in parsed["bytes"] + assert "__llmeter_bytes__" in parsed["nested"]["more_bytes"] + + +class TestLLMeterBytesDecoder: + """Unit tests for llmeter_bytes_decoder function. + + These tests verify specific examples and edge cases for the llmeter_bytes_decoder + function, complementing the property-based tests. + + Requirements: 2.1, 2.2, 2.3, 2.4, 6.2 + """ + + def test_marker_object_decoding(self): + """Test decoding of a marker object with valid base64. + + Validates: Requirements 2.1, 2.2, 2.3 + """ + + # Create a marker object with base64-encoded bytes + marker = {"__llmeter_bytes__": "aGVsbG8gd29ybGQ="} # "hello world" in base64 + + # Decode the marker + result = llmeter_bytes_decoder(marker) + + # Verify it returns bytes + assert isinstance(result, bytes) + assert result == b"hello world" + + def test_non_marker_dict_passthrough(self): + """Test that non-marker dicts are returned unchanged. + + Validates: Requirements 2.4 + """ + + # Regular dict without marker key + regular_dict = {"key": "value", "number": 42, "nested": {"data": "test"}} + + # Decode should return unchanged + result = llmeter_bytes_decoder(regular_dict) + + # Verify it's the same dict + assert result == regular_dict + assert result is regular_dict # Should be the exact same object + + def test_invalid_base64_error_handling(self): + """Test that invalid base64 in marker raises appropriate error. + + Validates: Requirements 6.2 + """ + import binascii + + + # Marker with invalid base64 string + invalid_marker = {"__llmeter_bytes__": "not-valid-base64!!!"} + + # Should raise binascii.Error when trying to decode + with pytest.raises(binascii.Error): + llmeter_bytes_decoder(invalid_marker) + + def test_multi_key_dict_with_marker_key_not_decoded(self): + """Test that multi-key dict containing marker key is not decoded. + + This is a safety check to ensure we only decode single-key marker objects. + + Validates: Requirements 2.4 + """ + + # Dict with marker key but also other keys (should not be decoded) + multi_key_dict = { + "__llmeter_bytes__": "aGVsbG8=", + "other_key": "other_value", + } + + # Should return unchanged (not decode) + result = llmeter_bytes_decoder(multi_key_dict) + + # Verify it's returned as-is + assert result == multi_key_dict + assert isinstance(result, dict) + assert "__llmeter_bytes__" in result + assert "other_key" in result + + def test_empty_bytes_decoding(self): + """Test decoding of marker object with empty bytes. + + Validates: Requirements 2.1, 2.3 + """ + + # Marker with empty base64 string (empty bytes) + empty_marker = {"__llmeter_bytes__": ""} + + # Decode + result = llmeter_bytes_decoder(empty_marker) + + # Verify it returns empty bytes + assert isinstance(result, bytes) + assert result == b"" + assert len(result) == 0 + + def test_large_binary_data_decoding(self): + """Test decoding of marker object with large binary data. + + Validates: Requirements 2.1, 2.3 + """ + import base64 + import os + + + # Create 1MB of random binary data + large_data = os.urandom(1024 * 1024) + base64_encoded = base64.b64encode(large_data).decode("utf-8") + + # Create marker + marker = {"__llmeter_bytes__": base64_encoded} + + # Decode + result = llmeter_bytes_decoder(marker) + + # Verify it matches original data + assert isinstance(result, bytes) + assert result == large_data + assert len(result) == 1024 * 1024 + + def test_nested_structure_with_marker(self): + """Test that decoder works correctly when used with json.loads on nested structures. + + Validates: Requirements 2.1, 2.5 + """ + + # JSON string with nested marker objects + json_str = json.dumps( + { + "modelId": "test-model", + "messages": [ + { + "role": "user", + "content": [ + {"text": "What is this?"}, + { + "image": { + "source": { + "bytes": { + "__llmeter_bytes__": "aGVsbG8=" # "hello" + } + } + } + }, + ], + } + ], + } + ) + + # Load with decoder + result = json.loads(json_str, object_hook=llmeter_bytes_decoder) + + # Verify structure is preserved + assert result["modelId"] == "test-model" + assert result["messages"][0]["role"] == "user" + + # Verify bytes are decoded + bytes_value = result["messages"][0]["content"][1]["image"]["source"]["bytes"] + assert isinstance(bytes_value, bytes) + assert bytes_value == b"hello" + + def test_dict_without_marker_key(self): + """Test that dict without marker key is returned unchanged. + + Validates: Requirements 2.4 + """ + + # Dict without the marker key + normal_dict = {"data": "value", "count": 123} + + # Should return unchanged + result = llmeter_bytes_decoder(normal_dict) + + assert result == normal_dict + assert isinstance(result, dict) + + def test_marker_with_special_characters(self): + """Test decoding marker with special characters in base64. + + Validates: Requirements 2.1, 2.3 + """ + import base64 + + + # Binary data with special characters + special_data = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01" + base64_encoded = base64.b64encode(special_data).decode("utf-8") + + # Create marker + marker = {"__llmeter_bytes__": base64_encoded} + + # Decode + result = llmeter_bytes_decoder(marker) + + # Verify + assert isinstance(result, bytes) + assert result == special_data + + class TestCreatePromptCollection: """Tests for CreatePromptCollection class.""" @@ -178,7 +599,9 @@ def create_payload(input_text, max_tokens=100): prompts = list( load_prompts( - prompt_file, create_payload, create_payload_kwargs={"max_tokens": 50} + prompt_file, + create_payload, + create_payload_kwargs={"max_tokens": 50}, ) ) assert len(prompts) == 1 @@ -195,9 +618,7 @@ def test_load_prompts_from_directory(self): def create_payload(input_text): return {"text": input_text} - prompts = list( - load_prompts(dir_path, create_payload, file_pattern="*.txt") - ) + prompts = list(load_prompts(dir_path, create_payload, file_pattern="*.txt")) assert len(prompts) == 2 def test_load_prompts_skips_empty_lines(self): @@ -216,7 +637,9 @@ def test_load_prompts_handles_exceptions(self): """Test that exceptions in create_payload_fn are handled when loading from directory.""" with tempfile.TemporaryDirectory() as tmpdir: dir_path = Path(tmpdir) - (dir_path / "prompts.txt").write_text("Good prompt\nBad prompt\nAnother good prompt\n") + (dir_path / "prompts.txt").write_text( + "Good prompt\nBad prompt\nAnother good prompt\n" + ) def create_payload(input_text): if "Bad" in input_text: @@ -334,12 +757,305 @@ def test_load_payloads_handles_io_error(self): try: json_file.chmod(0o000) # Should handle the error gracefully - payloads = list(load_payloads(json_file)) + _ = list(load_payloads(json_file)) # May be empty due to permission error finally: # Restore permissions for cleanup json_file.chmod(0o644) + def test_load_payloads_with_marker_objects(self): + """Test loading payload with marker objects. + + Validates: Requirements 2.1, 3.1, 5.2 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Create a payload with marker objects (simulating saved binary content) + payload_with_markers = { + "modelId": "test-model", + "image": { + "source": { + "bytes": { + "__llmeter_bytes__": "aGVsbG8gd29ybGQ=" # "hello world" + } + } + }, + } + + # Write to file + with jsonl_file.open("w") as f: + f.write(json.dumps(payload_with_markers) + "\n") + + # Load using load_payloads + payloads = list(load_payloads(jsonl_file)) + + # Verify payload was loaded + assert len(payloads) == 1 + + # Verify bytes were restored from marker + loaded_payload = payloads[0] + assert loaded_payload["modelId"] == "test-model" + assert isinstance( + loaded_payload["image"]["source"]["bytes"], bytes + ) + assert loaded_payload["image"]["source"]["bytes"] == b"hello world" + + def test_load_payloads_bytes_correctly_restored(self): + """Test that bytes objects are correctly restored from marker objects. + + Validates: Requirements 2.1, 2.3, 5.2 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Create multiple payloads with different bytes content + payloads_with_markers = [ + {"id": 1, "data": {"__llmeter_bytes__": "Zmlyc3Q="}}, # "first" + {"id": 2, "data": {"__llmeter_bytes__": "c2Vjb25k"}}, # "second" + { + "id": 3, + "nested": { + "deep": {"data": {"__llmeter_bytes__": "dGhpcmQ="}} # "third" + }, + }, + ] + + # Write to file + with jsonl_file.open("w") as f: + for payload in payloads_with_markers: + f.write(json.dumps(payload) + "\n") + + # Load payloads + loaded = list(load_payloads(jsonl_file)) + + # Verify all bytes were restored correctly + assert len(loaded) == 3 + assert loaded[0]["data"] == b"first" + assert loaded[1]["data"] == b"second" + assert loaded[2]["nested"]["deep"]["data"] == b"third" + + # Verify they are bytes objects + assert isinstance(loaded[0]["data"], bytes) + assert isinstance(loaded[1]["data"], bytes) + assert isinstance(loaded[2]["nested"]["deep"]["data"], bytes) + + def test_load_payloads_bytes_marker_round_trip(self): + """Test that bytes markers are decoded automatically during load. + + Validates: Requirements 5.2, 5.4 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Write a payload with a bytes marker (as save_payloads would) + payload = {"image": b"\xff\xd8\xff\xe0"} + save_payloads(payload, Path(tmpdir)) + + # Load and verify bytes are restored + loaded = list(load_payloads(jsonl_file)) + + assert len(loaded) == 1 + assert loaded[0]["image"] == b"\xff\xd8\xff\xe0" + + def test_load_payloads_backward_compatibility_old_format(self): + """Test backward compatibility - old format loads successfully. + + When loading payloads saved with the old format (no binary data, no markers), + they should load successfully without any issues. + + Validates: Requirements 3.1, 5.2, 5.4 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "old_format.jsonl" + + # Create payloads in old format (no marker objects) + old_format_payloads = [ + {"modelId": "test-model-1", "prompt": "Hello world", "maxTokens": 100}, + { + "modelId": "test-model-2", + "messages": [ + {"role": "user", "content": "What is the weather?"} + ], + }, + { + "modelId": "test-model-3", + "config": {"temperature": 0.7, "topP": 0.9}, + }, + ] + + # Write using standard json.dumps (old format) + with jsonl_file.open("w") as f: + for payload in old_format_payloads: + f.write(json.dumps(payload) + "\n") + + # Load using new load_payloads (with binary support) + loaded = list(load_payloads(jsonl_file)) + + # Verify all payloads loaded successfully + assert len(loaded) == 3 + + # Verify content matches exactly + assert loaded[0] == old_format_payloads[0] + assert loaded[1] == old_format_payloads[1] + assert loaded[2] == old_format_payloads[2] + + # Verify no marker objects were introduced + for payload in loaded: + assert "__llmeter_bytes__" not in json.dumps(payload) + + def test_load_payloads_round_trip_with_binary_content(self): + """Test round-trip: save with binary, load, verify bytes restored. + + Validates: Requirements 2.1, 2.5, 4.1 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + + # Original payload with binary content + original_payload = { + "modelId": "test-model", + "messages": [ + { + "role": "user", + "content": [ + {"text": "Describe this image"}, + { + "image": { + "format": "jpeg", + "source": {"bytes": b"\xff\xd8\xff\xe0\x00\x10JFIF"}, + } + }, + ], + } + ], + } + + # Save using save_payloads + saved_path = save_payloads(original_payload, output_path) + + # Load using load_payloads + loaded = list(load_payloads(saved_path)) + + # Verify round-trip integrity + assert len(loaded) == 1 + loaded_payload = loaded[0] + + # Verify structure is preserved + assert loaded_payload["modelId"] == original_payload["modelId"] + assert len(loaded_payload["messages"]) == 1 + assert loaded_payload["messages"][0]["role"] == "user" + + # Verify bytes were restored correctly + loaded_bytes = loaded_payload["messages"][0]["content"][1]["image"][ + "source" + ]["bytes"] + original_bytes = original_payload["messages"][0]["content"][1]["image"][ + "source" + ]["bytes"] + + assert isinstance(loaded_bytes, bytes) + assert loaded_bytes == original_bytes + + def test_load_payloads_multiple_marker_objects_in_single_payload(self): + """Test loading payload with multiple marker objects. + + Validates: Requirements 2.1, 9.8 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Payload with multiple marker objects + payload_with_multiple_markers = { + "modelId": "test-model", + "images": [ + {"id": 1, "data": {"__llmeter_bytes__": "aW1hZ2Ux"}}, # "image1" + {"id": 2, "data": {"__llmeter_bytes__": "aW1hZ2Uy"}}, # "image2" + {"id": 3, "data": {"__llmeter_bytes__": "aW1hZ2Uz"}}, # "image3" + ], + } + + # Write to file + with jsonl_file.open("w") as f: + f.write(json.dumps(payload_with_multiple_markers) + "\n") + + # Load payload + loaded = list(load_payloads(jsonl_file)) + + # Verify all marker objects were decoded + assert len(loaded) == 1 + payload = loaded[0] + + assert len(payload["images"]) == 3 + assert payload["images"][0]["data"] == b"image1" + assert payload["images"][1]["data"] == b"image2" + assert payload["images"][2]["data"] == b"image3" + + # Verify all are bytes + for image in payload["images"]: + assert isinstance(image["data"], bytes) + + def test_load_payloads_empty_bytes_marker(self): + """Test loading marker object with empty bytes. + + Validates: Requirements 2.1, 2.3 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Payload with empty bytes marker + payload_with_empty = { + "id": 1, + "empty_data": {"__llmeter_bytes__": ""}, # Empty base64 = empty bytes + } + + # Write to file + with jsonl_file.open("w") as f: + f.write(json.dumps(payload_with_empty) + "\n") + + # Load payload + loaded = list(load_payloads(jsonl_file)) + + # Verify empty bytes were restored + assert len(loaded) == 1 + assert loaded[0]["empty_data"] == b"" + assert isinstance(loaded[0]["empty_data"], bytes) + assert len(loaded[0]["empty_data"]) == 0 + + def test_load_payloads_large_binary_data(self): + """Test loading payload with large binary data (1MB). + + Validates: Requirements 2.1, 10.2 + """ + import base64 + import os + + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Create 1MB of random binary data + large_data = os.urandom(1024 * 1024) + base64_encoded = base64.b64encode(large_data).decode("utf-8") + + # Payload with large marker object + payload_with_large = { + "id": 1, + "large_image": {"__llmeter_bytes__": base64_encoded}, + } + + # Write to file + with jsonl_file.open("w") as f: + f.write(json.dumps(payload_with_large) + "\n") + + # Load payload + loaded = list(load_payloads(jsonl_file)) + + # Verify large bytes were restored correctly + assert len(loaded) == 1 + assert isinstance(loaded[0]["large_image"], bytes) + assert loaded[0]["large_image"] == large_data + assert len(loaded[0]["large_image"]) == 1024 * 1024 + class TestSavePayloads: """Tests for save_payloads function.""" @@ -402,6 +1118,159 @@ def test_save_payloads_returns_path(self): assert isinstance(result_path, Path) assert result_path.exists() + def test_save_payloads_with_bytes_objects(self): + """Test saving payload with bytes objects. + + Validates: Requirements 1.1, 3.3, 5.1 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + payload = { + "modelId": "test-model", + "image": {"source": {"bytes": b"\xff\xd8\xff\xe0\x00\x10JFIF"}}, + } + + result_path = save_payloads(payload, output_path) + + # Verify file was created + assert result_path.exists() + + # Read the file and verify it contains valid JSON + with result_path.open("r") as f: + line = f.readline() + parsed = json.loads(line) + + # Verify structure is preserved + assert parsed["modelId"] == "test-model" + + # Verify bytes are replaced with marker object + assert "__llmeter_bytes__" in parsed["image"]["source"]["bytes"] + assert len(parsed["image"]["source"]["bytes"]) == 1 + + def test_save_payloads_file_contains_valid_json_with_markers(self): + """Test that saved file contains valid JSON with marker objects. + + Validates: Requirements 1.1, 1.3 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + payloads = [ + {"id": 1, "data": b"first bytes"}, + {"id": 2, "data": b"second bytes"}, + {"id": 3, "nested": {"data": b"nested bytes"}}, + ] + + result_path = save_payloads(payloads, output_path) + + # Read and parse each line + with result_path.open("r") as f: + lines = f.readlines() + + assert len(lines) == 3 + + # Verify each line is valid JSON + for i, line in enumerate(lines): + parsed = json.loads(line) + assert parsed["id"] == i + 1 + + # Verify marker objects exist + if i < 2: + assert "__llmeter_bytes__" in parsed["data"] + else: + assert "__llmeter_bytes__" in parsed["nested"]["data"] + + def test_save_payloads_handles_to_dict_objects(self): + """Test that payloads with to_dict() objects are serialized correctly. + + Objects implementing to_dict() are handled by LLMeterEncoder automatically. + + Validates: Requirements 5.1, 5.3 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + + class CustomObj: + def to_dict(self): + return {"custom": "value"} + + payload = {"test": "data", "obj": CustomObj()} + + result_path = save_payloads(payload, output_path) + + with result_path.open("r") as f: + line = f.readline().strip() + + parsed = json.loads(line) + assert parsed["test"] == "data" + assert parsed["obj"] == {"custom": "value"} + + def test_save_payloads_backward_compatibility_no_bytes(self): + """Test backward compatibility when payload has no bytes. + + When a payload contains no bytes objects, the output should be identical + to using standard json.dumps (no marker objects introduced). + + Validates: Requirements 3.3, 5.1 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + payload = { + "modelId": "test-model", + "prompt": "Hello world", + "maxTokens": 100, + "temperature": 0.7, + } + + # Save with new encoder + result_path = save_payloads(payload, output_path) + + # Read the saved content + with result_path.open("r") as f: + saved_line = f.readline().strip() + + # Compare with standard json.dumps + standard_json = json.dumps(payload) + + # They should be identical (no marker objects introduced) + assert saved_line == standard_json + + # Verify no marker keys exist + parsed = json.loads(saved_line) + assert "__llmeter_bytes__" not in json.dumps(parsed) + + def test_save_payloads_multiple_payloads_with_mixed_content(self): + """Test saving multiple payloads with mixed binary and non-binary content. + + Validates: Requirements 1.1, 3.3 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + payloads = [ + {"id": 1, "text": "no binary"}, + {"id": 2, "image": b"binary data"}, + {"id": 3, "text": "also no binary"}, + ] + + result_path = save_payloads(payloads, output_path) + + # Read all lines + with result_path.open("r") as f: + lines = f.readlines() + + assert len(lines) == 3 + + # First payload: no marker + parsed1 = json.loads(lines[0]) + assert "__llmeter_bytes__" not in json.dumps(parsed1) + + # Second payload: has marker + parsed2 = json.loads(lines[1]) + assert "__llmeter_bytes__" in parsed2["image"] + + # Third payload: no marker + parsed3 = json.loads(lines[2]) + assert "__llmeter_bytes__" not in json.dumps(parsed3) + # Property-based tests class TestPromptUtilsProperties: @@ -410,7 +1279,7 @@ class TestPromptUtilsProperties: @given( st.lists( st.dictionaries( - st.text(min_size=1, max_size=50), + st.text(min_size=1, max_size=50).filter(lambda k: k != "__llmeter_bytes__"), st.one_of(st.text(max_size=100), st.integers(), st.booleans()), min_size=1, max_size=10, @@ -470,3 +1339,549 @@ def test_create_collection_length_combinations(self, input_lengths, output_lengt result = collection.create_collection() expected_length = len(input_lengths) * len(output_lengths) assert len(result) == expected_length + + +class TestErrorHandling: + """Unit tests for error conditions in serialization/deserialization. + + These tests verify that errors are handled gracefully with descriptive messages. + + Requirements: 6.1, 6.2, 6.3, 6.4 + """ + + def test_unserializable_type_str_fallback(self): + """Test that unserializable types are serialized via str() fallback. + + Validates: Requirements 6.1 + """ + + class CustomUnserializableObject: + """A custom class that cannot be serialized to JSON.""" + + def __init__(self, value): + self.value = value + + payload = { + "modelId": "test-model", + "custom_object": CustomUnserializableObject(42), + } + + # The unified encoder falls back to str() for unknown types + result = json.dumps(payload, cls=LLMeterEncoder) + assert "test-model" in result + + def test_invalid_json_error_handling(self): + """Test that invalid JSON raises JSONDecodeError with descriptive message. + + Validates: Requirements 6.2 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "invalid.jsonl" + + # Write invalid JSON to file + with jsonl_file.open("w") as f: + f.write("This is not valid JSON at all\n") + + # Should handle JSONDecodeError gracefully + # load_payloads skips invalid lines, so we need to check the behavior + payloads = list(load_payloads(jsonl_file)) + + # Invalid line should be skipped + assert len(payloads) == 0 + + def test_invalid_base64_error_handling(self): + """Test that invalid base64 in marker objects raises binascii.Error. + + Validates: Requirements 6.2, 6.3 + """ + import binascii + + + # Create a marker object with invalid base64 (incorrect padding) + # Base64 strings must have length that is a multiple of 4 + # A single character will trigger binascii.Error + invalid_marker = { + "__llmeter_bytes__": "a" + } + + # Should raise binascii.Error when trying to decode + with pytest.raises(binascii.Error): + llmeter_bytes_decoder(invalid_marker) + + def test_invalid_base64_in_load_payloads(self): + """Test that invalid base64 in saved payload is handled during load. + + Validates: Requirements 6.2, 6.3 + """ + import binascii + + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "invalid_base64.jsonl" + + # Create a payload with invalid base64 in marker object + # Using a single character which will trigger binascii.Error + invalid_payload = { + "modelId": "test-model", + "image": { + "source": { + "bytes": { + "__llmeter_bytes__": "a" + } + } + }, + } + + # Write to file + with jsonl_file.open("w") as f: + f.write(json.dumps(invalid_payload) + "\n") + + # Should raise binascii.Error when trying to load + with pytest.raises(binascii.Error): + list(load_payloads(jsonl_file)) + + def test_save_payloads_encoder_error_propagation(self): + """Test that encoder errors from LLMeterEncoder are propagated correctly. + + Validates: Requirements 6.4 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + + # An object whose str() raises β€” LLMeterEncoder returns None for these, + # which is valid JSON, so no error is raised. + class FailingStr: + def __str__(self): + raise RuntimeError("Cannot convert") + + payload = {"test": FailingStr()} + result_path = save_payloads(payload, output_path) + + with result_path.open("r") as f: + parsed = json.loads(f.readline()) + assert parsed["test"] is None + + def test_invalid_json_deserialization_error(self): + """Test that invalid JSON in files produces errors gracefully. + + Validates: Requirements 6.4 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Write invalid JSON + with jsonl_file.open("w") as f: + f.write("not valid json\n") + + # load_payloads prints errors for invalid JSON lines but doesn't raise + loaded = list(load_payloads(jsonl_file)) + assert len(loaded) == 0 + + def test_unserializable_nested_object_str_fallback(self): + """Test str() fallback for unserializable object in nested structure. + + Validates: Requirements 6.1, 6.3 + """ + + class NestedCustomObject: + """A custom class for testing nested error handling.""" + + def __init__(self): + self.data = "test" + + payload = { + "modelId": "test-model", + "messages": [ + { + "role": "user", + "content": [ + {"text": "Hello"}, + {"custom": NestedCustomObject()}, + ], + } + ], + } + + # The unified encoder falls back to str() for unknown types + result = json.dumps(payload, cls=LLMeterEncoder) + assert "test-model" in result + + def test_bytes_serialization_with_encoding_error(self): + """Test that bytes serialization handles all byte values correctly. + + This test verifies that bytes with any value (0-255) can be serialized + without encoding errors. + + Validates: Requirements 1.1, 6.1 + """ + # Create bytes with all possible byte values + all_bytes = bytes(range(256)) + payload = {"data": all_bytes} + + # Should serialize without errors + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Should be valid JSON + parsed = json.loads(serialized) + + # Verify marker exists + assert "__llmeter_bytes__" in parsed["data"] + + # Verify round-trip works + + deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) + assert deserialized["data"] == all_bytes + + def test_empty_payload_serialization(self): + """Test that empty payloads are handled correctly. + + Validates: Requirements 1.1, 6.1 + """ + # Empty dict + empty_payload = {} + + # Should serialize without errors + serialized = json.dumps(empty_payload, cls=LLMeterEncoder) + assert serialized == "{}" + + # Empty list + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + empty_list = [] + + # Should handle empty list + result_path = save_payloads(empty_list, output_path) + + # File should exist but be empty + with result_path.open("r") as f: + content = f.read() + assert content == "" + + def test_none_value_serialization(self): + """Test that None values are handled correctly. + + Validates: Requirements 1.1, 1.5 + """ + payload = { + "modelId": "test-model", + "optional_field": None, + "nested": {"also_none": None}, + } + + # Should serialize without errors + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Should be valid JSON + parsed = json.loads(serialized) + + # Verify None values are preserved + assert parsed["optional_field"] is None + assert parsed["nested"]["also_none"] is None + + def test_unicode_in_payload_with_bytes(self): + """Test that Unicode strings work correctly alongside bytes. + + Validates: Requirements 1.1, 1.5 + """ + payload = { + "text": "Hello δΈ–η•Œ 🌍", + "emoji": "πŸŽ‰πŸŽŠπŸŽˆ", + "bytes": b"\xff\xfe\xfd", + } + + # Should serialize without errors + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Should be valid JSON + parsed = json.loads(serialized) + + # Verify Unicode is preserved + assert parsed["text"] == "Hello δΈ–η•Œ 🌍" + assert parsed["emoji"] == "πŸŽ‰πŸŽŠπŸŽˆ" + + # Verify bytes have marker + assert "__llmeter_bytes__" in parsed["bytes"] + + +class TestPerformance: + """Performance tests for serialization/deserialization. + + These tests verify that serialization and deserialization of large binary data + completes within acceptable time limits and doesn't create unnecessary data copies. + + Requirements: 10.1, 10.2, 10.3, 10.4 + """ + + def test_1mb_image_serialization_performance(self): + """Test that 1MB image serialization completes within 100ms. + + Validates: Requirements 10.1 + """ + import os + import time + + # Create 1MB of random binary data (simulating an image) + large_image = os.urandom(1024 * 1024) + payload = { + "modelId": "test-model", + "messages": [ + { + "role": "user", + "content": [ + {"text": "Describe this image"}, + { + "image": { + "format": "jpeg", + "source": {"bytes": large_image}, + } + }, + ], + } + ], + } + + # Measure serialization time + start_time = time.perf_counter() + serialized = json.dumps(payload, cls=LLMeterEncoder) + end_time = time.perf_counter() + + # Calculate elapsed time in milliseconds + elapsed_ms = (end_time - start_time) * 1000 + + # Verify serialization completed within 100ms + assert ( + elapsed_ms < 100 + ), f"Serialization took {elapsed_ms:.2f}ms, expected < 100ms" + + # Verify the result is valid JSON + assert isinstance(serialized, str) + parsed = json.loads(serialized) + assert "__llmeter_bytes__" in parsed["messages"][0]["content"][1]["image"][ + "source" + ]["bytes"] + + def test_1mb_image_deserialization_performance(self): + """Test that 1MB image deserialization completes within 100ms. + + Validates: Requirements 10.2 + """ + import os + import time + + # Create 1MB of random binary data + large_image = os.urandom(1024 * 1024) + payload = { + "modelId": "test-model", + "image": {"source": {"bytes": large_image}}, + } + + # First serialize the payload + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Measure deserialization time + + start_time = time.perf_counter() + deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) + end_time = time.perf_counter() + + # Calculate elapsed time in milliseconds + elapsed_ms = (end_time - start_time) * 1000 + + # Verify deserialization completed within 100ms + assert ( + elapsed_ms < 100 + ), f"Deserialization took {elapsed_ms:.2f}ms, expected < 100ms" + + # Verify the result is correct + assert isinstance(deserialized["image"]["source"]["bytes"], bytes) + assert deserialized["image"]["source"]["bytes"] == large_image + + def test_serialization_no_unnecessary_copies(self): + """Test that serialization doesn't create unnecessary data copies. + + This test verifies that the serialization process is memory-efficient + by checking that the base64 encoding is done in-place without creating + multiple intermediate copies of the binary data. + + Validates: Requirements 10.3 + """ + import os + import sys + + # Create a moderately large binary payload (512KB) + binary_data = os.urandom(512 * 1024) + payload = {"data": binary_data} + + # Get initial memory usage (approximate) + initial_size = sys.getsizeof(binary_data) + + # Serialize the payload + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Parse to verify structure + parsed = json.loads(serialized) + assert "__llmeter_bytes__" in parsed["data"] + + # The serialized string should be roughly 4/3 the size of the original + # (base64 encoding overhead) plus JSON structure overhead + # We verify it's not significantly larger (which would indicate copies) + serialized_size = sys.getsizeof(serialized) + base64_expected_size = (initial_size * 4 // 3) + 1000 # +1000 for JSON overhead + + # Allow 50% overhead for Python string internals and JSON structure + max_acceptable_size = base64_expected_size * 1.5 + + assert ( + serialized_size < max_acceptable_size + ), f"Serialized size {serialized_size} exceeds expected {max_acceptable_size}" + + def test_deserialization_no_unnecessary_copies(self): + """Test that deserialization doesn't create unnecessary data copies. + + This test verifies that the deserialization process is memory-efficient + by checking that the base64 decoding is done efficiently without creating + multiple intermediate copies of the binary data. + + Validates: Requirements 10.4 + """ + import os + import sys + + # Create a moderately large binary payload (512KB) + binary_data = os.urandom(512 * 1024) + payload = {"data": binary_data} + + # Serialize first + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Deserialize + + deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) + + # Verify the deserialized bytes match original + assert deserialized["data"] == binary_data + + # The deserialized bytes should be approximately the original size + deserialized_size = sys.getsizeof(deserialized["data"]) + original_size = sys.getsizeof(binary_data) + + # Allow small overhead for Python object internals + # bytes objects have minimal overhead + max_acceptable_size = original_size * 1.1 + + assert ( + deserialized_size < max_acceptable_size + ), f"Deserialized size {deserialized_size} exceeds expected {max_acceptable_size}" + + def test_round_trip_performance_with_multiple_images(self): + """Test round-trip performance with multiple large images. + + This test verifies that serialization and deserialization remain performant + even with multiple large binary objects in a single payload. + + Validates: Requirements 10.1, 10.2 + """ + import os + import time + + # Create payload with 3 images of 512KB each (total ~1.5MB) + images = [os.urandom(512 * 1024) for _ in range(3)] + payload = { + "modelId": "test-model", + "images": [ + {"id": i, "data": img} for i, img in enumerate(images) + ], + } + + # Measure serialization time + start_time = time.perf_counter() + serialized = json.dumps(payload, cls=LLMeterEncoder) + serialize_time = (time.perf_counter() - start_time) * 1000 + + # Measure deserialization time + + start_time = time.perf_counter() + deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) + deserialize_time = (time.perf_counter() - start_time) * 1000 + + # With 3 images, we allow proportionally more time (but still reasonable) + # Each operation should complete in under 200ms for 1.5MB total + assert ( + serialize_time < 200 + ), f"Serialization took {serialize_time:.2f}ms, expected < 200ms" + assert ( + deserialize_time < 200 + ), f"Deserialization took {deserialize_time:.2f}ms, expected < 200ms" + + # Verify correctness + assert len(deserialized["images"]) == 3 + for i, img in enumerate(images): + assert deserialized["images"][i]["data"] == img + + def test_serialization_performance_scales_linearly(self): + """Test that serialization performance scales linearly with data size. + + This test verifies that doubling the data size roughly doubles the time, + indicating no algorithmic inefficiencies. + + Validates: Requirements 10.1, 10.3 + """ + import os + import time + + # Test with 256KB + small_data = os.urandom(256 * 1024) + small_payload = {"data": small_data} + + start_time = time.perf_counter() + json.dumps(small_payload, cls=LLMeterEncoder) + small_time = time.perf_counter() - start_time + + # Test with 512KB (2x size) + large_data = os.urandom(512 * 1024) + large_payload = {"data": large_data} + + start_time = time.perf_counter() + json.dumps(large_payload, cls=LLMeterEncoder) + large_time = time.perf_counter() - start_time + + # Large should take roughly 2x the time (allow 3x for variance) + # This verifies linear scaling, not quadratic or worse + assert ( + large_time < small_time * 3 + ), f"Performance doesn't scale linearly: {small_time:.4f}s vs {large_time:.4f}s" + + def test_deserialization_performance_scales_linearly(self): + """Test that deserialization performance scales linearly with data size. + + This test verifies that doubling the data size roughly doubles the time, + indicating no algorithmic inefficiencies. + + Validates: Requirements 10.2, 10.4 + """ + import os + import time + + + # Test with 256KB + small_data = os.urandom(256 * 1024) + small_payload = {"data": small_data} + small_serialized = json.dumps(small_payload, cls=LLMeterEncoder) + + start_time = time.perf_counter() + json.loads(small_serialized, object_hook=llmeter_bytes_decoder) + small_time = time.perf_counter() - start_time + + # Test with 512KB (2x size) + large_data = os.urandom(512 * 1024) + large_payload = {"data": large_data} + large_serialized = json.dumps(large_payload, cls=LLMeterEncoder) + + start_time = time.perf_counter() + json.loads(large_serialized, object_hook=llmeter_bytes_decoder) + large_time = time.perf_counter() - start_time + + # Large should take roughly 2x the time (allow 3x for variance) + # This verifies linear scaling, not quadratic or worse + assert ( + large_time < small_time * 3 + ), f"Performance doesn't scale linearly: {small_time:.4f}s vs {large_time:.4f}s" diff --git a/tests/unit/test_property_save_load.py b/tests/unit/test_property_save_load.py index 5b0a7c9..7b96dc7 100644 --- a/tests/unit/test_property_save_load.py +++ b/tests/unit/test_property_save_load.py @@ -156,7 +156,7 @@ class TestPayloadSaveLoadProperties: @given( st.lists( st.dictionaries( - st.text(min_size=1, max_size=50), + st.text(min_size=1, max_size=50).filter(lambda x: x != "__llmeter_bytes__"), st.one_of( st.text(max_size=200), st.integers(), @@ -172,7 +172,11 @@ class TestPayloadSaveLoadProperties: ) @settings(deadline=None) def test_save_load_payloads_preserves_data(self, payloads): - """Save/load roundtrip should preserve all payload data.""" + """Save/load roundtrip should preserve all payload data. + + Note: Excludes the reserved key '__llmeter_bytes__' which is used internally + for binary content serialization. + """ with tempfile.TemporaryDirectory() as tmpdir: output_path = Path(tmpdir) diff --git a/tests/unit/test_results.py b/tests/unit/test_results.py index 9262d16..c2da0a8 100644 --- a/tests/unit/test_results.py +++ b/tests/unit/test_results.py @@ -217,6 +217,27 @@ def test_stats_property_empty_result(): assert f"{metric}-{stat}" not in stats +def test_stats_json_serializable_with_datetimes(): + """stats dict should be directly JSON-serializable even with datetime fields.""" + from datetime import datetime, timezone + + result = Result( + responses=sample_responses_successful[:2], + total_requests=2, + clients=1, + n_requests=2, + total_test_time=1.0, + start_time=datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + end_time=datetime(2025, 1, 1, 12, 0, 1, tzinfo=timezone.utc), + ) + stats = result.stats + # Should not raise TypeError + json_str = json.dumps(stats) + parsed = json.loads(json_str) + assert parsed["start_time"] == "2025-01-01T12:00:00Z" + assert parsed["end_time"] == "2025-01-01T12:00:01Z" + + @pytest.fixture def temp_dir(tmp_path: Path): return UPath(tmp_path) @@ -306,3 +327,278 @@ def test_save_method_existing_responses(sample_result: Result, temp_dir: UPath): responses = [json.loads(line) for line in f] assert len(responses) == 6 # 5 original + 1 extra assert responses[-1]["id"] == "extra_response" + + +# Tests for LLMeterEncoder +# Validates Requirements: 7.1, 7.2 + + +def test_llmeter_encoder_handles_bytes(): + """Test that LLMeterEncoder handles bytes objects via base64 marker. + + Validates: Requirements 7.1, 7.2 + """ + from llmeter.json_utils import LLMeterEncoder + + # Test bytes object encoding + payload = {"image": {"bytes": b"\xff\xd8\xff\xe0"}} + encoded = json.dumps(payload, cls=LLMeterEncoder) + + # Verify marker object is present + assert "__llmeter_bytes__" in encoded + + # Verify it can be decoded back + decoded = json.loads(encoded, object_hook=lambda d: d) + assert decoded["image"]["bytes"]["__llmeter_bytes__"] == "/9j/4A==" + + +def test_llmeter_encoder_str_fallback(): + """Test that LLMeterEncoder falls back to str() for custom objects. + + Validates: Requirements 7.1, 7.2 + """ + from llmeter.json_utils import LLMeterEncoder + + # Create a custom object with __str__ method + class CustomObject: + def __str__(self): + return "custom_string_representation" + + payload = {"custom": CustomObject()} + encoded = json.dumps(payload, cls=LLMeterEncoder) + + # Verify str() fallback was used + decoded = json.loads(encoded) + assert decoded["custom"] == "custom_string_representation" + + +def test_llmeter_encoder_none_on_str_failure(): + """Test that LLMeterEncoder returns None when str() conversion fails. + + Validates: Requirements 7.1, 7.2 + """ + from llmeter.json_utils import LLMeterEncoder + + # Create a custom object that raises exception in __str__ + class FailingObject: + def __str__(self): + raise RuntimeError("Cannot convert to string") + + payload = {"failing": FailingObject()} + encoded = json.dumps(payload, cls=LLMeterEncoder) + + # Verify None was returned + decoded = json.loads(encoded) + assert decoded["failing"] is None + + +def test_llmeter_encoder_mixed_types(): + """Test that LLMeterEncoder handles mixed types correctly. + + Validates: Requirements 7.1, 7.2 + """ + from llmeter.json_utils import LLMeterEncoder + + class CustomObject: + def __str__(self): + return "custom" + + # Mix of bytes, custom objects, and standard types + payload = { + "bytes_field": b"\x00\x01\x02", + "custom_field": CustomObject(), + "string_field": "normal string", + "int_field": 42, + "nested": { + "bytes": b"\xff\xfe", + "custom": CustomObject() + } + } + + encoded = json.dumps(payload, cls=LLMeterEncoder) + decoded = json.loads(encoded) + + # Verify bytes were encoded with marker + assert "__llmeter_bytes__" in encoded + + # Verify custom objects were converted to strings + assert decoded["custom_field"] == "custom" + assert decoded["nested"]["custom"] == "custom" + + # Verify standard types remain unchanged + assert decoded["string_field"] == "normal string" + assert decoded["int_field"] == 42 + + +# Tests for InvocationResponse.to_json +# Validates Requirements: 7.1, 7.2, 7.3 + + +def test_invocation_response_to_json_with_binary_content(): + """Test InvocationResponse.to_json with binary content in input_payload. + + Validates: Requirements 7.1, 7.2 + """ + # Create InvocationResponse with binary content in input_payload + response = InvocationResponse( + response_text="This is an image of a cat", + input_payload={ + "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "messages": [{ + "role": "user", + "content": [ + {"text": "What is in this image?"}, + { + "image": { + "format": "jpeg", + "source": {"bytes": b"\xff\xd8\xff\xe0\x00\x10JFIF"} + } + } + ] + }] + }, + id="test-123", + time_to_first_token=0.5, + time_to_last_token=1.2, + num_tokens_input=15, + num_tokens_output=8 + ) + + # Serialize to JSON + json_str = response.to_json() + + # Verify it's valid JSON + parsed = json.loads(json_str) + + # Verify marker object is present for bytes + assert "__llmeter_bytes__" in json_str + + # Verify structure is preserved + assert parsed["response_text"] == "This is an image of a cat" + assert parsed["id"] == "test-123" + assert parsed["input_payload"]["modelId"] == "anthropic.claude-3-haiku-20240307-v1:0" + + # Verify bytes were encoded with marker + bytes_marker = parsed["input_payload"]["messages"][0]["content"][1]["image"]["source"]["bytes"] + assert "__llmeter_bytes__" in bytes_marker + assert bytes_marker["__llmeter_bytes__"] == "/9j/4AAQSkZJRg==" + + +def test_invocation_response_to_json_valid_json_output(): + """Test that InvocationResponse.to_json produces valid JSON. + + Validates: Requirements 7.1, 7.2 + """ + response = InvocationResponse( + response_text="Test response", + input_payload={ + "image_data": b"\x89PNG\r\n\x1a\n", + "text": "Analyze this image" + }, + id="test-456" + ) + + # Serialize to JSON + json_str = response.to_json() + + # Verify it's valid JSON (should not raise exception) + parsed = json.loads(json_str) + + # Verify all fields are present + assert "response_text" in parsed + assert "input_payload" in parsed + assert "id" in parsed + + # Verify bytes were properly encoded + assert isinstance(parsed["input_payload"]["image_data"], dict) + assert "__llmeter_bytes__" in parsed["input_payload"]["image_data"] + + +def test_invocation_response_to_json_round_trip(): + """Test round-trip serialization/deserialization with InvocationResponse. + + Validates: Requirements 7.1, 7.2, 7.3 + """ + from llmeter.json_utils import llmeter_bytes_decoder + + # Create original response with binary content + original_payload = { + "video": { + "format": "mp4", + "source": {"bytes": b"\x00\x00\x00\x18ftypmp42"} + }, + "prompt": "Describe this video" + } + + response = InvocationResponse( + response_text="A video of a sunset", + input_payload=original_payload, + id="test-789", + time_to_first_token=1.0, + time_to_last_token=2.5 + ) + + # Serialize to JSON + json_str = response.to_json() + + # Deserialize back + parsed = json.loads(json_str, object_hook=llmeter_bytes_decoder) + + # Verify input_payload was restored correctly + assert parsed["input_payload"] == original_payload + assert isinstance(parsed["input_payload"]["video"]["source"]["bytes"], bytes) + assert parsed["input_payload"]["video"]["source"]["bytes"] == b"\x00\x00\x00\x18ftypmp42" + + # Verify other fields + assert parsed["response_text"] == "A video of a sunset" + assert parsed["id"] == "test-789" + assert parsed["time_to_first_token"] == 1.0 + assert parsed["time_to_last_token"] == 2.5 + + +def test_invocation_response_to_json_no_binary_content(): + """Test InvocationResponse.to_json with no binary content (backward compatibility). + + Validates: Requirements 7.1, 7.2 + """ + response = InvocationResponse( + response_text="Simple text response", + input_payload={ + "modelId": "test-model", + "prompt": "Hello, world!" + }, + id="test-no-binary" + ) + + # Serialize to JSON + json_str = response.to_json() + + # Verify no marker objects are present + assert "__llmeter_bytes__" not in json_str + + # Verify it's valid JSON + parsed = json.loads(json_str) + assert parsed["input_payload"]["prompt"] == "Hello, world!" + + +def test_invocation_response_to_json_with_kwargs(): + """Test that InvocationResponse.to_json passes through kwargs to json.dumps. + + Validates: Requirements 7.4 + """ + response = InvocationResponse( + response_text="Test", + input_payload={"data": b"\x01\x02\x03"}, + id="test-kwargs" + ) + + # Test with indent parameter + json_str = response.to_json(indent=2) + + # Verify indentation is present + assert "\n" in json_str + assert " " in json_str + + # Verify it's still valid JSON + parsed = json.loads(json_str) + assert parsed["id"] == "test-kwargs" diff --git a/tests/unit/test_serialization_properties.py b/tests/unit/test_serialization_properties.py new file mode 100644 index 0000000..637b082 --- /dev/null +++ b/tests/unit/test_serialization_properties.py @@ -0,0 +1,861 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Property-based tests for LLMeterEncoder and llmeter_bytes_decoder. + +This module contains property-based tests using Hypothesis to verify that +LLMeterEncoder correctly handles all supported types (bytes, datetime, date, +time, PathLike, to_dict() objects, and str() fallback) and that +llmeter_bytes_decoder restores binary content from marker objects. + +Feature: json-serialization-optimization +""" + +import base64 +import json +from datetime import date, datetime, time, timezone +from pathlib import PurePosixPath + +from hypothesis import given, settings +from hypothesis import strategies as st +from hypothesis.strategies import composite + +from llmeter.json_utils import LLMeterEncoder, llmeter_bytes_decoder + +# Test infrastructure is set up and ready for property test implementation +# This file will contain property-based tests for: +# - Property 1: Serialization produces valid JSON with marker objects +# - Property 2: Serialization preserves non-binary structure +# - Property 3: Deserialization restores bytes from markers +# - Property 4: Deserialization preserves non-marker dicts +# - Property 5: Round-trip serialization preserves data integrity +# - Property 6: Round-trip preserves dictionary key ordering +# - Property 7: Non-binary payloads are backward compatible +# - Property 8: Serialization errors are descriptive +# - Property 9: Deserialization errors are descriptive +# - Property 10: InvocationResponse serializes binary payloads + + +# Custom strategies for generating test data +@composite +def json_value_strategy(draw, allow_bytes=True): + """Generate JSON-compatible values, optionally including bytes.""" + base_types = st.one_of( + st.none(), + st.booleans(), + st.integers(), + st.floats(allow_nan=False, allow_infinity=False), + st.text(max_size=100), + ) + + if allow_bytes: + return draw(st.one_of(base_types, st.binary(max_size=1000))) + return draw(base_types) + + +@composite +def nested_dict_strategy(draw, max_depth=3, allow_bytes=True): + """Generate nested dictionary structures with optional bytes objects.""" + if max_depth == 0: + return draw(json_value_strategy(allow_bytes=allow_bytes)) + + return draw( + st.dictionaries( + keys=st.text(min_size=1, max_size=20).filter(lambda k: k != "__llmeter_bytes__"), + values=st.one_of( + json_value_strategy(allow_bytes=allow_bytes), + nested_dict_strategy(max_depth=max_depth - 1, allow_bytes=allow_bytes), + st.lists(json_value_strategy(allow_bytes=allow_bytes), max_size=5), + ), + max_size=5, + ) + ) + + +@composite +def payload_with_bytes_strategy(draw): + """Generate payloads that contain at least one bytes object.""" + # Start with a nested dict that may or may not have bytes + payload = draw(nested_dict_strategy(max_depth=3, allow_bytes=True)) + + # Ensure at least one bytes object exists + # Add a guaranteed bytes field + payload["_test_bytes"] = draw(st.binary(min_size=1, max_size=500)) + + return payload + + +@composite +def payload_without_bytes_strategy(draw): + """Generate payloads that contain no bytes objects.""" + return draw(nested_dict_strategy(max_depth=3, allow_bytes=False)) + + +class TestSerializationProperties: + """Property-based tests for serialization behavior.""" + + @given(payload_with_bytes_strategy()) + @settings(max_examples=100) + def test_property_1_serialization_produces_valid_json_with_marker_objects( + self, payload + ): + """Property 1: Serialization produces valid JSON with marker objects. + + **Validates: Requirements 1.1, 1.3** + + For any payload containing bytes objects at any nesting level, serializing + the payload SHALL produce valid JSON where each bytes object is replaced + with a marker object containing the key "__llmeter_bytes__" and a + base64-encoded string value. + """ + # Serialize the payload + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Verify it's valid JSON by parsing it + parsed = json.loads(serialized) + assert isinstance(parsed, dict) + + # Verify that marker objects exist in the serialized string + assert "__llmeter_bytes__" in serialized + + # Helper function to verify marker objects in nested structures + def verify_markers(obj): + """Recursively verify that bytes are replaced with marker objects.""" + if isinstance(obj, dict): + # Check if this is a marker object + if "__llmeter_bytes__" in obj: + # Verify it's a valid marker object + assert len(obj) == 1, "Marker object should have only one key" + base64_str = obj["__llmeter_bytes__"] + assert isinstance(base64_str, str), ( + "Marker value should be a string" + ) + # Verify it's valid base64 by attempting to decode + try: + base64.b64decode(base64_str) + except Exception as e: + raise AssertionError(f"Invalid base64 in marker object: {e}") + else: + # Recursively check nested structures + for value in obj.values(): + verify_markers(value) + elif isinstance(obj, list): + for item in obj: + verify_markers(item) + + # Verify all marker objects in the parsed structure + verify_markers(parsed) + + +class TestDeserializationProperties: + """Property-based tests for deserialization behavior.""" + + @given(st.binary(min_size=0, max_size=1000)) + @settings(max_examples=100) + def test_property_3_deserialization_restores_bytes_from_markers( + self, original_bytes + ): + """Property 3: Deserialization restores bytes from markers. + + **Validates: Requirements 2.1** + + For any JSON string containing marker objects with the "__llmeter_bytes__" + key, deserializing SHALL convert each marker object back to the original + bytes object by base64-decoding the string value. + """ + + # Create a marker object from the original bytes + base64_str = base64.b64encode(original_bytes).decode("utf-8") + marker_object = {"__llmeter_bytes__": base64_str} + + # Decode the marker object + decoded_bytes = llmeter_bytes_decoder(marker_object) + + # Verify the decoded bytes match the original + assert isinstance(decoded_bytes, bytes), "Decoder should return bytes" + assert decoded_bytes == original_bytes, ( + "Decoded bytes should match original bytes" + ) + + @given(payload_without_bytes_strategy()) + @settings(max_examples=100) + def test_property_4_deserialization_preserves_non_marker_dicts(self, payload): + """Property 4: Deserialization preserves non-marker dicts. + + **Validates: Requirements 2.4** + + For any payload containing dictionaries without the "__llmeter_bytes__" + marker key, deserializing SHALL return those dictionaries unchanged. + """ + + # Helper function to recursively apply decoder to all dicts + def apply_decoder_recursively(obj): + """Apply decoder to all dicts in the structure.""" + if isinstance(obj, dict): + # Apply decoder to this dict + decoded = llmeter_bytes_decoder(obj) + # If it's still a dict (not converted to bytes), recurse + if isinstance(decoded, dict): + return {k: apply_decoder_recursively(v) for k, v in decoded.items()} + return decoded + elif isinstance(obj, list): + return [apply_decoder_recursively(item) for item in obj] + else: + return obj + + # Apply decoder recursively to the entire payload + decoded_payload = apply_decoder_recursively(payload) + + # Verify the payload is unchanged + assert decoded_payload == payload, ( + "Decoder should return non-marker dicts unchanged" + ) + + # Verify no bytes objects were introduced + def verify_no_bytes(obj): + """Recursively verify no bytes objects exist.""" + if isinstance(obj, bytes): + raise AssertionError("Decoder introduced bytes object unexpectedly") + elif isinstance(obj, dict): + for value in obj.values(): + verify_no_bytes(value) + elif isinstance(obj, list): + for item in obj: + verify_no_bytes(item) + + verify_no_bytes(decoded_payload) + + +class TestRoundTripProperties: + """Property-based tests for round-trip serialization integrity.""" + + @given(payload_with_bytes_strategy()) + @settings(max_examples=100) + def test_property_2_serialization_preserves_non_binary_structure(self, payload): + """Property 2: Serialization preserves non-binary structure. + + **Validates: Requirements 1.5** + + For any payload structure (keys, nesting levels, value types except bytes), + serializing then parsing the JSON SHALL preserve the structure identically, + with only bytes objects replaced by marker objects. + + # Feature: json-serialization-optimization, Property 2: Serialization preserves non-binary structure + """ + # Serialize the payload + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Parse the JSON (without decoding markers back to bytes) + parsed = json.loads(serialized) + + # Helper function to verify structure preservation + def verify_structure(original, parsed_obj, path=""): + """Recursively verify that structure is preserved except for bytes.""" + if isinstance(original, bytes): + # Bytes should be replaced with marker object + assert isinstance(parsed_obj, dict), ( + f"At {path}: Expected marker dict for bytes, got {type(parsed_obj)}" + ) + assert "__llmeter_bytes__" in parsed_obj, ( + f"At {path}: Expected marker object for bytes" + ) + assert len(parsed_obj) == 1, ( + f"At {path}: Marker object should have only one key" + ) + # Verify the base64 string can be decoded back to original bytes + decoded = base64.b64decode(parsed_obj["__llmeter_bytes__"]) + assert decoded == original, ( + f"At {path}: Decoded bytes don't match original" + ) + elif isinstance(original, dict): + # Dict structure should be preserved + assert isinstance(parsed_obj, dict), ( + f"At {path}: Expected dict, got {type(parsed_obj)}" + ) + assert set(original.keys()) == set(parsed_obj.keys()), ( + f"At {path}: Dict keys differ. Original: {set(original.keys())}, " + f"Parsed: {set(parsed_obj.keys())}" + ) + # Verify nesting is preserved + for key in original.keys(): + verify_structure( + original[key], + parsed_obj[key], + f"{path}.{key}" if path else key + ) + elif isinstance(original, list): + # List structure should be preserved + assert isinstance(parsed_obj, list), ( + f"At {path}: Expected list, got {type(parsed_obj)}" + ) + assert len(original) == len(parsed_obj), ( + f"At {path}: List lengths differ. Original: {len(original)}, " + f"Parsed: {len(parsed_obj)}" + ) + # Verify each element + for i, (orig_item, parsed_item) in enumerate(zip(original, parsed_obj)): + verify_structure( + orig_item, + parsed_item, + f"{path}[{i}]" + ) + else: + # Primitive types should be preserved exactly + # Note: We need exact type matching here, not isinstance checks + assert type(original) is type(parsed_obj), ( # noqa: E721 + f"At {path}: Type mismatch. Original: {type(original)}, " + f"Parsed: {type(parsed_obj)}" + ) + assert original == parsed_obj, ( + f"At {path}: Values differ. Original: {original}, " + f"Parsed: {parsed_obj}" + ) + + # Verify structure preservation throughout + verify_structure(payload, parsed) + + @given(payload_with_bytes_strategy()) + @settings(max_examples=100) + def test_property_5_round_trip_serialization_preserves_data_integrity( + self, payload + ): + """Property 5: Round-trip serialization preserves data integrity. + + **Validates: Requirements 4.1, 2.5** + + For any valid payload with binary content, the property + deserialize(serialize(payload)) == payload SHALL hold, preserving + byte-for-byte equality of all bytes objects and exact equality of + all other values. + """ + + # Serialize the payload + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Deserialize the payload + deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) + + # Verify round-trip equality + assert deserialized == payload, ( + "Round-trip serialization should preserve data integrity" + ) + + # Helper function to verify byte-for-byte equality of bytes objects + def verify_bytes_equality(original, restored, path=""): + """Recursively verify that bytes objects are byte-for-byte equal.""" + if isinstance(original, bytes): + assert isinstance(restored, bytes), ( + f"At {path}: Expected bytes, got {type(restored)}" + ) + assert original == restored, ( + f"At {path}: Bytes objects differ" + ) + elif isinstance(original, dict): + assert isinstance(restored, dict), ( + f"At {path}: Expected dict, got {type(restored)}" + ) + assert set(original.keys()) == set(restored.keys()), ( + f"At {path}: Dict keys differ" + ) + for key in original.keys(): + verify_bytes_equality( + original[key], + restored[key], + f"{path}.{key}" if path else key + ) + elif isinstance(original, list): + assert isinstance(restored, list), ( + f"At {path}: Expected list, got {type(restored)}" + ) + assert len(original) == len(restored), ( + f"At {path}: List lengths differ" + ) + for i, (orig_item, rest_item) in enumerate(zip(original, restored)): + verify_bytes_equality( + orig_item, + rest_item, + f"{path}[{i}]" + ) + else: + # For primitive types, equality check is sufficient + assert original == restored, ( + f"At {path}: Values differ: {original} != {restored}" + ) + + # Verify byte-for-byte equality throughout the structure + verify_bytes_equality(payload, deserialized) + + + @given(payload_with_bytes_strategy()) + @settings(max_examples=100) + def test_property_6_round_trip_preserves_dictionary_key_ordering(self, payload): + """Property 6: Round-trip preserves dictionary key ordering. + + **Validates: Requirements 4.4** + + For any payload with ordered dictionaries, round-trip serialization SHALL + preserve the insertion order of dictionary keys. + + # Feature: json-serialization-optimization, Property 6: Round-trip preserves dictionary key ordering + """ + + # Serialize the payload + serialized = json.dumps(payload, cls=LLMeterEncoder) + + # Deserialize the payload + deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) + + # Helper function to verify key ordering + def verify_key_ordering(original, restored, path=""): + """Recursively verify that dictionary key ordering is preserved.""" + if isinstance(original, dict): + assert isinstance(restored, dict), ( + f"At {path}: Expected dict, got {type(restored)}" + ) + + # Get the keys as lists to preserve order + original_keys = list(original.keys()) + restored_keys = list(restored.keys()) + + # Verify the keys are in the same order + assert original_keys == restored_keys, ( + f"At {path}: Key ordering differs. " + f"Original: {original_keys}, Restored: {restored_keys}" + ) + + # Recursively verify nested structures + for key in original_keys: + verify_key_ordering( + original[key], + restored[key], + f"{path}.{key}" if path else key + ) + elif isinstance(original, list): + assert isinstance(restored, list), ( + f"At {path}: Expected list, got {type(restored)}" + ) + assert len(original) == len(restored), ( + f"At {path}: List lengths differ" + ) + # Verify each element + for i, (orig_item, rest_item) in enumerate(zip(original, restored)): + verify_key_ordering( + orig_item, + rest_item, + f"{path}[{i}]" + ) + # For non-dict, non-list types, no key ordering to verify + + # Verify key ordering is preserved throughout the structure + verify_key_ordering(payload, deserialized) + + +class TestBackwardCompatibilityProperties: + """Property-based tests for backward compatibility.""" + + @given(payload_without_bytes_strategy()) + @settings(max_examples=100) + def test_property_7_non_binary_payloads_are_backward_compatible(self, payload): + """Property 7: Non-binary payloads are backward compatible. + + **Validates: Requirements 3.2, 3.3** + + For any payload containing no bytes objects, serializing with the new + encoder SHALL produce output identical to serializing with the standard + json.dumps (no marker objects introduced). + + # Feature: json-serialization-optimization, Property 7: Non-binary payloads are backward compatible + """ + # Serialize with the new encoder + serialized_with_encoder = json.dumps(payload, cls=LLMeterEncoder) + + # Serialize with standard json.dumps + serialized_standard = json.dumps(payload) + + # Verify they produce identical output + assert serialized_with_encoder == serialized_standard, ( + "Serialization with LLMeterEncoder should produce identical output " + "to standard json.dumps for payloads without bytes objects" + ) + + # Verify no marker objects were introduced + assert "__llmeter_bytes__" not in serialized_with_encoder, ( + "No marker objects should be introduced for payloads without bytes" + ) + + # Verify both can be parsed identically + parsed_encoder = json.loads(serialized_with_encoder) + parsed_standard = json.loads(serialized_standard) + assert parsed_encoder == parsed_standard, ( + "Parsed output should be identical for both serialization methods" + ) + + # Verify the parsed output matches the original payload + assert parsed_encoder == payload, ( + "Parsed output should match the original payload" + ) + + +class TestErrorHandlingProperties: + """Property-based tests for error handling.""" + + @composite + def unserializable_object_strategy(draw): + """Generate objects that are not JSON serializable and not bytes.""" + # Create various types of unserializable objects + unserializable_types = [ + # Custom class instance + lambda: type('CustomClass', (), {})(), + # Function + lambda: lambda x: x, + # Set (not JSON serializable) + lambda: {1, 2, 3}, + # Complex number + lambda: complex(1, 2), + # Object with __dict__ + lambda: type('ObjWithDict', (), {'attr': 'value'})(), + ] + + # Choose one of the unserializable types + return draw(st.sampled_from(unserializable_types))() + + @given(st.data()) + @settings(max_examples=100) + def test_property_8_serialization_handles_unknown_types(self, data): + """Property 8: Serialization handles unknown types via str() fallback. + + **Validates: Requirements 6.1** + + For any payload containing unserializable types (not bytes, not standard + JSON types), serialization SHALL succeed by falling back to str() + representation, producing valid JSON output. + + # Feature: json-serialization-optimization, Property 8: Serialization handles unknown types + """ + # Generate a payload with an unserializable object + unserializable_obj = data.draw( + TestErrorHandlingProperties.unserializable_object_strategy() + ) + + # Create a payload containing the unserializable object + # We'll place it at various locations in the structure + placement_strategy = st.sampled_from([ + # Direct value + lambda obj: {"unserializable": obj}, + # Nested in dict + lambda obj: {"outer": {"inner": {"unserializable": obj}}}, + # In a list + lambda obj: {"items": [1, 2, obj, 4]}, + # Mixed structure + lambda obj: {"data": {"list": [{"nested": obj}]}}, + ]) + + payload_creator = data.draw(placement_strategy) + payload = payload_creator(unserializable_obj) + + # The unified encoder falls back to str() for unknown types + result = json.dumps(payload, cls=LLMeterEncoder) + # Result should be valid JSON + parsed = json.loads(result) + assert isinstance(parsed, dict) + + + @given(st.data()) + @settings(max_examples=100) + def test_property_9_deserialization_errors_are_descriptive(self, data): + """Property 9: Deserialization errors are descriptive. + + **Validates: Requirements 6.2** + + For any invalid JSON string or JSON containing marker objects with invalid + base64 strings, attempting to deserialize SHALL raise an appropriate + exception (JSONDecodeError or binascii.Error) with a descriptive message. + + # Feature: json-serialization-optimization, Property 9: Deserialization errors are descriptive + """ + + # Test invalid JSON strings that will raise JSONDecodeError + # Note: base64.b64decode() is lenient by default and accepts many inputs, + # so we focus on JSON parsing errors which are more common in practice + + # Generate invalid JSON strings that will definitely fail parsing + invalid_json_strategy = st.sampled_from([ + "{invalid json}", + '{"key": undefined}', + "{'single': 'quotes'}", + '{"unclosed": ', + '{"trailing": "comma",}', + 'not json at all', + '["unclosed array"', + '}invalid start{', + '{"double""quotes"}', + '{"key": value}', # unquoted value + '[1, 2, 3,]', # trailing comma in array + ]) + invalid_json = data.draw(invalid_json_strategy) + + # Attempt to deserialize and verify it raises JSONDecodeError + try: + json.loads(invalid_json, object_hook=llmeter_bytes_decoder) + # If we get here, deserialization succeeded when it shouldn't have + raise AssertionError( + f"Expected JSONDecodeError for invalid JSON: {invalid_json}" + ) + except json.JSONDecodeError as e: + # Verify the error message is descriptive + error_msg = str(e) + + # The error message should indicate it's a JSON parsing error + # JSONDecodeError messages typically contain position information + assert len(error_msg) > 0, ( + "Error message should not be empty" + ) + + # Verify the exception has the expected attributes + assert hasattr(e, 'msg'), "JSONDecodeError should have 'msg' attribute" + assert hasattr(e, 'lineno'), "JSONDecodeError should have 'lineno' attribute" + assert hasattr(e, 'colno'), "JSONDecodeError should have 'colno' attribute" + + # Verify the error message contains useful information + # It should mention what went wrong + assert e.msg, "JSONDecodeError msg should not be empty" + + +class TestInvocationResponseProperties: + """Property-based tests for InvocationResponse serialization.""" + + @given(nested_dict_strategy(max_depth=3, allow_bytes=True)) + @settings(max_examples=100, deadline=None) + def test_property_10_invocation_response_serializes_binary_payloads( + self, input_payload + ): + """Property 10: InvocationResponse serializes binary payloads. + + **Validates: Requirements 7.1** + + For any InvocationResponse object where the input_payload field contains + bytes objects, calling to_json() SHALL produce valid JSON with marker + objects replacing the bytes. + + # Feature: json-serialization-optimization, Property 10: InvocationResponse serializes binary payloads + """ + from llmeter.endpoints.base import InvocationResponse + + # Create an InvocationResponse with the generated input_payload + response = InvocationResponse( + response_text="Test response", + input_payload=input_payload, + id="test-id", + time_to_first_token=0.1, + time_to_last_token=0.5, + num_tokens_input=10, + num_tokens_output=20, + ) + + # Serialize using to_json() + json_str = response.to_json() + + # Verify it's valid JSON by parsing it + parsed = json.loads(json_str) + assert isinstance(parsed, dict), "to_json() should produce a valid JSON object" + + # Verify the input_payload field exists + assert "input_payload" in parsed, "Serialized response should contain input_payload" + + # Helper function to check for bytes in original and markers in serialized + def has_bytes(obj): + """Recursively check if object contains bytes.""" + if isinstance(obj, bytes): + return True + elif isinstance(obj, dict): + return any(has_bytes(v) for v in obj.values()) + elif isinstance(obj, list): + return any(has_bytes(item) for item in obj) + return False + + def verify_bytes_serialized(original, serialized_obj, path=""): + """Recursively verify bytes are replaced with marker objects.""" + if isinstance(original, bytes): + # Bytes should be replaced with marker object + assert isinstance(serialized_obj, dict), ( + f"At {path}: Expected marker dict for bytes, got {type(serialized_obj)}" + ) + assert "__llmeter_bytes__" in serialized_obj, ( + f"At {path}: Expected marker object for bytes" + ) + assert len(serialized_obj) == 1, ( + f"At {path}: Marker object should have only one key" + ) + # Verify the base64 string can be decoded back to original bytes + base64_str = serialized_obj["__llmeter_bytes__"] + assert isinstance(base64_str, str), ( + f"At {path}: Marker value should be a string" + ) + decoded = base64.b64decode(base64_str) + assert decoded == original, ( + f"At {path}: Decoded bytes don't match original" + ) + elif isinstance(original, dict): + assert isinstance(serialized_obj, dict), ( + f"At {path}: Expected dict, got {type(serialized_obj)}" + ) + # Verify all keys are present + for key in original.keys(): + assert key in serialized_obj, ( + f"At {path}: Key '{key}' missing in serialized object" + ) + verify_bytes_serialized( + original[key], + serialized_obj[key], + f"{path}.{key}" if path else key + ) + elif isinstance(original, list): + assert isinstance(serialized_obj, list), ( + f"At {path}: Expected list, got {type(serialized_obj)}" + ) + assert len(original) == len(serialized_obj), ( + f"At {path}: List lengths differ" + ) + for i, (orig_item, ser_item) in enumerate(zip(original, serialized_obj)): + verify_bytes_serialized( + orig_item, + ser_item, + f"{path}[{i}]" + ) + # For other types, no special verification needed + + # If the input_payload contains bytes, verify they're serialized correctly + if has_bytes(input_payload): + assert "__llmeter_bytes__" in json_str, ( + "Serialized JSON should contain marker objects when input_payload has bytes" + ) + verify_bytes_serialized(input_payload, parsed["input_payload"], "input_payload") + + # Verify other fields are serialized correctly + assert parsed["response_text"] == "Test response" + assert parsed["id"] == "test-id" + assert parsed["time_to_first_token"] == 0.1 + assert parsed["time_to_last_token"] == 0.5 + assert parsed["num_tokens_input"] == 10 + assert parsed["num_tokens_output"] == 20 + + +# --------------------------------------------------------------------------- +# Strategies for the extended type tests +# --------------------------------------------------------------------------- + +# datetime strategy: aware and naive datetimes +_datetime_strategy = st.one_of( + # Naive datetimes + st.datetimes(min_value=datetime(2000, 1, 1), max_value=datetime(2030, 12, 31)), + # UTC-aware datetimes + st.datetimes( + min_value=datetime(2000, 1, 1), + max_value=datetime(2030, 12, 31), + timezones=st.just(timezone.utc), + ), +) + +_date_strategy = st.dates(min_value=date(2000, 1, 1), max_value=date(2030, 12, 31)) + +_time_strategy = st.times() + +_path_strategy = st.from_regex(r"[a-z][a-z0-9_/]{0,30}", fullmatch=True).map( + PurePosixPath +) + + +@composite +def to_dict_object_strategy(draw): + """Generate an object with a to_dict() method returning a JSON-safe dict.""" + inner = draw( + st.dictionaries( + keys=st.text(min_size=1, max_size=10), + values=st.one_of(st.integers(), st.text(max_size=20), st.booleans()), + max_size=5, + ) + ) + + class _Obj: + def __init__(self, d): + self._d = d + + def to_dict(self): + return self._d + + return _Obj(inner), inner + + +class TestDatetimeSerializationProperties: + """Property-based tests for datetime/date/time encoding.""" + + @given(_datetime_strategy) + @settings(max_examples=100) + def test_datetime_produces_iso_string_with_z_suffix(self, dt): + """Datetime values are serialized to ISO-8601 strings at seconds precision. + + Aware datetimes are converted to UTC and suffixed with 'Z'. + Naive datetimes are serialized as-is with no timezone indicator. + Microseconds are truncated (the encoder uses timespec="seconds"). + """ + result = json.loads(json.dumps({"v": dt}, cls=LLMeterEncoder)) + assert isinstance(result["v"], str) + if dt.tzinfo is not None: + assert result["v"].endswith("Z") + parsed_back = datetime.fromisoformat(result["v"].replace("Z", "+00:00")) + expected = dt.astimezone(timezone.utc).replace(microsecond=0) + assert parsed_back == expected + else: + parsed_back = datetime.fromisoformat(result["v"]) + assert parsed_back == dt.replace(microsecond=0) + + @given(_date_strategy) + @settings(max_examples=100) + def test_date_produces_iso_string(self, d): + """Date values are serialized via isoformat().""" + result = json.loads(json.dumps({"v": d}, cls=LLMeterEncoder)) + assert result["v"] == d.isoformat() + assert date.fromisoformat(result["v"]) == d + + @given(_time_strategy) + @settings(max_examples=100) + def test_time_produces_iso_string(self, t): + """Time values are serialized via isoformat().""" + result = json.loads(json.dumps({"v": t}, cls=LLMeterEncoder)) + assert result["v"] == t.isoformat() + assert time.fromisoformat(result["v"]) == t + + +class TestPathSerializationProperties: + """Property-based tests for PathLike encoding.""" + + @given(_path_strategy) + @settings(max_examples=100) + def test_pathlike_produces_posix_string(self, p): + """PathLike objects are serialized to POSIX path strings.""" + result = json.loads(json.dumps({"v": p}, cls=LLMeterEncoder)) + assert isinstance(result["v"], str) + assert result["v"] == p.as_posix() + + +class TestToDictSerializationProperties: + """Property-based tests for to_dict() delegation.""" + + @given(to_dict_object_strategy()) + @settings(max_examples=100) + def test_to_dict_delegation_produces_expected_dict(self, obj_and_expected): + """Objects with to_dict() are serialized by calling that method.""" + obj, expected = obj_and_expected + result = json.loads(json.dumps({"v": obj}, cls=LLMeterEncoder)) + assert result["v"] == expected + + @given(to_dict_object_strategy(), st.binary(min_size=1, max_size=100)) + @settings(max_examples=100) + def test_to_dict_and_bytes_coexist(self, obj_and_expected, raw_bytes): + """Payloads mixing to_dict() objects and bytes round-trip correctly.""" + obj, expected = obj_and_expected + payload = {"obj": obj, "data": raw_bytes} + serialized = json.dumps(payload, cls=LLMeterEncoder) + restored = json.loads(serialized, object_hook=llmeter_bytes_decoder) + assert restored["obj"] == expected + assert restored["data"] == raw_bytes diff --git a/tests/unit/test_tokenizers.py b/tests/unit/test_tokenizers.py index f441392..37dde1b 100644 --- a/tests/unit/test_tokenizers.py +++ b/tests/unit/test_tokenizers.py @@ -13,6 +13,19 @@ save_tokenizer, ) +# Check for optional dependencies +try: + import transformers # noqa: F401 + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + +try: + import tiktoken # noqa: F401 + TIKTOKEN_AVAILABLE = True +except ImportError: + TIKTOKEN_AVAILABLE = False + # Mock classes for testing class MockTransformersTokenizer: @@ -118,15 +131,29 @@ def test_save_tokenizer(tmp_path): # Test _load_tokenizer_from_info function -@pytest.mark.skip(reason="transformers is not installed") +@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="transformers is not installed") def test_load_tokenizer_from_info_transformers(monkeypatch): + # Mock AutoTokenizer.from_pretrained to return our mock + def mock_from_pretrained(name): + return MockTransformersTokenizer() + + from transformers import AutoTokenizer + monkeypatch.setattr(AutoTokenizer, "from_pretrained", mock_from_pretrained) + tokenizer_info = {"tokenizer_module": "transformers", "name": "mock-transformer"} tokenizer = _load_tokenizer_from_info(tokenizer_info) assert isinstance(tokenizer, MockTransformersTokenizer) -@pytest.mark.skip(reason="tiktoken is not installed") +@pytest.mark.skipif(not TIKTOKEN_AVAILABLE, reason="tiktoken is not installed") def test_load_tokenizer_from_info_tiktoken(monkeypatch): + # Mock get_encoding to return our mock + def mock_get_encoding(name): + return MockTiktokenTokenizer() + + import tiktoken + monkeypatch.setattr(tiktoken, "get_encoding", mock_get_encoding) + tokenizer_info = {"tokenizer_module": "tiktoken", "name": "mock-tiktoken"} tokenizer = _load_tokenizer_from_info(tokenizer_info) assert isinstance(tokenizer, MockTiktokenTokenizer)