From 3c91762594e1fbbefeba371df60d1efbc32469a5 Mon Sep 17 00:00:00 2001 From: Alessandro Cere Date: Thu, 26 Feb 2026 11:54:35 +0800 Subject: [PATCH 1/9] Fix image payload serialization to avoid base64 encoding overhead Addresses #20 by implementing binary-safe serialization for payloads and results containing images. This prevents double base64 encoding and significantly reduces memory usage and serialization time. Changes: - Add binary serialization support to prompt_utils and results - Update endpoints to use binary-safe serialization - Add property-based tests for serialization correctness - Update integration tests to verify image handling --- llmeter/endpoints/base.py | 109 +- llmeter/endpoints/bedrock_invoke.py | 8 +- llmeter/prompt_utils.py | 249 +++- llmeter/results.py | 34 + tests/integ/test_bedrock_converse.py | 251 ++++ tests/integ/test_bedrock_invoke.py | 206 ++- tests/integ/test_openai_bedrock.py | 265 +++- tests/unit/test_prompt_utils.py | 1487 ++++++++++++++++++- tests/unit/test_results.py | 275 ++++ tests/unit/test_serialization_properties.py | 768 ++++++++++ 10 files changed, 3585 insertions(+), 67 deletions(-) create mode 100644 tests/unit/test_serialization_properties.py diff --git a/llmeter/endpoints/base.py b/llmeter/endpoints/base.py index 9ce6903..bc35b03 100644 --- a/llmeter/endpoints/base.py +++ b/llmeter/endpoints/base.py @@ -47,13 +47,110 @@ class InvocationResponse: retries: int | None = None def to_json(self, **kwargs) -> str: - def default_serializer(obj): - try: - return str(obj) - except Exception: - return None + """ + Convert InvocationResponse to JSON string with binary content support. + + This method serializes the InvocationResponse object to a JSON string, with + automatic handling of binary content (bytes objects) in the input_payload field. + Binary data is converted to base64-encoded strings wrapped in marker objects, + enabling JSON serialization while preserving the ability to restore the original + bytes during deserialization. + + Binary Content Handling: + When the input_payload contains bytes objects (e.g., images, video), they are + automatically converted to base64-encoded strings and wrapped in marker objects + with the key "__llmeter_bytes__". This approach enables JSON serialization of + multimodal payloads while maintaining round-trip integrity. + + The marker object format is: {"__llmeter_bytes__": ""} + + For non-serializable types other than bytes, the encoder falls back to str() + representation to ensure the response can always be serialized. + + Args: + **kwargs: Additional arguments passed to json.dumps (e.g., indent, sort_keys) - return json.dumps(asdict(self), default=default_serializer, **kwargs) + Returns: + str: JSON representation of the response + + Examples: + Serialize a response with binary content in the payload: + + >>> # Create a response with binary image data in the payload + >>> with open("image.jpg", "rb") as f: + ... image_bytes = f.read() + >>> response = InvocationResponse( + ... response_text="The image shows a cat.", + ... input_payload={ + ... "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + ... "messages": [{ + ... "role": "user", + ... "content": [ + ... {"text": "What is in this image?"}, + ... { + ... "image": { + ... "format": "jpeg", + ... "source": {"bytes": image_bytes} + ... } + ... } + ... ] + ... }] + ... }, + ... time_to_last_token=1.23, + ... num_tokens_output=15 + ... ) + >>> json_str = response.to_json() + >>> # The JSON string contains marker objects for binary data + >>> "__llmeter_bytes__" in json_str + True + + Serialize with pretty printing: + + >>> json_str = response.to_json(indent=2) + >>> print(json_str) + { + "response_text": "The image shows a cat.", + "input_payload": { + "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "messages": [ + { + "role": "user", + "content": [ + {"text": "What is in this image?"}, + { + "image": { + "format": "jpeg", + "source": { + "bytes": {"__llmeter_bytes__": "/9j/4AAQSkZJRg..."} + } + } + } + ] + } + ] + }, + "time_to_last_token": 1.23, + "num_tokens_output": 15, + ... + } + + Round-trip serialization with binary preservation: + + >>> # Serialize to JSON + >>> json_str = response.to_json() + >>> # Parse back to dict + >>> import json + >>> from llmeter.prompt_utils import llmeter_bytes_decoder + >>> response_dict = json.loads(json_str, object_hook=llmeter_bytes_decoder) + >>> # Binary data is preserved + >>> original_bytes = response.input_payload["messages"][0]["content"][1]["image"]["source"]["bytes"] + >>> restored_bytes = response_dict["input_payload"]["messages"][0]["content"][1]["image"]["source"]["bytes"] + >>> original_bytes == restored_bytes + True + """ + from llmeter.results import InvocationResponseEncoder + + return json.dumps(asdict(self), cls=InvocationResponseEncoder, **kwargs) @staticmethod def error_output( diff --git a/llmeter/endpoints/bedrock_invoke.py b/llmeter/endpoints/bedrock_invoke.py index 20a9403..f6216c7 100644 --- a/llmeter/endpoints/bedrock_invoke.py +++ b/llmeter/endpoints/bedrock_invoke.py @@ -261,7 +261,9 @@ def invoke(self, payload: dict) -> InvocationResponse: raise TypeError("Payload must be a dictionary") try: - req_body = json.dumps(payload).encode("utf-8") + from llmeter.prompt_utils import LLMeterBytesEncoder + + req_body = json.dumps(payload, cls=LLMeterBytesEncoder).encode("utf-8") try: start_t = time.perf_counter() client_response = self._bedrock_client.invoke_model( # type: ignore @@ -353,7 +355,9 @@ def __init__( ) def invoke(self, payload: dict) -> InvocationResponse: - req_body = json.dumps(payload).encode("utf-8") + from llmeter.prompt_utils import LLMeterBytesEncoder + + req_body = json.dumps(payload, cls=LLMeterBytesEncoder).encode("utf-8") try: start_t = time.perf_counter() client_response = self._bedrock_client.invoke_model_with_response_stream( # type: ignore diff --git a/llmeter/prompt_utils.py b/llmeter/prompt_utils.py index 0cf5126..c4c7e18 100644 --- a/llmeter/prompt_utils.py +++ b/llmeter/prompt_utils.py @@ -1,6 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import base64 import json import logging from dataclasses import dataclass @@ -16,6 +17,63 @@ logger = logging.getLogger(__name__) +class LLMeterBytesEncoder(json.JSONEncoder): + """Custom JSON encoder that handles bytes objects by converting them to base64. + + This encoder wraps bytes objects in a marker object with the key "__llmeter_bytes__" + to enable round-trip serialization and deserialization of binary content in payloads. + + Example: + >>> encoder = LLMeterBytesEncoder() + >>> payload = {"image": {"bytes": b"\\xff\\xd8\\xff\\xe0"}} + >>> json.dumps(payload, cls=LLMeterBytesEncoder) + '{"image": {"bytes": {"__llmeter_bytes__": "/9j/4A=="}}}' + """ + + def default(self, obj): + """Encode bytes objects as marker objects with base64 strings. + + Args: + obj: Object to encode + + Returns: + dict: Marker object for bytes, or delegates to parent for other types + + Raises: + TypeError: If object is not JSON serializable + """ + if isinstance(obj, bytes): + return { + "__llmeter_bytes__": base64.b64encode(obj).decode("utf-8") + } + return super().default(obj) + + +def llmeter_bytes_decoder(dct: dict) -> dict | bytes: + """Decode marker objects back to bytes during JSON deserialization. + + This function is used as an object_hook for json.loads to detect and decode + marker objects created by LLMeterBytesEncoder back to bytes during deserialization. + + Args: + dct: Dictionary from JSON parsing + + Returns: + bytes if marker detected, otherwise original dict + + Example: + >>> marker = {"__llmeter_bytes__": "/9j/4A=="} + >>> llmeter_bytes_decoder(marker) + b'\\xff\\xd8\\xff\\xe0' + >>> regular_dict = {"key": "value"} + >>> llmeter_bytes_decoder(regular_dict) + {'key': 'value'} + """ + if "__llmeter_bytes__" in dct and len(dct) == 1: + return base64.b64decode(dct["__llmeter_bytes__"]) + return dct + + @dataclass class CreatePromptCollection: input_lengths: list[int] @@ -107,16 +165,32 @@ def load_prompts( continue -def load_payloads(file_path: os.PathLike | str) -> Iterator[dict]: +def load_payloads( + file_path: os.PathLike | str, + deserializer: Callable[[str], dict] | None = None, +) -> Iterator[dict]: """ - Load JSON payload(s) from a file or directory. + Load JSON payload(s) from a file or directory with binary content support. This function reads JSON data from either a single file or multiple files - in a directory. It supports both .json and .jsonl file formats. + in a directory. It supports both .json and .jsonl file formats. Binary content + (bytes objects) that were serialized using LLMeterBytesEncoder are automatically + restored during deserialization. + + Binary Content Handling: + When loading payloads saved with save_payloads(), marker objects with the key + "__llmeter_bytes__" are automatically detected and converted back to bytes objects. + The base64-encoded strings are decoded to restore the original binary data, + enabling round-trip preservation of multimodal content like images and video. + + The marker object format is: {"__llmeter_bytes__": ""} Args: file_path (Union[Path, str]): Path to a JSON file or a directory containing JSON files. Can be a string or a Path object. + deserializer (Callable[[str], dict], optional): Custom deserializer function + that takes a JSON string and returns a dict. If None, uses json.loads with + llmeter_bytes_decoder for automatic binary content handling. Defaults to None. Yields: dict: Each JSON object loaded from the file(s). @@ -127,6 +201,55 @@ def load_payloads(file_path: os.PathLike | str) -> Iterator[dict]: ValidationError: If the JSON data does not conform to the expected schema. IOError: If there's an error reading the file. + Examples: + Load a Bedrock Converse API payload with image content: + + >>> # Assuming a file was saved with save_payloads() containing binary data + >>> payloads = list(load_payloads("/tmp/output/payload.jsonl")) + >>> payload = payloads[0] + >>> # Binary content is automatically restored as bytes + >>> image_bytes = payload["messages"][0]["content"][1]["image"]["source"]["bytes"] + >>> isinstance(image_bytes, bytes) + True + >>> # The bytes can be used directly with the API + >>> print(f"Image size: {len(image_bytes)} bytes") + Image size: 52341 bytes + + Load multiple payloads with video content: + + >>> for payload in load_payloads("/tmp/output/multimodal.jsonl"): + ... video_content = payload["messages"][0]["content"][1] + ... if "video" in video_content: + ... video_bytes = video_content["video"]["source"]["bytes"] + ... print(f"Loaded video: {len(video_bytes)} bytes") + Loaded video: 1048576 bytes + + Load all payloads from a directory: + + >>> # Load all .json and .jsonl files in a directory + >>> all_payloads = list(load_payloads("/tmp/output/")) + >>> print(f"Loaded {len(all_payloads)} payloads") + Loaded 5 payloads + + Round-trip example showing binary preservation: + + >>> # Original payload with binary data + >>> original = { + ... "modelId": "test-model", + ... "messages": [{ + ... "role": "user", + ... "content": [ + ... {"image": {"source": {"bytes": b"\\xff\\xd8\\xff\\xe0"}}} + ... ] + ... }] + ... } + >>> # Save and load + >>> save_payloads(original, "/tmp/test") + PosixPath('/tmp/test/payload.jsonl') + >>> loaded = list(load_payloads("/tmp/test/payload.jsonl"))[0] + >>> # Binary data is preserved byte-for-byte + >>> original == loaded + True """ file_path = Path(file_path) @@ -134,13 +257,13 @@ def load_payloads(file_path: os.PathLike | str) -> Iterator[dict]: raise FileNotFoundError(f"The specified path does not exist: {file_path}") if file_path.is_file(): - yield from _load_data_file(file_path) + yield from _load_data_file(file_path, deserializer) else: for file in file_path.glob("*.json*"): - yield from _load_data_file(file) + yield from _load_data_file(file, deserializer) -def _load_data_file(file: Path) -> Iterator[dict]: +def _load_data_file(file: Path, deserializer: Callable[[str], dict] | None = None) -> Iterator[dict]: try: with file.open(mode="r") as f: if file.suffix.lower() in [".jsonl", ".manifest"]: @@ -148,11 +271,19 @@ def _load_data_file(file: Path) -> Iterator[dict]: try: if not line.strip(): continue - yield json.loads(line.strip()) + if deserializer is None: + yield json.loads(line.strip(), object_hook=llmeter_bytes_decoder) + else: + yield deserializer(line.strip()) except json.JSONDecodeError as e: print(f"Error decoding JSON in {file}: {e}") else: # Assume it's a regular JSON file - yield json.load(f) + if deserializer is None: + yield json.load(f, object_hook=llmeter_bytes_decoder) + else: + # For custom deserializer, read the entire file as string + f.seek(0) + yield deserializer(f.read()) except IOError as e: print(f"Error reading file {file}: {e}") except json.JSONDecodeError as e: @@ -163,20 +294,109 @@ def save_payloads( payloads: list[dict] | dict, output_path: os.PathLike | str, output_file: str = "payload.jsonl", + serializer: Callable[[dict], str] | None = None, ) -> Path: """ - Save payloads to a file. + Save payloads to a file with support for binary content. + + This function saves payloads to a JSONL file, with automatic handling of binary + content (bytes objects) through base64 encoding. Binary data is wrapped in marker + objects during serialization to enable round-trip preservation. + + Binary Content Handling: + When a payload contains bytes objects (e.g., images, video), they are automatically + converted to base64-encoded strings and wrapped in a marker object with the key + "__llmeter_bytes__". This approach enables JSON serialization while preserving + the ability to restore the original bytes during deserialization with load_payloads(). + + The marker object format is: {"__llmeter_bytes__": ""} Args: - payloads (Iterator[Dict]): An iterator of payloads (dicts). + payloads (Union[list[dict], dict]): Payload(s) to save. May contain bytes objects + at any nesting level. output_path (Union[Path, str]): The directory path where the output file should be saved. - output_file (str, optional): The name of the output file. Defaults to "payloads.jsonl". + output_file (str, optional): The name of the output file. Defaults to "payload.jsonl". + serializer (Callable[[dict], str], optional): Custom serializer function that takes + a dict and returns a JSON string. If None, uses LLMeterBytesEncoder for automatic + binary content handling. Defaults to None. Returns: - output_file_path (UPath): The path to the output file. + Path: The path to the output file. Raises: IOError: If there's an error writing to the file. + TypeError: If payload contains unserializable types. + + Examples: + Save a Bedrock Converse API payload with image content: + + >>> import base64 + >>> # Create a payload with binary image data + >>> with open("image.jpg", "rb") as f: + ... image_bytes = f.read() + >>> payload = { + ... "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + ... "messages": [{ + ... "role": "user", + ... "content": [ + ... {"text": "What is in this image?"}, + ... { + ... "image": { + ... "format": "jpeg", + ... "source": {"bytes": image_bytes} + ... } + ... } + ... ] + ... }] + ... } + >>> output_path = save_payloads(payload, "/tmp/output") + >>> print(output_path) + /tmp/output/payload.jsonl + + Save multiple payloads with video content: + + >>> with open("video.mp4", "rb") as f: + ... video_bytes = f.read() + >>> payloads = [ + ... { + ... "modelId": "anthropic.claude-3-sonnet-20240229-v1:0", + ... "messages": [{ + ... "role": "user", + ... "content": [ + ... {"text": "Describe this video"}, + ... { + ... "video": { + ... "format": "mp4", + ... "source": {"bytes": video_bytes} + ... } + ... } + ... ] + ... }] + ... } + ... ] + >>> save_payloads(payloads, "/tmp/output", "multimodal.jsonl") + PosixPath('/tmp/output/multimodal.jsonl') + + The saved JSON file will contain marker objects for binary data: + + >>> # Example of what gets written to the file: + >>> # { + >>> # "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + >>> # "messages": [{ + >>> # "role": "user", + >>> # "content": [ + >>> # {"text": "What is in this image?"}, + >>> # { + >>> # "image": { + >>> # "format": "jpeg", + >>> # "source": { + >>> # "bytes": {"__llmeter_bytes__": "/9j/4AAQSkZJRg..."} + >>> # } + >>> # } + >>> # } + >>> # ] + >>> # }] + >>> # } """ output_path = Path(output_path) output_path.mkdir(parents=True, exist_ok=True) @@ -186,5 +406,8 @@ def save_payloads( payloads = [payloads] with output_file_path.open(mode="w") as f: for payload in payloads: - f.write(json.dumps(payload) + "\n") + if serializer is None: + f.write(json.dumps(payload, cls=LLMeterBytesEncoder) + "\n") + else: + f.write(serializer(payload) + "\n") return output_file_path diff --git a/llmeter/results.py b/llmeter/results.py index 45a1c00..5392d2f 100644 --- a/llmeter/results.py +++ b/llmeter/results.py @@ -14,6 +14,7 @@ from upath import UPath as Path from .endpoints import InvocationResponse +from .prompt_utils import LLMeterBytesEncoder from .utils import summary_stats_from_list logger = logging.getLogger(__name__) @@ -39,6 +40,39 @@ def utc_datetime_serializer(obj: Any) -> str: return str(obj) +class InvocationResponseEncoder(LLMeterBytesEncoder): + """Extended encoder for InvocationResponse with fallback to str() for non-serializable types. + + This encoder extends LLMeterBytesEncoder to handle bytes objects (via parent class) + and adds a fallback mechanism for other non-serializable types by converting them + to strings. This is particularly useful for InvocationResponse objects that may + contain various non-standard types. + + Example: + >>> response = InvocationResponse(input_payload={"image": {"bytes": b"\\xff\\xd8"}}) + >>> json.dumps(asdict(response), cls=InvocationResponseEncoder) + '{"input_payload": {"image": {"bytes": {"__llmeter_bytes__": "/9g="}}}}' + """ + + def default(self, obj): + """Encode objects with bytes support and str() fallback. + + Args: + obj: Object to encode + + Returns: + Encoded representation or None if encoding fails + """ + # First try bytes encoding from parent + if isinstance(obj, bytes): + return super().default(obj) + # Fallback to string representation for other non-serializable types + try: + return str(obj) + except Exception: + return None + + @dataclass class Result: """Results of a test run.""" diff --git a/tests/integ/test_bedrock_converse.py b/tests/integ/test_bedrock_converse.py index 76e4652..359541c 100644 --- a/tests/integ/test_bedrock_converse.py +++ b/tests/integ/test_bedrock_converse.py @@ -335,3 +335,254 @@ def test_bedrock_converse_streaming_with_image( # Verify response has an ID assert response.id is not None, "Response should have an ID" + + +def test_save_load_payload_with_image(test_payload_with_image, tmp_path): + """ + Test saving and loading payload with image content using serialization. + + This test validates that: + - Payloads with image bytes can be saved to disk + - Saved payloads contain valid JSON with marker objects + - Loaded payloads restore bytes objects correctly + - Round-trip preserves byte-for-byte equality + + **Validates: Requirements 1.6, 9.3** + + Args: + test_payload_with_image: Test payload with image content (from fixture). + tmp_path: Temporary directory for test files (from pytest). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + + # Save payload with image + output_file = save_payloads(test_payload_with_image, tmp_path, "test_image.jsonl") + assert output_file.exists(), "Output file should be created" + + # Verify file contains valid JSON with marker objects + with output_file.open("r") as f: + content = f.read() + assert "__llmeter_bytes__" in content, "File should contain marker objects" + + # Load payload and verify bytes are restored + loaded_payloads = list(load_payloads(output_file)) + assert len(loaded_payloads) == 1, "Should load exactly one payload" + + loaded = loaded_payloads[0] + original_bytes = test_payload_with_image["messages"][0]["content"][0]["image"]["source"]["bytes"] + loaded_bytes = loaded["messages"][0]["content"][0]["image"]["source"]["bytes"] + + assert isinstance(loaded_bytes, bytes), "Loaded bytes should be bytes type" + assert loaded_bytes == original_bytes, "Bytes should match after round-trip" + assert loaded == test_payload_with_image, "Full payload should match after round-trip" + + +def test_save_load_payload_with_video(tmp_path): + """ + Test saving and loading payload with video content using serialization. + + This test validates that: + - Payloads with video bytes can be saved to disk + - Saved payloads contain valid JSON with marker objects + - Loaded payloads restore bytes objects correctly + - Video content path (messages[].content[].video.source.bytes) is handled + + **Validates: Requirements 1.6, 9.4** + + Args: + tmp_path: Temporary directory for test files (from pytest). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + + # Create a test payload with video content (simulated video bytes) + video_payload = { + "messages": [ + { + "role": "user", + "content": [ + { + "video": { + "format": "mp4", + "source": {"bytes": b"\x00\x00\x00\x18ftypmp42"}, # MP4 header + } + }, + {"text": "What is happening in this video?"}, + ], + } + ], + "inferenceConfig": {"maxTokens": 150}, + } + + # Save payload with video + output_file = save_payloads(video_payload, tmp_path, "test_video.jsonl") + assert output_file.exists(), "Output file should be created" + + # Verify file contains valid JSON with marker objects + with output_file.open("r") as f: + content = f.read() + assert "__llmeter_bytes__" in content, "File should contain marker objects" + + # Load payload and verify bytes are restored + loaded_payloads = list(load_payloads(output_file)) + assert len(loaded_payloads) == 1, "Should load exactly one payload" + + loaded = loaded_payloads[0] + original_bytes = video_payload["messages"][0]["content"][0]["video"]["source"]["bytes"] + loaded_bytes = loaded["messages"][0]["content"][0]["video"]["source"]["bytes"] + + assert isinstance(loaded_bytes, bytes), "Loaded bytes should be bytes type" + assert loaded_bytes == original_bytes, "Bytes should match after round-trip" + assert loaded == video_payload, "Full payload should match after round-trip" + + +def test_save_load_multiple_images(tmp_path): + """ + Test saving and loading payload with multiple images in single payload. + + This test validates that: + - Payloads with multiple image bytes can be saved to disk + - All images are correctly serialized with marker objects + - All images are correctly restored after loading + - Multiple bytes objects in same payload are handled independently + + **Validates: Requirements 9.8** + + Args: + tmp_path: Temporary directory for test files (from pytest). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + + # Create a test payload with multiple images + multi_image_payload = { + "messages": [ + { + "role": "user", + "content": [ + { + "image": { + "format": "png", + "source": {"bytes": b"\x89PNG\r\n\x1a\n"}, # PNG header + } + }, + {"text": "Compare these images:"}, + { + "image": { + "format": "jpeg", + "source": {"bytes": b"\xff\xd8\xff\xe0"}, # JPEG header + } + }, + { + "image": { + "format": "png", + "source": {"bytes": b"\x89PNG\r\n\x1a\n\x00\x00"}, # Different PNG + } + }, + ], + } + ], + "inferenceConfig": {"maxTokens": 200}, + } + + # Save payload with multiple images + output_file = save_payloads(multi_image_payload, tmp_path, "test_multi_image.jsonl") + assert output_file.exists(), "Output file should be created" + + # Verify file contains multiple marker objects + with output_file.open("r") as f: + content = f.read() + marker_count = content.count("__llmeter_bytes__") + assert marker_count == 3, f"File should contain 3 marker objects, found {marker_count}" + + # Load payload and verify all bytes are restored + loaded_payloads = list(load_payloads(output_file)) + assert len(loaded_payloads) == 1, "Should load exactly one payload" + + loaded = loaded_payloads[0] + content = loaded["messages"][0]["content"] + + # Verify first image + assert isinstance(content[0]["image"]["source"]["bytes"], bytes) + assert content[0]["image"]["source"]["bytes"] == b"\x89PNG\r\n\x1a\n" + + # Verify second image + assert isinstance(content[2]["image"]["source"]["bytes"], bytes) + assert content[2]["image"]["source"]["bytes"] == b"\xff\xd8\xff\xe0" + + # Verify third image + assert isinstance(content[3]["image"]["source"]["bytes"], bytes) + assert content[3]["image"]["source"]["bytes"] == b"\x89PNG\r\n\x1a\n\x00\x00" + + # Verify full payload matches + assert loaded == multi_image_payload, "Full payload should match after round-trip" + + +@pytest.mark.integ +def test_round_trip_bedrock_converse_structure(test_payload_with_image, tmp_path, aws_credentials, aws_region): + """ + Test round-trip serialization with actual Bedrock Converse API structure. + + This test validates that: + - Complete Bedrock Converse API payload structure is preserved + - All nested fields (modelId, messages, inferenceConfig) are maintained + - Image bytes in messages[].content[].image.source.bytes path are handled + - Round-trip produces identical payload structure + - Loaded payload can be used with the endpoint's invoke method + + **Validates: Requirements 1.6, 9.3** + + Args: + test_payload_with_image: Test payload with image content (from fixture). + tmp_path: Temporary directory for test files (from pytest). + aws_credentials: Boto3 session with valid AWS credentials (from fixture). + aws_region: AWS region for testing (from fixture). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + from llmeter.endpoints.bedrock import BedrockConverse + + # Create a complete Bedrock Converse payload with all typical fields + complete_payload = { + "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "messages": test_payload_with_image["messages"], + "inferenceConfig": { + "maxTokens": 150, + "temperature": 0.7, + "topP": 0.9, + }, + "system": [{"text": "You are a helpful assistant that describes images."}], + } + + # Save and load the complete payload + output_file = save_payloads(complete_payload, tmp_path, "test_complete.jsonl") + loaded_payloads = list(load_payloads(output_file)) + + assert len(loaded_payloads) == 1, "Should load exactly one payload" + loaded = loaded_payloads[0] + + # Verify all top-level fields are preserved + assert loaded["modelId"] == complete_payload["modelId"] + assert loaded["inferenceConfig"] == complete_payload["inferenceConfig"] + assert loaded["system"] == complete_payload["system"] + + # Verify messages structure is preserved + assert len(loaded["messages"]) == len(complete_payload["messages"]) + assert loaded["messages"][0]["role"] == complete_payload["messages"][0]["role"] + + # Verify image bytes are correctly restored + original_bytes = complete_payload["messages"][0]["content"][0]["image"]["source"]["bytes"] + loaded_bytes = loaded["messages"][0]["content"][0]["image"]["source"]["bytes"] + assert isinstance(loaded_bytes, bytes) + assert loaded_bytes == original_bytes + + # Verify complete equality + assert loaded == complete_payload, "Complete payload should match after round-trip" + + # Verify the loaded payload can be used with the endpoint's invoke method + # Extract model_id from the loaded payload + model_id = loaded.pop("modelId") + endpoint = BedrockConverse(model_id=model_id, region=aws_region) + response = endpoint.invoke(loaded) + + # Verify the endpoint successfully processed the loaded payload + assert response.response_text is not None, "Response should contain text" + assert len(response.response_text) > 0, "Response text should not be empty" + assert response.error is None, f"Response should not contain errors: {response.error}" diff --git a/tests/integ/test_bedrock_invoke.py b/tests/integ/test_bedrock_invoke.py index a816f3e..00750e1 100644 --- a/tests/integ/test_bedrock_invoke.py +++ b/tests/integ/test_bedrock_invoke.py @@ -95,9 +95,9 @@ def test_bedrock_invoke_non_streaming(aws_credentials, aws_region, bedrock_test_ # Verify token counts are present and positive assert response.num_tokens_input is not None, "Input token count should not be None" assert response.num_tokens_input > 0, "Input token count should be positive" - assert ( - response.num_tokens_output is not None - ), "Output token count should not be None" + assert response.num_tokens_output is not None, ( + "Output token count should not be None" + ) assert response.num_tokens_output > 0, "Output token count should be positive" # Verify response time is measured and positive @@ -105,9 +105,9 @@ def test_bedrock_invoke_non_streaming(aws_credentials, aws_region, bedrock_test_ assert response.time_to_last_token > 0, "Response time should be positive" # Verify no errors in response - assert ( - response.error is None - ), f"Response should not contain errors: {response.error}" + assert response.error is None, ( + f"Response should not contain errors: {response.error}" + ) # Verify response has an ID assert response.id is not None, "Response should have an ID" @@ -187,32 +187,196 @@ def test_bedrock_invoke_streaming(aws_credentials, aws_region, bedrock_test_mode # Verify token counts are present and positive assert response.num_tokens_input is not None, "Input token count should not be None" assert response.num_tokens_input > 0, "Input token count should be positive" - assert ( - response.num_tokens_output is not None - ), "Output token count should not be None" + assert response.num_tokens_output is not None, ( + "Output token count should not be None" + ) assert response.num_tokens_output > 0, "Output token count should be positive" # Verify time to first token is measured and positive - assert ( - response.time_to_first_token is not None - ), "Time to first token should not be None" + assert response.time_to_first_token is not None, ( + "Time to first token should not be None" + ) assert response.time_to_first_token > 0, "Time to first token should be positive" # Verify time to last token is measured and positive - assert ( - response.time_to_last_token is not None - ), "Time to last token should not be None" + assert response.time_to_last_token is not None, ( + "Time to last token should not be None" + ) assert response.time_to_last_token > 0, "Time to last token should be positive" # Verify TTLT > TTFT (streaming should take time to complete) - assert ( - response.time_to_last_token > response.time_to_first_token - ), "Time to last token should be greater than time to first token" + assert response.time_to_last_token > response.time_to_first_token, ( + "Time to last token should be greater than time to first token" + ) # Verify no errors in response - assert ( - response.error is None - ), f"Response should not contain errors: {response.error}" + assert response.error is None, ( + f"Response should not contain errors: {response.error}" + ) # Verify response has an ID assert response.id is not None, "Response should have an ID" + + +def test_save_load_invoke_payload_with_image(tmp_path): + """ + Test saving and loading Bedrock Invoke API payload with image content. + + This test validates that: + - Bedrock Invoke API payloads with image bytes can be saved to disk + - Provider-specific payload structure (Anthropic Claude format) is preserved + - Saved payloads contain valid JSON with marker objects + - Loaded payloads restore bytes objects correctly + - Round-trip preserves byte-for-byte equality + + **Validates: Requirements 9.5** + + Args: + tmp_path: Temporary directory for test files (from pytest). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + import io + from PIL import Image + + # Create a simple test image + img = Image.new("RGB", (50, 50), color="blue") + img_buffer = io.BytesIO() + img.save(img_buffer, format="PNG") + img_bytes = img_buffer.getvalue() + + # Create Bedrock Invoke API payload with Anthropic Claude format + # This uses the native Messages API format, not Converse API + invoke_payload = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 150, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": img_bytes, # Binary data in provider-specific format + }, + }, + { + "type": "text", + "text": "What is in this image?", + }, + ], + } + ], + } + + # Save payload with image + output_file = save_payloads(invoke_payload, tmp_path, "test_invoke_image.jsonl") + assert output_file.exists(), "Output file should be created" + + # Verify file contains valid JSON with marker objects + with output_file.open("r") as f: + content = f.read() + assert "__llmeter_bytes__" in content, "File should contain marker objects" + + # Load payload and verify bytes are restored + loaded_payloads = list(load_payloads(output_file)) + assert len(loaded_payloads) == 1, "Should load exactly one payload" + + loaded = loaded_payloads[0] + original_bytes = invoke_payload["messages"][0]["content"][0]["source"]["data"] + loaded_bytes = loaded["messages"][0]["content"][0]["source"]["data"] + + assert isinstance(loaded_bytes, bytes), "Loaded bytes should be bytes type" + assert loaded_bytes == original_bytes, "Bytes should match after round-trip" + assert loaded == invoke_payload, "Full payload should match after round-trip" + + +@pytest.mark.integ +def test_round_trip_invoke_structure( + tmp_path, aws_credentials, aws_region, bedrock_test_model +): + """ + Test round-trip serialization with actual Bedrock Invoke API structure. + + This test validates that: + - Complete Bedrock Invoke API payload structure is preserved + - Provider-specific fields (anthropic_version, max_tokens) are maintained + - Text content in messages is preserved correctly + - Round-trip produces identical payload structure + - Loaded payload can be used with the endpoint's invoke method + + **Validates: Requirements 9.5** + + Args: + tmp_path: Temporary directory for test files (from pytest). + aws_credentials: Boto3 session with valid AWS credentials (from fixture). + aws_region: AWS region for testing (from fixture). + bedrock_test_model: Model ID for testing (from fixture). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + + # Create a complete Bedrock Invoke payload with provider-specific structure + # This mimics the actual Anthropic Claude Messages API format used by InvokeModel + complete_invoke_payload = { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 200, + "temperature": 0.7, + "top_p": 0.9, + "messages": [ + { + "role": "user", + "content": "Please provide a brief response to this test message.", + } + ], + "system": "You are a helpful assistant.", + } + + # Save and load the complete payload + output_file = save_payloads( + complete_invoke_payload, tmp_path, "test_invoke_complete.jsonl" + ) + loaded_payloads = list(load_payloads(output_file)) + + assert len(loaded_payloads) == 1, "Should load exactly one payload" + loaded = loaded_payloads[0] + + # Verify all top-level fields are preserved + assert loaded["anthropic_version"] == complete_invoke_payload["anthropic_version"] + assert loaded["max_tokens"] == complete_invoke_payload["max_tokens"] + assert loaded["temperature"] == complete_invoke_payload["temperature"] + assert loaded["top_p"] == complete_invoke_payload["top_p"] + assert loaded["system"] == complete_invoke_payload["system"] + + # Verify messages structure is preserved + assert len(loaded["messages"]) == len(complete_invoke_payload["messages"]) + assert ( + loaded["messages"][0]["role"] == complete_invoke_payload["messages"][0]["role"] + ) + assert ( + loaded["messages"][0]["content"] + == complete_invoke_payload["messages"][0]["content"] + ) + + # Verify complete equality + assert loaded == complete_invoke_payload, ( + "Complete payload should match after round-trip" + ) + + # Verify the loaded payload can be used with the endpoint's invoke method + endpoint = BedrockInvoke( + model_id=bedrock_test_model, + region=aws_region, + generated_text_jmespath="content[0].text", + generated_token_count_jmespath="usage.output_tokens", + input_token_count_jmespath="usage.input_tokens", + input_text_jmespath="messages[0].content", + ) + response = endpoint.invoke(loaded) + + # Verify the endpoint successfully processed the loaded payload + assert response.response_text is not None, "Response should contain text" + assert len(response.response_text) > 0, "Response text should not be empty" + assert response.error is None, ( + f"Response should not contain errors: {response.error}" + ) diff --git a/tests/integ/test_openai_bedrock.py b/tests/integ/test_openai_bedrock.py index 63c8902..d3b9b28 100644 --- a/tests/integ/test_openai_bedrock.py +++ b/tests/integ/test_openai_bedrock.py @@ -92,25 +92,25 @@ def test_openai_bedrock_non_streaming( # Verify response contains non-empty text assert response.choices is not None, "Response choices should not be None" assert len(response.choices) > 0, "Response should have at least one choice" - assert ( - response.choices[0].message.content is not None - ), "Response content should not be None" - assert ( - len(response.choices[0].message.content) > 0 - ), "Response content should not be empty" - assert isinstance( - response.choices[0].message.content, str - ), "Response content should be a string" + assert response.choices[0].message.content is not None, ( + "Response content should not be None" + ) + assert len(response.choices[0].message.content) > 0, ( + "Response content should not be empty" + ) + assert isinstance(response.choices[0].message.content, str), ( + "Response content should be a string" + ) # Verify token counts are present and positive assert response.usage is not None, "Response usage should not be None" - assert ( - response.usage.prompt_tokens is not None - ), "Input token count should not be None" + assert response.usage.prompt_tokens is not None, ( + "Input token count should not be None" + ) assert response.usage.prompt_tokens > 0, "Input token count should be positive" - assert ( - response.usage.completion_tokens is not None - ), "Output token count should not be None" + assert response.usage.completion_tokens is not None, ( + "Output token count should not be None" + ) assert response.usage.completion_tokens > 0, "Output token count should be positive" # Verify response time is measured and positive @@ -225,9 +225,238 @@ def test_openai_bedrock_streaming( assert time_to_last_token > 0, "Time to last token should be positive" # Verify TTLT > TTFT (streaming should take time to complete) - assert ( - time_to_last_token > time_to_first_token - ), "Time to last token should be greater than time to first token" + assert time_to_last_token > time_to_first_token, ( + "Time to last token should be greater than time to first token" + ) # Verify response has an ID assert response_id is not None, "Response should have an ID" + + +def test_save_load_openai_payload_with_image_url(tmp_path): + """ + Test saving and loading OpenAI payload with image_url data URI. + + This test validates that: + - OpenAI payloads with image_url data URIs can be saved to disk + - Saved payloads contain valid JSON with marker objects + - Loaded payloads restore bytes objects correctly in image_url.url field + - Round-trip preserves byte-for-byte equality + + **Validates: Requirements 9.6, 9.7** + + Args: + tmp_path: Temporary directory for test files (from pytest). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + import base64 + + # Create a small test image (1x1 red pixel PNG) + test_image_bytes = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + ) + + # Create OpenAI chat.completions payload with image_url data URI + openai_payload = { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": { + "url": test_image_bytes # Store as bytes for serialization + }, + }, + ], + } + ], + "max_tokens": 100, + } + + # Save payload with image_url + output_file = save_payloads(openai_payload, tmp_path, "test_openai_image.jsonl") + assert output_file.exists(), "Output file should be created" + + # Verify file contains valid JSON with marker objects + with output_file.open("r") as f: + content = f.read() + assert "__llmeter_bytes__" in content, "File should contain marker objects" + + # Load payload and verify bytes are restored + loaded_payloads = list(load_payloads(output_file)) + assert len(loaded_payloads) == 1, "Should load exactly one payload" + + loaded = loaded_payloads[0] + original_bytes = openai_payload["messages"][0]["content"][1]["image_url"]["url"] + loaded_bytes = loaded["messages"][0]["content"][1]["image_url"]["url"] + + assert isinstance(loaded_bytes, bytes), "Loaded bytes should be bytes type" + assert loaded_bytes == original_bytes, "Bytes should match after round-trip" + assert loaded == openai_payload, "Full payload should match after round-trip" + + +@pytest.mark.integ +@pytest.mark.skipif(not OPENAI_AVAILABLE, reason="OpenAI SDK not installed") +def test_save_load_openai_complete_structure( + tmp_path, aws_credentials, aws_region, bedrock_openai_test_model +): + """ + Test round-trip with actual OpenAI chat.completions structure. + + This test validates that: + - Complete OpenAI chat.completions payloads serialize correctly + - All typical OpenAI fields are preserved + - Multiple content items with mixed text and images work correctly + - The messages[].content[].image_url.url path is handled correctly + - Loaded payload can be used with the OpenAI client + + **Validates: Requirements 9.6, 9.7** + + Args: + tmp_path: Temporary directory for test files (from pytest). + aws_credentials: Boto3 session with valid AWS credentials (from fixture). + aws_region: AWS region for testing (from fixture). + bedrock_openai_test_model: Model ID for OpenAI SDK testing (from fixture). + """ + from llmeter.prompt_utils import save_payloads, load_payloads + import base64 + + # Create test images + image1_bytes = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + ) + image2_bytes = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAYAAABytg0kAAAAFElEQVR42mNk+M/wn4EIwDiqkL4KAcT9BAFZhEjRAAAAAElFTkSuQmCC" + ) + + # Use google.gemma-3-4b-it for multi-modal testing (supports images) + multimodal_model = "google.gemma-3-4b-it" + + # Create complete OpenAI payload with multiple messages and images + complete_payload = { + "model": multimodal_model, + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant that analyzes images.", + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Compare these two images:"}, + { + "type": "image_url", + "image_url": {"url": image1_bytes, "detail": "high"}, + }, + { + "type": "image_url", + "image_url": {"url": image2_bytes, "detail": "low"}, + }, + {"type": "text", "text": "What are the differences?"}, + ], + }, + ], + "max_tokens": 500, + "temperature": 0.7, + "top_p": 0.9, + } + + # Save and load the complete payload + output_file = save_payloads( + complete_payload, tmp_path, "test_openai_complete.jsonl" + ) + loaded_payloads = list(load_payloads(output_file)) + + assert len(loaded_payloads) == 1, "Should load exactly one payload" + loaded = loaded_payloads[0] + + # Verify structure is preserved + assert loaded["model"] == complete_payload["model"] + assert len(loaded["messages"]) == 2 + assert loaded["messages"][0]["role"] == "system" + assert loaded["messages"][1]["role"] == "user" + assert len(loaded["messages"][1]["content"]) == 4 + + # Verify image bytes are restored correctly + loaded_image1 = loaded["messages"][1]["content"][1]["image_url"]["url"] + loaded_image2 = loaded["messages"][1]["content"][2]["image_url"]["url"] + + assert isinstance(loaded_image1, bytes), "First image should be bytes" + assert isinstance(loaded_image2, bytes), "Second image should be bytes" + assert loaded_image1 == image1_bytes, "First image bytes should match" + assert loaded_image2 == image2_bytes, "Second image bytes should match" + + # Verify detail field is preserved + assert loaded["messages"][1]["content"][1]["image_url"]["detail"] == "high" + assert loaded["messages"][1]["content"][2]["image_url"]["detail"] == "low" + + # Verify full round-trip equality + assert loaded == complete_payload, "Full payload should match after round-trip" + + # Verify the loaded payload can be used with the OpenAI client + # Note: OpenAI expects base64 data URIs, not raw bytes, so we need to convert + # the bytes back to data URI format for the actual API call + token = provide_token(region=aws_region) + base_url = f"https://bedrock-runtime.{aws_region}.amazonaws.com/openai/v1" + client = OpenAI(api_key=token, base_url=base_url) + + # Convert bytes back to data URIs for OpenAI API + api_payload = { + "model": loaded["model"], + "messages": [ + loaded["messages"][0], # System message + { + "role": loaded["messages"][1]["role"], + "content": [ + loaded["messages"][1]["content"][0], # Text + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{base64.b64encode(loaded['messages'][1]['content'][1]['image_url']['url']).decode()}", + "detail": loaded["messages"][1]["content"][1]["image_url"][ + "detail" + ], + }, + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{base64.b64encode(loaded['messages'][1]['content'][2]['image_url']['url']).decode()}", + "detail": loaded["messages"][1]["content"][2]["image_url"][ + "detail" + ], + }, + }, + loaded["messages"][1]["content"][3], # Text + ], + }, + ], + "max_tokens": loaded["max_tokens"], + "temperature": loaded["temperature"], + "top_p": loaded["top_p"], + } + + response = client.chat.completions.create(**api_payload) + + # Verify the client successfully processed the loaded payload + assert response.choices is not None, "Response should contain choices" + assert len(response.choices) > 0, "Response should have at least one choice" + assert response.choices[0].message.content is not None, ( + "Response should contain text" + ) + assert len(response.choices[0].message.content) > 0, ( + "Response text should not be empty" + ) + + # Verify the client successfully processed the loaded payload + assert response.choices is not None, "Response should contain choices" + assert len(response.choices) > 0, "Response should have at least one choice" + assert response.choices[0].message.content is not None, ( + "Response should contain text" + ) + assert len(response.choices[0].message.content) > 0, ( + "Response text should not be empty" + ) diff --git a/tests/unit/test_prompt_utils.py b/tests/unit/test_prompt_utils.py index c8a07e9..c0d1a0f 100644 --- a/tests/unit/test_prompt_utils.py +++ b/tests/unit/test_prompt_utils.py @@ -13,6 +13,7 @@ from llmeter.prompt_utils import ( CreatePromptCollection, + LLMeterBytesEncoder, load_payloads, load_prompts, save_payloads, @@ -20,6 +21,435 @@ from llmeter.tokenizers import DummyTokenizer +class TestLLMeterBytesEncoder: + """Unit tests for LLMeterBytesEncoder class. + + These tests verify specific examples and edge cases for the LLMeterBytesEncoder + class, complementing the property-based tests. + + Requirements: 1.1, 1.2, 1.3, 1.6 + """ + + def test_simple_bytes_object_serialization(self): + """Test serialization of a simple bytes object. + + Validates: Requirements 1.1, 1.2, 1.3 + """ + payload = {"data": b"hello world"} + + # Serialize using the encoder + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify the marker object structure + assert "__llmeter_bytes__" in parsed["data"] + assert len(parsed["data"]) == 1 + assert isinstance(parsed["data"]["__llmeter_bytes__"], str) + + # Verify the base64 encoding is correct + import base64 + + decoded = base64.b64decode(parsed["data"]["__llmeter_bytes__"]) + assert decoded == b"hello world" + + def test_nested_bytes_in_dict_structure(self): + """Test serialization of bytes nested in complex dict structure. + + Validates: Requirements 1.1, 1.3, 1.5 + """ + payload = { + "modelId": "test-model", + "messages": [ + { + "role": "user", + "content": [ + {"text": "What is in this image?"}, + { + "image": { + "format": "jpeg", + "source": {"bytes": b"\xff\xd8\xff\xe0\x00\x10JFIF"}, + } + }, + ], + } + ], + } + + # Serialize + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify structure is preserved + assert parsed["modelId"] == "test-model" + assert parsed["messages"][0]["role"] == "user" + assert parsed["messages"][0]["content"][0]["text"] == "What is in this image?" + + # Verify bytes are replaced with marker + bytes_marker = parsed["messages"][0]["content"][1]["image"]["source"]["bytes"] + assert "__llmeter_bytes__" in bytes_marker + assert len(bytes_marker) == 1 + + # Verify the base64 encoding + import base64 + + decoded = base64.b64decode(bytes_marker["__llmeter_bytes__"]) + assert decoded == b"\xff\xd8\xff\xe0\x00\x10JFIF" + + def test_empty_bytes_object(self): + """Test serialization of an empty bytes object. + + Validates: Requirements 1.1, 1.3 + """ + payload = {"empty": b""} + + # Serialize + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify marker object exists + assert "__llmeter_bytes__" in parsed["empty"] + + # Verify empty bytes decodes correctly + import base64 + + decoded = base64.b64decode(parsed["empty"]["__llmeter_bytes__"]) + assert decoded == b"" + assert len(decoded) == 0 + + def test_large_binary_data_1mb(self): + """Test serialization of large binary data (1MB). + + Validates: Requirements 1.6, 10.1 + """ + import os + + # Create 1MB of random binary data + large_data = os.urandom(1024 * 1024) + payload = {"large_image": large_data} + + # Serialize + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify marker object exists + assert "__llmeter_bytes__" in parsed["large_image"] + + # Verify the data round-trips correctly + import base64 + + decoded = base64.b64decode(parsed["large_image"]["__llmeter_bytes__"]) + assert decoded == large_data + assert len(decoded) == 1024 * 1024 + + def test_multiple_bytes_objects_in_payload(self): + """Test serialization of payload with multiple bytes objects. + + Validates: Requirements 1.1, 1.3, 9.8 + """ + payload = { + "image1": b"first image data", + "image2": b"second image data", + "nested": {"image3": b"third image data"}, + } + + # Serialize + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify all bytes objects have markers + assert "__llmeter_bytes__" in parsed["image1"] + assert "__llmeter_bytes__" in parsed["image2"] + assert "__llmeter_bytes__" in parsed["nested"]["image3"] + + # Verify all decode correctly + import base64 + + assert ( + base64.b64decode(parsed["image1"]["__llmeter_bytes__"]) + == b"first image data" + ) + assert ( + base64.b64decode(parsed["image2"]["__llmeter_bytes__"]) + == b"second image data" + ) + assert ( + base64.b64decode(parsed["nested"]["image3"]["__llmeter_bytes__"]) + == b"third image data" + ) + + def test_bytes_in_list(self): + """Test serialization of bytes objects within lists. + + Validates: Requirements 1.1, 1.3, 1.5 + """ + payload = {"images": [b"image1", b"image2", b"image3"]} + + # Serialize + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify all list items have markers + assert len(parsed["images"]) == 3 + for item in parsed["images"]: + assert "__llmeter_bytes__" in item + assert len(item) == 1 + + def test_mixed_types_with_bytes(self): + """Test serialization of payload with mixed types including bytes. + + Validates: Requirements 1.1, 1.5 + """ + payload = { + "string": "text value", + "number": 42, + "float": 3.14, + "boolean": True, + "null": None, + "list": [1, 2, 3], + "bytes": b"binary data", + "nested": {"more_bytes": b"more binary"}, + } + + # Serialize + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Verify it's valid JSON + parsed = json.loads(serialized) + + # Verify non-bytes types are preserved + assert parsed["string"] == "text value" + assert parsed["number"] == 42 + assert parsed["float"] == 3.14 + assert parsed["boolean"] is True + assert parsed["null"] is None + assert parsed["list"] == [1, 2, 3] + + # Verify bytes have markers + assert "__llmeter_bytes__" in parsed["bytes"] + assert "__llmeter_bytes__" in parsed["nested"]["more_bytes"] + + +class TestLLMeterBytesDecoder: + """Unit tests for llmeter_bytes_decoder function. + + These tests verify specific examples and edge cases for the llmeter_bytes_decoder + function, complementing the property-based tests. + + Requirements: 2.1, 2.2, 2.3, 2.4, 6.2 + """ + + def test_marker_object_decoding(self): + """Test decoding of a marker object with valid base64. + + Validates: Requirements 2.1, 2.2, 2.3 + """ + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Create a marker object with base64-encoded bytes + marker = {"__llmeter_bytes__": "aGVsbG8gd29ybGQ="} # "hello world" in base64 + + # Decode the marker + result = llmeter_bytes_decoder(marker) + + # Verify it returns bytes + assert isinstance(result, bytes) + assert result == b"hello world" + + def test_non_marker_dict_passthrough(self): + """Test that non-marker dicts are returned unchanged. + + Validates: Requirements 2.4 + """ + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Regular dict without marker key + regular_dict = {"key": "value", "number": 42, "nested": {"data": "test"}} + + # Decode should return unchanged + result = llmeter_bytes_decoder(regular_dict) + + # Verify it's the same dict + assert result == regular_dict + assert result is regular_dict # Should be the exact same object + + def test_invalid_base64_error_handling(self): + """Test that invalid base64 in marker raises appropriate error. + + Validates: Requirements 6.2 + """ + import binascii + + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Marker with invalid base64 string + invalid_marker = {"__llmeter_bytes__": "not-valid-base64!!!"} + + # Should raise binascii.Error when trying to decode + with pytest.raises(binascii.Error): + llmeter_bytes_decoder(invalid_marker) + + def test_multi_key_dict_with_marker_key_not_decoded(self): + """Test that multi-key dict containing marker key is not decoded. + + This is a safety check to ensure we only decode single-key marker objects. + + Validates: Requirements 2.4 + """ + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Dict with marker key but also other keys (should not be decoded) + multi_key_dict = { + "__llmeter_bytes__": "aGVsbG8=", + "other_key": "other_value", + } + + # Should return unchanged (not decode) + result = llmeter_bytes_decoder(multi_key_dict) + + # Verify it's returned as-is + assert result == multi_key_dict + assert isinstance(result, dict) + assert "__llmeter_bytes__" in result + assert "other_key" in result + + def test_empty_bytes_decoding(self): + """Test decoding of marker object with empty bytes. + + Validates: Requirements 2.1, 2.3 + """ + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Marker with empty base64 string (empty bytes) + empty_marker = {"__llmeter_bytes__": ""} + + # Decode + result = llmeter_bytes_decoder(empty_marker) + + # Verify it returns empty bytes + assert isinstance(result, bytes) + assert result == b"" + assert len(result) == 0 + + def test_large_binary_data_decoding(self): + """Test decoding of marker object with large binary data. + + Validates: Requirements 2.1, 2.3 + """ + import base64 + import os + + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Create 1MB of random binary data + large_data = os.urandom(1024 * 1024) + base64_encoded = base64.b64encode(large_data).decode("utf-8") + + # Create marker + marker = {"__llmeter_bytes__": base64_encoded} + + # Decode + result = llmeter_bytes_decoder(marker) + + # Verify it matches original data + assert isinstance(result, bytes) + assert result == large_data + assert len(result) == 1024 * 1024 + + def test_nested_structure_with_marker(self): + """Test that decoder works correctly when used with json.loads on nested structures. + + Validates: Requirements 2.1, 2.5 + """ + from llmeter.prompt_utils import llmeter_bytes_decoder + + # JSON string with nested marker objects + json_str = json.dumps( + { + "modelId": "test-model", + "messages": [ + { + "role": "user", + "content": [ + {"text": "What is this?"}, + { + "image": { + "source": { + "bytes": { + "__llmeter_bytes__": "aGVsbG8=" # "hello" + } + } + } + }, + ], + } + ], + } + ) + + # Load with decoder + result = json.loads(json_str, object_hook=llmeter_bytes_decoder) + + # Verify structure is preserved + assert result["modelId"] == "test-model" + assert result["messages"][0]["role"] == "user" + + # Verify bytes are decoded + bytes_value = result["messages"][0]["content"][1]["image"]["source"]["bytes"] + assert isinstance(bytes_value, bytes) + assert bytes_value == b"hello" + + def test_dict_without_marker_key(self): + """Test that dict without marker key is returned unchanged. + + Validates: Requirements 2.4 + """ + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Dict without the marker key + normal_dict = {"data": "value", "count": 123} + + # Should return unchanged + result = llmeter_bytes_decoder(normal_dict) + + assert result == normal_dict + assert isinstance(result, dict) + + def test_marker_with_special_characters(self): + """Test decoding marker with special characters in base64. + + Validates: Requirements 2.1, 2.3 + """ + import base64 + + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Binary data with special characters + special_data = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01" + base64_encoded = base64.b64encode(special_data).decode("utf-8") + + # Create marker + marker = {"__llmeter_bytes__": base64_encoded} + + # Decode + result = llmeter_bytes_decoder(marker) + + # Verify + assert isinstance(result, bytes) + assert result == special_data + + class TestCreatePromptCollection: """Tests for CreatePromptCollection class.""" @@ -178,7 +608,9 @@ def create_payload(input_text, max_tokens=100): prompts = list( load_prompts( - prompt_file, create_payload, create_payload_kwargs={"max_tokens": 50} + prompt_file, + create_payload, + create_payload_kwargs={"max_tokens": 50}, ) ) assert len(prompts) == 1 @@ -195,9 +627,7 @@ def test_load_prompts_from_directory(self): def create_payload(input_text): return {"text": input_text} - prompts = list( - load_prompts(dir_path, create_payload, file_pattern="*.txt") - ) + prompts = list(load_prompts(dir_path, create_payload, file_pattern="*.txt")) assert len(prompts) == 2 def test_load_prompts_skips_empty_lines(self): @@ -216,7 +646,9 @@ def test_load_prompts_handles_exceptions(self): """Test that exceptions in create_payload_fn are handled when loading from directory.""" with tempfile.TemporaryDirectory() as tmpdir: dir_path = Path(tmpdir) - (dir_path / "prompts.txt").write_text("Good prompt\nBad prompt\nAnother good prompt\n") + (dir_path / "prompts.txt").write_text( + "Good prompt\nBad prompt\nAnother good prompt\n" + ) def create_payload(input_text): if "Bad" in input_text: @@ -334,12 +766,315 @@ def test_load_payloads_handles_io_error(self): try: json_file.chmod(0o000) # Should handle the error gracefully - payloads = list(load_payloads(json_file)) + _ = list(load_payloads(json_file)) # May be empty due to permission error finally: # Restore permissions for cleanup json_file.chmod(0o644) + def test_load_payloads_with_marker_objects(self): + """Test loading payload with marker objects. + + Validates: Requirements 2.1, 3.1, 5.2 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Create a payload with marker objects (simulating saved binary content) + payload_with_markers = { + "modelId": "test-model", + "image": { + "source": { + "bytes": { + "__llmeter_bytes__": "aGVsbG8gd29ybGQ=" # "hello world" + } + } + }, + } + + # Write to file + with jsonl_file.open("w") as f: + f.write(json.dumps(payload_with_markers) + "\n") + + # Load using load_payloads + payloads = list(load_payloads(jsonl_file)) + + # Verify payload was loaded + assert len(payloads) == 1 + + # Verify bytes were restored from marker + loaded_payload = payloads[0] + assert loaded_payload["modelId"] == "test-model" + assert isinstance( + loaded_payload["image"]["source"]["bytes"], bytes + ) + assert loaded_payload["image"]["source"]["bytes"] == b"hello world" + + def test_load_payloads_bytes_correctly_restored(self): + """Test that bytes objects are correctly restored from marker objects. + + Validates: Requirements 2.1, 2.3, 5.2 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Create multiple payloads with different bytes content + payloads_with_markers = [ + {"id": 1, "data": {"__llmeter_bytes__": "Zmlyc3Q="}}, # "first" + {"id": 2, "data": {"__llmeter_bytes__": "c2Vjb25k"}}, # "second" + { + "id": 3, + "nested": { + "deep": {"data": {"__llmeter_bytes__": "dGhpcmQ="}} # "third" + }, + }, + ] + + # Write to file + with jsonl_file.open("w") as f: + for payload in payloads_with_markers: + f.write(json.dumps(payload) + "\n") + + # Load payloads + loaded = list(load_payloads(jsonl_file)) + + # Verify all bytes were restored correctly + assert len(loaded) == 3 + assert loaded[0]["data"] == b"first" + assert loaded[1]["data"] == b"second" + assert loaded[2]["nested"]["deep"]["data"] == b"third" + + # Verify they are bytes objects + assert isinstance(loaded[0]["data"], bytes) + assert isinstance(loaded[1]["data"], bytes) + assert isinstance(loaded[2]["nested"]["deep"]["data"], bytes) + + def test_load_payloads_custom_deserializer_parameter(self): + """Test custom deserializer parameter works correctly. + + Validates: Requirements 5.2, 5.4 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Write a simple payload + payload = {"test": "data", "number": 42} + with jsonl_file.open("w") as f: + f.write(json.dumps(payload) + "\n") + + # Custom deserializer that adds a field + def custom_deserializer(json_str): + data = json.loads(json_str) + data["custom_field"] = "added_by_deserializer" + return data + + # Load with custom deserializer + loaded = list(load_payloads(jsonl_file, deserializer=custom_deserializer)) + + # Verify custom deserializer was used + assert len(loaded) == 1 + assert loaded[0]["test"] == "data" + assert loaded[0]["number"] == 42 + assert loaded[0]["custom_field"] == "added_by_deserializer" + + def test_load_payloads_backward_compatibility_old_format(self): + """Test backward compatibility - old format loads successfully. + + When loading payloads saved with the old format (no binary data, no markers), + they should load successfully without any issues. + + Validates: Requirements 3.1, 5.2, 5.4 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "old_format.jsonl" + + # Create payloads in old format (no marker objects) + old_format_payloads = [ + {"modelId": "test-model-1", "prompt": "Hello world", "maxTokens": 100}, + { + "modelId": "test-model-2", + "messages": [ + {"role": "user", "content": "What is the weather?"} + ], + }, + { + "modelId": "test-model-3", + "config": {"temperature": 0.7, "topP": 0.9}, + }, + ] + + # Write using standard json.dumps (old format) + with jsonl_file.open("w") as f: + for payload in old_format_payloads: + f.write(json.dumps(payload) + "\n") + + # Load using new load_payloads (with binary support) + loaded = list(load_payloads(jsonl_file)) + + # Verify all payloads loaded successfully + assert len(loaded) == 3 + + # Verify content matches exactly + assert loaded[0] == old_format_payloads[0] + assert loaded[1] == old_format_payloads[1] + assert loaded[2] == old_format_payloads[2] + + # Verify no marker objects were introduced + for payload in loaded: + assert "__llmeter_bytes__" not in json.dumps(payload) + + def test_load_payloads_round_trip_with_binary_content(self): + """Test round-trip: save with binary, load, verify bytes restored. + + Validates: Requirements 2.1, 2.5, 4.1 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + + # Original payload with binary content + original_payload = { + "modelId": "test-model", + "messages": [ + { + "role": "user", + "content": [ + {"text": "Describe this image"}, + { + "image": { + "format": "jpeg", + "source": {"bytes": b"\xff\xd8\xff\xe0\x00\x10JFIF"}, + } + }, + ], + } + ], + } + + # Save using save_payloads + saved_path = save_payloads(original_payload, output_path) + + # Load using load_payloads + loaded = list(load_payloads(saved_path)) + + # Verify round-trip integrity + assert len(loaded) == 1 + loaded_payload = loaded[0] + + # Verify structure is preserved + assert loaded_payload["modelId"] == original_payload["modelId"] + assert len(loaded_payload["messages"]) == 1 + assert loaded_payload["messages"][0]["role"] == "user" + + # Verify bytes were restored correctly + loaded_bytes = loaded_payload["messages"][0]["content"][1]["image"][ + "source" + ]["bytes"] + original_bytes = original_payload["messages"][0]["content"][1]["image"][ + "source" + ]["bytes"] + + assert isinstance(loaded_bytes, bytes) + assert loaded_bytes == original_bytes + + def test_load_payloads_multiple_marker_objects_in_single_payload(self): + """Test loading payload with multiple marker objects. + + Validates: Requirements 2.1, 9.8 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Payload with multiple marker objects + payload_with_multiple_markers = { + "modelId": "test-model", + "images": [ + {"id": 1, "data": {"__llmeter_bytes__": "aW1hZ2Ux"}}, # "image1" + {"id": 2, "data": {"__llmeter_bytes__": "aW1hZ2Uy"}}, # "image2" + {"id": 3, "data": {"__llmeter_bytes__": "aW1hZ2Uz"}}, # "image3" + ], + } + + # Write to file + with jsonl_file.open("w") as f: + f.write(json.dumps(payload_with_multiple_markers) + "\n") + + # Load payload + loaded = list(load_payloads(jsonl_file)) + + # Verify all marker objects were decoded + assert len(loaded) == 1 + payload = loaded[0] + + assert len(payload["images"]) == 3 + assert payload["images"][0]["data"] == b"image1" + assert payload["images"][1]["data"] == b"image2" + assert payload["images"][2]["data"] == b"image3" + + # Verify all are bytes + for image in payload["images"]: + assert isinstance(image["data"], bytes) + + def test_load_payloads_empty_bytes_marker(self): + """Test loading marker object with empty bytes. + + Validates: Requirements 2.1, 2.3 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Payload with empty bytes marker + payload_with_empty = { + "id": 1, + "empty_data": {"__llmeter_bytes__": ""}, # Empty base64 = empty bytes + } + + # Write to file + with jsonl_file.open("w") as f: + f.write(json.dumps(payload_with_empty) + "\n") + + # Load payload + loaded = list(load_payloads(jsonl_file)) + + # Verify empty bytes were restored + assert len(loaded) == 1 + assert loaded[0]["empty_data"] == b"" + assert isinstance(loaded[0]["empty_data"], bytes) + assert len(loaded[0]["empty_data"]) == 0 + + def test_load_payloads_large_binary_data(self): + """Test loading payload with large binary data (1MB). + + Validates: Requirements 2.1, 10.2 + """ + import base64 + import os + + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Create 1MB of random binary data + large_data = os.urandom(1024 * 1024) + base64_encoded = base64.b64encode(large_data).decode("utf-8") + + # Payload with large marker object + payload_with_large = { + "id": 1, + "large_image": {"__llmeter_bytes__": base64_encoded}, + } + + # Write to file + with jsonl_file.open("w") as f: + f.write(json.dumps(payload_with_large) + "\n") + + # Load payload + loaded = list(load_payloads(jsonl_file)) + + # Verify large bytes were restored correctly + assert len(loaded) == 1 + assert isinstance(loaded[0]["large_image"], bytes) + assert loaded[0]["large_image"] == large_data + assert len(loaded[0]["large_image"]) == 1024 * 1024 + class TestSavePayloads: """Tests for save_payloads function.""" @@ -402,6 +1137,162 @@ def test_save_payloads_returns_path(self): assert isinstance(result_path, Path) assert result_path.exists() + def test_save_payloads_with_bytes_objects(self): + """Test saving payload with bytes objects. + + Validates: Requirements 1.1, 3.3, 5.1 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + payload = { + "modelId": "test-model", + "image": {"source": {"bytes": b"\xff\xd8\xff\xe0\x00\x10JFIF"}}, + } + + result_path = save_payloads(payload, output_path) + + # Verify file was created + assert result_path.exists() + + # Read the file and verify it contains valid JSON + with result_path.open("r") as f: + line = f.readline() + parsed = json.loads(line) + + # Verify structure is preserved + assert parsed["modelId"] == "test-model" + + # Verify bytes are replaced with marker object + assert "__llmeter_bytes__" in parsed["image"]["source"]["bytes"] + assert len(parsed["image"]["source"]["bytes"]) == 1 + + def test_save_payloads_file_contains_valid_json_with_markers(self): + """Test that saved file contains valid JSON with marker objects. + + Validates: Requirements 1.1, 1.3 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + payloads = [ + {"id": 1, "data": b"first bytes"}, + {"id": 2, "data": b"second bytes"}, + {"id": 3, "nested": {"data": b"nested bytes"}}, + ] + + result_path = save_payloads(payloads, output_path) + + # Read and parse each line + with result_path.open("r") as f: + lines = f.readlines() + + assert len(lines) == 3 + + # Verify each line is valid JSON + for i, line in enumerate(lines): + parsed = json.loads(line) + assert parsed["id"] == i + 1 + + # Verify marker objects exist + if i < 2: + assert "__llmeter_bytes__" in parsed["data"] + else: + assert "__llmeter_bytes__" in parsed["nested"]["data"] + + def test_save_payloads_custom_serializer_parameter(self): + """Test custom serializer parameter works correctly. + + Validates: Requirements 5.1, 5.3 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + payload = {"test": "data", "number": 42} + + # Custom serializer that adds a prefix + def custom_serializer(payload_dict): + return "CUSTOM:" + json.dumps(payload_dict) + + result_path = save_payloads( + payload, output_path, serializer=custom_serializer + ) + + # Verify custom serializer was used + with result_path.open("r") as f: + line = f.readline() + + assert line.startswith("CUSTOM:") + # Remove prefix and verify it's valid JSON + json_part = line[7:].strip() + parsed = json.loads(json_part) + assert parsed["test"] == "data" + assert parsed["number"] == 42 + + def test_save_payloads_backward_compatibility_no_bytes(self): + """Test backward compatibility when payload has no bytes. + + When a payload contains no bytes objects, the output should be identical + to using standard json.dumps (no marker objects introduced). + + Validates: Requirements 3.3, 5.1 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + payload = { + "modelId": "test-model", + "prompt": "Hello world", + "maxTokens": 100, + "temperature": 0.7, + } + + # Save with new encoder + result_path = save_payloads(payload, output_path) + + # Read the saved content + with result_path.open("r") as f: + saved_line = f.readline().strip() + + # Compare with standard json.dumps + standard_json = json.dumps(payload) + + # They should be identical (no marker objects introduced) + assert saved_line == standard_json + + # Verify no marker keys exist + parsed = json.loads(saved_line) + assert "__llmeter_bytes__" not in json.dumps(parsed) + + def test_save_payloads_multiple_payloads_with_mixed_content(self): + """Test saving multiple payloads with mixed binary and non-binary content. + + Validates: Requirements 1.1, 3.3 + """ + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + payloads = [ + {"id": 1, "text": "no binary"}, + {"id": 2, "image": b"binary data"}, + {"id": 3, "text": "also no binary"}, + ] + + result_path = save_payloads(payloads, output_path) + + # Read all lines + with result_path.open("r") as f: + lines = f.readlines() + + assert len(lines) == 3 + + # First payload: no marker + parsed1 = json.loads(lines[0]) + assert "__llmeter_bytes__" not in json.dumps(parsed1) + + # Second payload: has marker + parsed2 = json.loads(lines[1]) + assert "__llmeter_bytes__" in parsed2["image"] + + # Third payload: no marker + parsed3 = json.loads(lines[2]) + assert "__llmeter_bytes__" not in json.dumps(parsed3) + # Property-based tests class TestPromptUtilsProperties: @@ -410,7 +1301,7 @@ class TestPromptUtilsProperties: @given( st.lists( st.dictionaries( - st.text(min_size=1, max_size=50), + st.text(min_size=1, max_size=50).filter(lambda k: k != "__llmeter_bytes__"), st.one_of(st.text(max_size=100), st.integers(), st.booleans()), min_size=1, max_size=10, @@ -470,3 +1361,585 @@ def test_create_collection_length_combinations(self, input_lengths, output_lengt result = collection.create_collection() expected_length = len(input_lengths) * len(output_lengths) assert len(result) == expected_length + + +class TestErrorHandling: + """Unit tests for error conditions in serialization/deserialization. + + These tests verify that errors are handled gracefully with descriptive messages. + + Requirements: 6.1, 6.2, 6.3, 6.4 + """ + + def test_unserializable_type_error_message(self): + """Test that unserializable types raise TypeError with type information. + + Validates: Requirements 6.1 + """ + + class CustomUnserializableObject: + """A custom class that cannot be serialized to JSON.""" + + def __init__(self, value): + self.value = value + + payload = { + "modelId": "test-model", + "custom_object": CustomUnserializableObject(42), + } + + # Should raise TypeError when trying to serialize + with pytest.raises(TypeError) as exc_info: + json.dumps(payload, cls=LLMeterBytesEncoder) + + # Verify error message contains type information + error_message = str(exc_info.value) + assert "CustomUnserializableObject" in error_message + + def test_invalid_json_error_handling(self): + """Test that invalid JSON raises JSONDecodeError with descriptive message. + + Validates: Requirements 6.2 + """ + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "invalid.jsonl" + + # Write invalid JSON to file + with jsonl_file.open("w") as f: + f.write("This is not valid JSON at all\n") + + # Should handle JSONDecodeError gracefully + # load_payloads skips invalid lines, so we need to check the behavior + payloads = list(load_payloads(jsonl_file)) + + # Invalid line should be skipped + assert len(payloads) == 0 + + def test_invalid_base64_error_handling(self): + """Test that invalid base64 in marker objects raises binascii.Error. + + Validates: Requirements 6.2, 6.3 + """ + import binascii + + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Create a marker object with invalid base64 (incorrect padding) + # Base64 strings must have length that is a multiple of 4 + # A single character will trigger binascii.Error + invalid_marker = { + "__llmeter_bytes__": "a" + } + + # Should raise binascii.Error when trying to decode + with pytest.raises(binascii.Error): + llmeter_bytes_decoder(invalid_marker) + + def test_invalid_base64_in_load_payloads(self): + """Test that invalid base64 in saved payload is handled during load. + + Validates: Requirements 6.2, 6.3 + """ + import binascii + + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "invalid_base64.jsonl" + + # Create a payload with invalid base64 in marker object + # Using a single character which will trigger binascii.Error + invalid_payload = { + "modelId": "test-model", + "image": { + "source": { + "bytes": { + "__llmeter_bytes__": "a" + } + } + }, + } + + # Write to file + with jsonl_file.open("w") as f: + f.write(json.dumps(invalid_payload) + "\n") + + # Should raise binascii.Error when trying to load + with pytest.raises(binascii.Error): + list(load_payloads(jsonl_file)) + + def test_custom_serializer_exception_propagation(self): + """Test that custom serializer exceptions are propagated correctly. + + Validates: Requirements 6.4 + """ + + class CustomSerializerError(Exception): + """Custom exception for testing.""" + + pass + + def failing_serializer(payload): + """A custom serializer that always raises an exception.""" + raise CustomSerializerError("Custom serializer failed!") + + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + payload = {"test": "data"} + + # Should propagate the custom exception + with pytest.raises(CustomSerializerError) as exc_info: + save_payloads(payload, output_path, serializer=failing_serializer) + + # Verify the exception message is preserved + assert "Custom serializer failed!" in str(exc_info.value) + + def test_custom_deserializer_exception_propagation(self): + """Test that custom deserializer exceptions are propagated correctly. + + Validates: Requirements 6.4 + """ + + class CustomDeserializerError(Exception): + """Custom exception for testing.""" + + pass + + def failing_deserializer(json_str): + """A custom deserializer that always raises an exception.""" + raise CustomDeserializerError("Custom deserializer failed!") + + with tempfile.TemporaryDirectory() as tmpdir: + jsonl_file = Path(tmpdir) / "payload.jsonl" + + # Write a valid payload + payload = {"test": "data"} + with jsonl_file.open("w") as f: + f.write(json.dumps(payload) + "\n") + + # Should propagate the custom exception + with pytest.raises(CustomDeserializerError) as exc_info: + list(load_payloads(jsonl_file, deserializer=failing_deserializer)) + + # Verify the exception message is preserved + assert "Custom deserializer failed!" in str(exc_info.value) + + def test_unserializable_nested_object_error(self): + """Test error message for unserializable object in nested structure. + + Validates: Requirements 6.1, 6.3 + """ + + class NestedCustomObject: + """A custom class for testing nested error handling.""" + + def __init__(self): + self.data = "test" + + payload = { + "modelId": "test-model", + "messages": [ + { + "role": "user", + "content": [ + {"text": "Hello"}, + {"custom": NestedCustomObject()}, + ], + } + ], + } + + # Should raise TypeError with type information + with pytest.raises(TypeError) as exc_info: + json.dumps(payload, cls=LLMeterBytesEncoder) + + # Verify error message contains the type name + error_message = str(exc_info.value) + assert "NestedCustomObject" in error_message + + def test_bytes_serialization_with_encoding_error(self): + """Test that bytes serialization handles all byte values correctly. + + This test verifies that bytes with any value (0-255) can be serialized + without encoding errors. + + Validates: Requirements 1.1, 6.1 + """ + # Create bytes with all possible byte values + all_bytes = bytes(range(256)) + payload = {"data": all_bytes} + + # Should serialize without errors + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Should be valid JSON + parsed = json.loads(serialized) + + # Verify marker exists + assert "__llmeter_bytes__" in parsed["data"] + + # Verify round-trip works + from llmeter.prompt_utils import llmeter_bytes_decoder + + deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) + assert deserialized["data"] == all_bytes + + def test_empty_payload_serialization(self): + """Test that empty payloads are handled correctly. + + Validates: Requirements 1.1, 6.1 + """ + # Empty dict + empty_payload = {} + + # Should serialize without errors + serialized = json.dumps(empty_payload, cls=LLMeterBytesEncoder) + assert serialized == "{}" + + # Empty list + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) + empty_list = [] + + # Should handle empty list + result_path = save_payloads(empty_list, output_path) + + # File should exist but be empty + with result_path.open("r") as f: + content = f.read() + assert content == "" + + def test_none_value_serialization(self): + """Test that None values are handled correctly. + + Validates: Requirements 1.1, 1.5 + """ + payload = { + "modelId": "test-model", + "optional_field": None, + "nested": {"also_none": None}, + } + + # Should serialize without errors + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Should be valid JSON + parsed = json.loads(serialized) + + # Verify None values are preserved + assert parsed["optional_field"] is None + assert parsed["nested"]["also_none"] is None + + def test_unicode_in_payload_with_bytes(self): + """Test that Unicode strings work correctly alongside bytes. + + Validates: Requirements 1.1, 1.5 + """ + payload = { + "text": "Hello 世界 🌍", + "emoji": "🎉🎊🎈", + "bytes": b"\xff\xfe\xfd", + } + + # Should serialize without errors + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Should be valid JSON + parsed = json.loads(serialized) + + # Verify Unicode is preserved + assert parsed["text"] == "Hello 世界 🌍" + assert parsed["emoji"] == "🎉🎊🎈" + + # Verify bytes have marker + assert "__llmeter_bytes__" in parsed["bytes"] + + +class TestPerformance: + """Performance tests for serialization/deserialization. + + These tests verify that serialization and deserialization of large binary data + completes within acceptable time limits and doesn't create unnecessary data copies. + + Requirements: 10.1, 10.2, 10.3, 10.4 + """ + + def test_1mb_image_serialization_performance(self): + """Test that 1MB image serialization completes within 100ms. + + Validates: Requirements 10.1 + """ + import os + import time + + # Create 1MB of random binary data (simulating an image) + large_image = os.urandom(1024 * 1024) + payload = { + "modelId": "test-model", + "messages": [ + { + "role": "user", + "content": [ + {"text": "Describe this image"}, + { + "image": { + "format": "jpeg", + "source": {"bytes": large_image}, + } + }, + ], + } + ], + } + + # Measure serialization time + start_time = time.perf_counter() + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + end_time = time.perf_counter() + + # Calculate elapsed time in milliseconds + elapsed_ms = (end_time - start_time) * 1000 + + # Verify serialization completed within 100ms + assert ( + elapsed_ms < 100 + ), f"Serialization took {elapsed_ms:.2f}ms, expected < 100ms" + + # Verify the result is valid JSON + assert isinstance(serialized, str) + parsed = json.loads(serialized) + assert "__llmeter_bytes__" in parsed["messages"][0]["content"][1]["image"][ + "source" + ]["bytes"] + + def test_1mb_image_deserialization_performance(self): + """Test that 1MB image deserialization completes within 100ms. + + Validates: Requirements 10.2 + """ + import os + import time + + # Create 1MB of random binary data + large_image = os.urandom(1024 * 1024) + payload = { + "modelId": "test-model", + "image": {"source": {"bytes": large_image}}, + } + + # First serialize the payload + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Measure deserialization time + from llmeter.prompt_utils import llmeter_bytes_decoder + + start_time = time.perf_counter() + deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) + end_time = time.perf_counter() + + # Calculate elapsed time in milliseconds + elapsed_ms = (end_time - start_time) * 1000 + + # Verify deserialization completed within 100ms + assert ( + elapsed_ms < 100 + ), f"Deserialization took {elapsed_ms:.2f}ms, expected < 100ms" + + # Verify the result is correct + assert isinstance(deserialized["image"]["source"]["bytes"], bytes) + assert deserialized["image"]["source"]["bytes"] == large_image + + def test_serialization_no_unnecessary_copies(self): + """Test that serialization doesn't create unnecessary data copies. + + This test verifies that the serialization process is memory-efficient + by checking that the base64 encoding is done in-place without creating + multiple intermediate copies of the binary data. + + Validates: Requirements 10.3 + """ + import os + import sys + + # Create a moderately large binary payload (512KB) + binary_data = os.urandom(512 * 1024) + payload = {"data": binary_data} + + # Get initial memory usage (approximate) + initial_size = sys.getsizeof(binary_data) + + # Serialize the payload + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Parse to verify structure + parsed = json.loads(serialized) + assert "__llmeter_bytes__" in parsed["data"] + + # The serialized string should be roughly 4/3 the size of the original + # (base64 encoding overhead) plus JSON structure overhead + # We verify it's not significantly larger (which would indicate copies) + serialized_size = sys.getsizeof(serialized) + base64_expected_size = (initial_size * 4 // 3) + 1000 # +1000 for JSON overhead + + # Allow 50% overhead for Python string internals and JSON structure + max_acceptable_size = base64_expected_size * 1.5 + + assert ( + serialized_size < max_acceptable_size + ), f"Serialized size {serialized_size} exceeds expected {max_acceptable_size}" + + def test_deserialization_no_unnecessary_copies(self): + """Test that deserialization doesn't create unnecessary data copies. + + This test verifies that the deserialization process is memory-efficient + by checking that the base64 decoding is done efficiently without creating + multiple intermediate copies of the binary data. + + Validates: Requirements 10.4 + """ + import os + import sys + + # Create a moderately large binary payload (512KB) + binary_data = os.urandom(512 * 1024) + payload = {"data": binary_data} + + # Serialize first + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Get the size of the serialized data + serialized_size = sys.getsizeof(serialized) + + # Deserialize + from llmeter.prompt_utils import llmeter_bytes_decoder + + deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) + + # Verify the deserialized bytes match original + assert deserialized["data"] == binary_data + + # The deserialized bytes should be approximately the original size + deserialized_size = sys.getsizeof(deserialized["data"]) + original_size = sys.getsizeof(binary_data) + + # Allow small overhead for Python object internals + # bytes objects have minimal overhead + max_acceptable_size = original_size * 1.1 + + assert ( + deserialized_size < max_acceptable_size + ), f"Deserialized size {deserialized_size} exceeds expected {max_acceptable_size}" + + def test_round_trip_performance_with_multiple_images(self): + """Test round-trip performance with multiple large images. + + This test verifies that serialization and deserialization remain performant + even with multiple large binary objects in a single payload. + + Validates: Requirements 10.1, 10.2 + """ + import os + import time + + # Create payload with 3 images of 512KB each (total ~1.5MB) + images = [os.urandom(512 * 1024) for _ in range(3)] + payload = { + "modelId": "test-model", + "images": [ + {"id": i, "data": img} for i, img in enumerate(images) + ], + } + + # Measure serialization time + start_time = time.perf_counter() + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialize_time = (time.perf_counter() - start_time) * 1000 + + # Measure deserialization time + from llmeter.prompt_utils import llmeter_bytes_decoder + + start_time = time.perf_counter() + deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) + deserialize_time = (time.perf_counter() - start_time) * 1000 + + # With 3 images, we allow proportionally more time (but still reasonable) + # Each operation should complete in under 200ms for 1.5MB total + assert ( + serialize_time < 200 + ), f"Serialization took {serialize_time:.2f}ms, expected < 200ms" + assert ( + deserialize_time < 200 + ), f"Deserialization took {deserialize_time:.2f}ms, expected < 200ms" + + # Verify correctness + assert len(deserialized["images"]) == 3 + for i, img in enumerate(images): + assert deserialized["images"][i]["data"] == img + + def test_serialization_performance_scales_linearly(self): + """Test that serialization performance scales linearly with data size. + + This test verifies that doubling the data size roughly doubles the time, + indicating no algorithmic inefficiencies. + + Validates: Requirements 10.1, 10.3 + """ + import os + import time + + # Test with 256KB + small_data = os.urandom(256 * 1024) + small_payload = {"data": small_data} + + start_time = time.perf_counter() + json.dumps(small_payload, cls=LLMeterBytesEncoder) + small_time = time.perf_counter() - start_time + + # Test with 512KB (2x size) + large_data = os.urandom(512 * 1024) + large_payload = {"data": large_data} + + start_time = time.perf_counter() + json.dumps(large_payload, cls=LLMeterBytesEncoder) + large_time = time.perf_counter() - start_time + + # Large should take roughly 2x the time (allow 3x for variance) + # This verifies linear scaling, not quadratic or worse + assert ( + large_time < small_time * 3 + ), f"Performance doesn't scale linearly: {small_time:.4f}s vs {large_time:.4f}s" + + def test_deserialization_performance_scales_linearly(self): + """Test that deserialization performance scales linearly with data size. + + This test verifies that doubling the data size roughly doubles the time, + indicating no algorithmic inefficiencies. + + Validates: Requirements 10.2, 10.4 + """ + import os + import time + + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Test with 256KB + small_data = os.urandom(256 * 1024) + small_payload = {"data": small_data} + small_serialized = json.dumps(small_payload, cls=LLMeterBytesEncoder) + + start_time = time.perf_counter() + json.loads(small_serialized, object_hook=llmeter_bytes_decoder) + small_time = time.perf_counter() - start_time + + # Test with 512KB (2x size) + large_data = os.urandom(512 * 1024) + large_payload = {"data": large_data} + large_serialized = json.dumps(large_payload, cls=LLMeterBytesEncoder) + + start_time = time.perf_counter() + json.loads(large_serialized, object_hook=llmeter_bytes_decoder) + large_time = time.perf_counter() - start_time + + # Large should take roughly 2x the time (allow 3x for variance) + # This verifies linear scaling, not quadratic or worse + assert ( + large_time < small_time * 3 + ), f"Performance doesn't scale linearly: {small_time:.4f}s vs {large_time:.4f}s" diff --git a/tests/unit/test_results.py b/tests/unit/test_results.py index 9262d16..cdc4deb 100644 --- a/tests/unit/test_results.py +++ b/tests/unit/test_results.py @@ -306,3 +306,278 @@ def test_save_method_existing_responses(sample_result: Result, temp_dir: UPath): responses = [json.loads(line) for line in f] assert len(responses) == 6 # 5 original + 1 extra assert responses[-1]["id"] == "extra_response" + + +# Tests for InvocationResponseEncoder +# Validates Requirements: 7.1, 7.2 + + +def test_invocation_response_encoder_handles_bytes(): + """Test that InvocationResponseEncoder handles bytes objects via parent class. + + Validates: Requirements 7.1, 7.2 + """ + from llmeter.results import InvocationResponseEncoder + + # Test bytes object encoding + payload = {"image": {"bytes": b"\xff\xd8\xff\xe0"}} + encoded = json.dumps(payload, cls=InvocationResponseEncoder) + + # Verify marker object is present + assert "__llmeter_bytes__" in encoded + + # Verify it can be decoded back + decoded = json.loads(encoded, object_hook=lambda d: d) + assert decoded["image"]["bytes"]["__llmeter_bytes__"] == "/9j/4A==" + + +def test_invocation_response_encoder_str_fallback(): + """Test that InvocationResponseEncoder falls back to str() for custom objects. + + Validates: Requirements 7.1, 7.2 + """ + from llmeter.results import InvocationResponseEncoder + + # Create a custom object with __str__ method + class CustomObject: + def __str__(self): + return "custom_string_representation" + + payload = {"custom": CustomObject()} + encoded = json.dumps(payload, cls=InvocationResponseEncoder) + + # Verify str() fallback was used + decoded = json.loads(encoded) + assert decoded["custom"] == "custom_string_representation" + + +def test_invocation_response_encoder_none_on_str_failure(): + """Test that InvocationResponseEncoder returns None when str() conversion fails. + + Validates: Requirements 7.1, 7.2 + """ + from llmeter.results import InvocationResponseEncoder + + # Create a custom object that raises exception in __str__ + class FailingObject: + def __str__(self): + raise RuntimeError("Cannot convert to string") + + payload = {"failing": FailingObject()} + encoded = json.dumps(payload, cls=InvocationResponseEncoder) + + # Verify None was returned + decoded = json.loads(encoded) + assert decoded["failing"] is None + + +def test_invocation_response_encoder_mixed_types(): + """Test that InvocationResponseEncoder handles mixed types correctly. + + Validates: Requirements 7.1, 7.2 + """ + from llmeter.results import InvocationResponseEncoder + + class CustomObject: + def __str__(self): + return "custom" + + # Mix of bytes, custom objects, and standard types + payload = { + "bytes_field": b"\x00\x01\x02", + "custom_field": CustomObject(), + "string_field": "normal string", + "int_field": 42, + "nested": { + "bytes": b"\xff\xfe", + "custom": CustomObject() + } + } + + encoded = json.dumps(payload, cls=InvocationResponseEncoder) + decoded = json.loads(encoded) + + # Verify bytes were encoded with marker + assert "__llmeter_bytes__" in encoded + + # Verify custom objects were converted to strings + assert decoded["custom_field"] == "custom" + assert decoded["nested"]["custom"] == "custom" + + # Verify standard types remain unchanged + assert decoded["string_field"] == "normal string" + assert decoded["int_field"] == 42 + + +# Tests for InvocationResponse.to_json +# Validates Requirements: 7.1, 7.2, 7.3 + + +def test_invocation_response_to_json_with_binary_content(): + """Test InvocationResponse.to_json with binary content in input_payload. + + Validates: Requirements 7.1, 7.2 + """ + # Create InvocationResponse with binary content in input_payload + response = InvocationResponse( + response_text="This is an image of a cat", + input_payload={ + "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "messages": [{ + "role": "user", + "content": [ + {"text": "What is in this image?"}, + { + "image": { + "format": "jpeg", + "source": {"bytes": b"\xff\xd8\xff\xe0\x00\x10JFIF"} + } + } + ] + }] + }, + id="test-123", + time_to_first_token=0.5, + time_to_last_token=1.2, + num_tokens_input=15, + num_tokens_output=8 + ) + + # Serialize to JSON + json_str = response.to_json() + + # Verify it's valid JSON + parsed = json.loads(json_str) + + # Verify marker object is present for bytes + assert "__llmeter_bytes__" in json_str + + # Verify structure is preserved + assert parsed["response_text"] == "This is an image of a cat" + assert parsed["id"] == "test-123" + assert parsed["input_payload"]["modelId"] == "anthropic.claude-3-haiku-20240307-v1:0" + + # Verify bytes were encoded with marker + bytes_marker = parsed["input_payload"]["messages"][0]["content"][1]["image"]["source"]["bytes"] + assert "__llmeter_bytes__" in bytes_marker + assert bytes_marker["__llmeter_bytes__"] == "/9j/4AAQSkZJRg==" + + +def test_invocation_response_to_json_valid_json_output(): + """Test that InvocationResponse.to_json produces valid JSON. + + Validates: Requirements 7.1, 7.2 + """ + response = InvocationResponse( + response_text="Test response", + input_payload={ + "image_data": b"\x89PNG\r\n\x1a\n", + "text": "Analyze this image" + }, + id="test-456" + ) + + # Serialize to JSON + json_str = response.to_json() + + # Verify it's valid JSON (should not raise exception) + parsed = json.loads(json_str) + + # Verify all fields are present + assert "response_text" in parsed + assert "input_payload" in parsed + assert "id" in parsed + + # Verify bytes were properly encoded + assert isinstance(parsed["input_payload"]["image_data"], dict) + assert "__llmeter_bytes__" in parsed["input_payload"]["image_data"] + + +def test_invocation_response_to_json_round_trip(): + """Test round-trip serialization/deserialization with InvocationResponse. + + Validates: Requirements 7.1, 7.2, 7.3 + """ + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Create original response with binary content + original_payload = { + "video": { + "format": "mp4", + "source": {"bytes": b"\x00\x00\x00\x18ftypmp42"} + }, + "prompt": "Describe this video" + } + + response = InvocationResponse( + response_text="A video of a sunset", + input_payload=original_payload, + id="test-789", + time_to_first_token=1.0, + time_to_last_token=2.5 + ) + + # Serialize to JSON + json_str = response.to_json() + + # Deserialize back + parsed = json.loads(json_str, object_hook=llmeter_bytes_decoder) + + # Verify input_payload was restored correctly + assert parsed["input_payload"] == original_payload + assert isinstance(parsed["input_payload"]["video"]["source"]["bytes"], bytes) + assert parsed["input_payload"]["video"]["source"]["bytes"] == b"\x00\x00\x00\x18ftypmp42" + + # Verify other fields + assert parsed["response_text"] == "A video of a sunset" + assert parsed["id"] == "test-789" + assert parsed["time_to_first_token"] == 1.0 + assert parsed["time_to_last_token"] == 2.5 + + +def test_invocation_response_to_json_no_binary_content(): + """Test InvocationResponse.to_json with no binary content (backward compatibility). + + Validates: Requirements 7.1, 7.2 + """ + response = InvocationResponse( + response_text="Simple text response", + input_payload={ + "modelId": "test-model", + "prompt": "Hello, world!" + }, + id="test-no-binary" + ) + + # Serialize to JSON + json_str = response.to_json() + + # Verify no marker objects are present + assert "__llmeter_bytes__" not in json_str + + # Verify it's valid JSON + parsed = json.loads(json_str) + assert parsed["input_payload"]["prompt"] == "Hello, world!" + + +def test_invocation_response_to_json_with_kwargs(): + """Test that InvocationResponse.to_json passes through kwargs to json.dumps. + + Validates: Requirements 7.4 + """ + response = InvocationResponse( + response_text="Test", + input_payload={"data": b"\x01\x02\x03"}, + id="test-kwargs" + ) + + # Test with indent parameter + json_str = response.to_json(indent=2) + + # Verify indentation is present + assert "\n" in json_str + assert " " in json_str + + # Verify it's still valid JSON + parsed = json.loads(json_str) + assert parsed["id"] == "test-kwargs" diff --git a/tests/unit/test_serialization_properties.py b/tests/unit/test_serialization_properties.py new file mode 100644 index 0000000..fcb1f6b --- /dev/null +++ b/tests/unit/test_serialization_properties.py @@ -0,0 +1,768 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Property-based tests for JSON serialization optimization. + +This module contains property-based tests for the binary content serialization +feature using Hypothesis. These tests verify that the serialization and +deserialization of payloads containing bytes objects maintains data integrity +and correctness across a wide range of inputs. + +Feature: json-serialization-optimization +""" + +import base64 +import json + +from hypothesis import given, settings +from hypothesis import strategies as st +from hypothesis.strategies import composite + +from llmeter.prompt_utils import LLMeterBytesEncoder + +# Test infrastructure is set up and ready for property test implementation +# This file will contain property-based tests for: +# - Property 1: Serialization produces valid JSON with marker objects +# - Property 2: Serialization preserves non-binary structure +# - Property 3: Deserialization restores bytes from markers +# - Property 4: Deserialization preserves non-marker dicts +# - Property 5: Round-trip serialization preserves data integrity +# - Property 6: Round-trip preserves dictionary key ordering +# - Property 7: Non-binary payloads are backward compatible +# - Property 8: Serialization errors are descriptive +# - Property 9: Deserialization errors are descriptive +# - Property 10: InvocationResponse serializes binary payloads + + +# Custom strategies for generating test data +@composite +def json_value_strategy(draw, allow_bytes=True): + """Generate JSON-compatible values, optionally including bytes.""" + base_types = st.one_of( + st.none(), + st.booleans(), + st.integers(), + st.floats(allow_nan=False, allow_infinity=False), + st.text(max_size=100), + ) + + if allow_bytes: + return draw(st.one_of(base_types, st.binary(max_size=1000))) + return draw(base_types) + + +@composite +def nested_dict_strategy(draw, max_depth=3, allow_bytes=True): + """Generate nested dictionary structures with optional bytes objects.""" + if max_depth == 0: + return draw(json_value_strategy(allow_bytes=allow_bytes)) + + return draw( + st.dictionaries( + keys=st.text(min_size=1, max_size=20).filter(lambda k: k != "__llmeter_bytes__"), + values=st.one_of( + json_value_strategy(allow_bytes=allow_bytes), + nested_dict_strategy(max_depth=max_depth - 1, allow_bytes=allow_bytes), + st.lists(json_value_strategy(allow_bytes=allow_bytes), max_size=5), + ), + max_size=5, + ) + ) + + +@composite +def payload_with_bytes_strategy(draw): + """Generate payloads that contain at least one bytes object.""" + # Start with a nested dict that may or may not have bytes + payload = draw(nested_dict_strategy(max_depth=3, allow_bytes=True)) + + # Ensure at least one bytes object exists + # Add a guaranteed bytes field + payload["_test_bytes"] = draw(st.binary(min_size=1, max_size=500)) + + return payload + + +@composite +def payload_without_bytes_strategy(draw): + """Generate payloads that contain no bytes objects.""" + return draw(nested_dict_strategy(max_depth=3, allow_bytes=False)) + + +class TestSerializationProperties: + """Property-based tests for serialization behavior.""" + + @given(payload_with_bytes_strategy()) + @settings(max_examples=100) + def test_property_1_serialization_produces_valid_json_with_marker_objects( + self, payload + ): + """Property 1: Serialization produces valid JSON with marker objects. + + **Validates: Requirements 1.1, 1.3** + + For any payload containing bytes objects at any nesting level, serializing + the payload SHALL produce valid JSON where each bytes object is replaced + with a marker object containing the key "__llmeter_bytes__" and a + base64-encoded string value. + """ + # Serialize the payload + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Verify it's valid JSON by parsing it + parsed = json.loads(serialized) + assert isinstance(parsed, dict) + + # Verify that marker objects exist in the serialized string + assert "__llmeter_bytes__" in serialized + + # Helper function to verify marker objects in nested structures + def verify_markers(obj): + """Recursively verify that bytes are replaced with marker objects.""" + if isinstance(obj, dict): + # Check if this is a marker object + if "__llmeter_bytes__" in obj: + # Verify it's a valid marker object + assert len(obj) == 1, "Marker object should have only one key" + base64_str = obj["__llmeter_bytes__"] + assert isinstance(base64_str, str), ( + "Marker value should be a string" + ) + # Verify it's valid base64 by attempting to decode + try: + base64.b64decode(base64_str) + except Exception as e: + raise AssertionError(f"Invalid base64 in marker object: {e}") + else: + # Recursively check nested structures + for value in obj.values(): + verify_markers(value) + elif isinstance(obj, list): + for item in obj: + verify_markers(item) + + # Verify all marker objects in the parsed structure + verify_markers(parsed) + + +class TestDeserializationProperties: + """Property-based tests for deserialization behavior.""" + + @given(st.binary(min_size=0, max_size=1000)) + @settings(max_examples=100) + def test_property_3_deserialization_restores_bytes_from_markers( + self, original_bytes + ): + """Property 3: Deserialization restores bytes from markers. + + **Validates: Requirements 2.1** + + For any JSON string containing marker objects with the "__llmeter_bytes__" + key, deserializing SHALL convert each marker object back to the original + bytes object by base64-decoding the string value. + """ + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Create a marker object from the original bytes + base64_str = base64.b64encode(original_bytes).decode("utf-8") + marker_object = {"__llmeter_bytes__": base64_str} + + # Decode the marker object + decoded_bytes = llmeter_bytes_decoder(marker_object) + + # Verify the decoded bytes match the original + assert isinstance(decoded_bytes, bytes), "Decoder should return bytes" + assert decoded_bytes == original_bytes, ( + "Decoded bytes should match original bytes" + ) + + @given(payload_without_bytes_strategy()) + @settings(max_examples=100) + def test_property_4_deserialization_preserves_non_marker_dicts(self, payload): + """Property 4: Deserialization preserves non-marker dicts. + + **Validates: Requirements 2.4** + + For any payload containing dictionaries without the "__llmeter_bytes__" + marker key, deserializing SHALL return those dictionaries unchanged. + """ + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Helper function to recursively apply decoder to all dicts + def apply_decoder_recursively(obj): + """Apply decoder to all dicts in the structure.""" + if isinstance(obj, dict): + # Apply decoder to this dict + decoded = llmeter_bytes_decoder(obj) + # If it's still a dict (not converted to bytes), recurse + if isinstance(decoded, dict): + return {k: apply_decoder_recursively(v) for k, v in decoded.items()} + return decoded + elif isinstance(obj, list): + return [apply_decoder_recursively(item) for item in obj] + else: + return obj + + # Apply decoder recursively to the entire payload + decoded_payload = apply_decoder_recursively(payload) + + # Verify the payload is unchanged + assert decoded_payload == payload, ( + "Decoder should return non-marker dicts unchanged" + ) + + # Verify no bytes objects were introduced + def verify_no_bytes(obj): + """Recursively verify no bytes objects exist.""" + if isinstance(obj, bytes): + raise AssertionError("Decoder introduced bytes object unexpectedly") + elif isinstance(obj, dict): + for value in obj.values(): + verify_no_bytes(value) + elif isinstance(obj, list): + for item in obj: + verify_no_bytes(item) + + verify_no_bytes(decoded_payload) + + +class TestRoundTripProperties: + """Property-based tests for round-trip serialization integrity.""" + + @given(payload_with_bytes_strategy()) + @settings(max_examples=100) + def test_property_2_serialization_preserves_non_binary_structure(self, payload): + """Property 2: Serialization preserves non-binary structure. + + **Validates: Requirements 1.5** + + For any payload structure (keys, nesting levels, value types except bytes), + serializing then parsing the JSON SHALL preserve the structure identically, + with only bytes objects replaced by marker objects. + + # Feature: json-serialization-optimization, Property 2: Serialization preserves non-binary structure + """ + # Serialize the payload + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Parse the JSON (without decoding markers back to bytes) + parsed = json.loads(serialized) + + # Helper function to verify structure preservation + def verify_structure(original, parsed_obj, path=""): + """Recursively verify that structure is preserved except for bytes.""" + if isinstance(original, bytes): + # Bytes should be replaced with marker object + assert isinstance(parsed_obj, dict), ( + f"At {path}: Expected marker dict for bytes, got {type(parsed_obj)}" + ) + assert "__llmeter_bytes__" in parsed_obj, ( + f"At {path}: Expected marker object for bytes" + ) + assert len(parsed_obj) == 1, ( + f"At {path}: Marker object should have only one key" + ) + # Verify the base64 string can be decoded back to original bytes + decoded = base64.b64decode(parsed_obj["__llmeter_bytes__"]) + assert decoded == original, ( + f"At {path}: Decoded bytes don't match original" + ) + elif isinstance(original, dict): + # Dict structure should be preserved + assert isinstance(parsed_obj, dict), ( + f"At {path}: Expected dict, got {type(parsed_obj)}" + ) + assert set(original.keys()) == set(parsed_obj.keys()), ( + f"At {path}: Dict keys differ. Original: {set(original.keys())}, " + f"Parsed: {set(parsed_obj.keys())}" + ) + # Verify nesting is preserved + for key in original.keys(): + verify_structure( + original[key], + parsed_obj[key], + f"{path}.{key}" if path else key + ) + elif isinstance(original, list): + # List structure should be preserved + assert isinstance(parsed_obj, list), ( + f"At {path}: Expected list, got {type(parsed_obj)}" + ) + assert len(original) == len(parsed_obj), ( + f"At {path}: List lengths differ. Original: {len(original)}, " + f"Parsed: {len(parsed_obj)}" + ) + # Verify each element + for i, (orig_item, parsed_item) in enumerate(zip(original, parsed_obj)): + verify_structure( + orig_item, + parsed_item, + f"{path}[{i}]" + ) + else: + # Primitive types should be preserved exactly + # Note: We need exact type matching here, not isinstance checks + assert type(original) is type(parsed_obj), ( # noqa: E721 + f"At {path}: Type mismatch. Original: {type(original)}, " + f"Parsed: {type(parsed_obj)}" + ) + assert original == parsed_obj, ( + f"At {path}: Values differ. Original: {original}, " + f"Parsed: {parsed_obj}" + ) + + # Verify structure preservation throughout + verify_structure(payload, parsed) + + @given(payload_with_bytes_strategy()) + @settings(max_examples=100) + def test_property_5_round_trip_serialization_preserves_data_integrity( + self, payload + ): + """Property 5: Round-trip serialization preserves data integrity. + + **Validates: Requirements 4.1, 2.5** + + For any valid payload with binary content, the property + deserialize(serialize(payload)) == payload SHALL hold, preserving + byte-for-byte equality of all bytes objects and exact equality of + all other values. + """ + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Serialize the payload + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Deserialize the payload + deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) + + # Verify round-trip equality + assert deserialized == payload, ( + "Round-trip serialization should preserve data integrity" + ) + + # Helper function to verify byte-for-byte equality of bytes objects + def verify_bytes_equality(original, restored, path=""): + """Recursively verify that bytes objects are byte-for-byte equal.""" + if isinstance(original, bytes): + assert isinstance(restored, bytes), ( + f"At {path}: Expected bytes, got {type(restored)}" + ) + assert original == restored, ( + f"At {path}: Bytes objects differ" + ) + elif isinstance(original, dict): + assert isinstance(restored, dict), ( + f"At {path}: Expected dict, got {type(restored)}" + ) + assert set(original.keys()) == set(restored.keys()), ( + f"At {path}: Dict keys differ" + ) + for key in original.keys(): + verify_bytes_equality( + original[key], + restored[key], + f"{path}.{key}" if path else key + ) + elif isinstance(original, list): + assert isinstance(restored, list), ( + f"At {path}: Expected list, got {type(restored)}" + ) + assert len(original) == len(restored), ( + f"At {path}: List lengths differ" + ) + for i, (orig_item, rest_item) in enumerate(zip(original, restored)): + verify_bytes_equality( + orig_item, + rest_item, + f"{path}[{i}]" + ) + else: + # For primitive types, equality check is sufficient + assert original == restored, ( + f"At {path}: Values differ: {original} != {restored}" + ) + + # Verify byte-for-byte equality throughout the structure + verify_bytes_equality(payload, deserialized) + + + @given(payload_with_bytes_strategy()) + @settings(max_examples=100) + def test_property_6_round_trip_preserves_dictionary_key_ordering(self, payload): + """Property 6: Round-trip preserves dictionary key ordering. + + **Validates: Requirements 4.4** + + For any payload with ordered dictionaries, round-trip serialization SHALL + preserve the insertion order of dictionary keys. + + # Feature: json-serialization-optimization, Property 6: Round-trip preserves dictionary key ordering + """ + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Serialize the payload + serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Deserialize the payload + deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) + + # Helper function to verify key ordering + def verify_key_ordering(original, restored, path=""): + """Recursively verify that dictionary key ordering is preserved.""" + if isinstance(original, dict): + assert isinstance(restored, dict), ( + f"At {path}: Expected dict, got {type(restored)}" + ) + + # Get the keys as lists to preserve order + original_keys = list(original.keys()) + restored_keys = list(restored.keys()) + + # Verify the keys are in the same order + assert original_keys == restored_keys, ( + f"At {path}: Key ordering differs. " + f"Original: {original_keys}, Restored: {restored_keys}" + ) + + # Recursively verify nested structures + for key in original_keys: + verify_key_ordering( + original[key], + restored[key], + f"{path}.{key}" if path else key + ) + elif isinstance(original, list): + assert isinstance(restored, list), ( + f"At {path}: Expected list, got {type(restored)}" + ) + assert len(original) == len(restored), ( + f"At {path}: List lengths differ" + ) + # Verify each element + for i, (orig_item, rest_item) in enumerate(zip(original, restored)): + verify_key_ordering( + orig_item, + rest_item, + f"{path}[{i}]" + ) + # For non-dict, non-list types, no key ordering to verify + + # Verify key ordering is preserved throughout the structure + verify_key_ordering(payload, deserialized) + + +class TestBackwardCompatibilityProperties: + """Property-based tests for backward compatibility.""" + + @given(payload_without_bytes_strategy()) + @settings(max_examples=100) + def test_property_7_non_binary_payloads_are_backward_compatible(self, payload): + """Property 7: Non-binary payloads are backward compatible. + + **Validates: Requirements 3.2, 3.3** + + For any payload containing no bytes objects, serializing with the new + encoder SHALL produce output identical to serializing with the standard + json.dumps (no marker objects introduced). + + # Feature: json-serialization-optimization, Property 7: Non-binary payloads are backward compatible + """ + # Serialize with the new encoder + serialized_with_encoder = json.dumps(payload, cls=LLMeterBytesEncoder) + + # Serialize with standard json.dumps + serialized_standard = json.dumps(payload) + + # Verify they produce identical output + assert serialized_with_encoder == serialized_standard, ( + "Serialization with LLMeterBytesEncoder should produce identical output " + "to standard json.dumps for payloads without bytes objects" + ) + + # Verify no marker objects were introduced + assert "__llmeter_bytes__" not in serialized_with_encoder, ( + "No marker objects should be introduced for payloads without bytes" + ) + + # Verify both can be parsed identically + parsed_encoder = json.loads(serialized_with_encoder) + parsed_standard = json.loads(serialized_standard) + assert parsed_encoder == parsed_standard, ( + "Parsed output should be identical for both serialization methods" + ) + + # Verify the parsed output matches the original payload + assert parsed_encoder == payload, ( + "Parsed output should match the original payload" + ) + + +class TestErrorHandlingProperties: + """Property-based tests for error handling.""" + + @composite + def unserializable_object_strategy(draw): + """Generate objects that are not JSON serializable and not bytes.""" + # Create various types of unserializable objects + unserializable_types = [ + # Custom class instance + lambda: type('CustomClass', (), {})(), + # Function + lambda: lambda x: x, + # Set (not JSON serializable) + lambda: {1, 2, 3}, + # Complex number + lambda: complex(1, 2), + # Object with __dict__ + lambda: type('ObjWithDict', (), {'attr': 'value'})(), + ] + + # Choose one of the unserializable types + return draw(st.sampled_from(unserializable_types))() + + @given(st.data()) + @settings(max_examples=100) + def test_property_8_serialization_errors_are_descriptive(self, data): + """Property 8: Serialization errors are descriptive. + + **Validates: Requirements 6.1** + + For any payload containing unserializable types (not bytes, not standard + JSON types), attempting to serialize SHALL raise a TypeError with a message + indicating the problematic type. + + # Feature: json-serialization-optimization, Property 8: Serialization errors are descriptive + """ + # Generate a payload with an unserializable object + unserializable_obj = data.draw( + TestErrorHandlingProperties.unserializable_object_strategy() + ) + + # Create a payload containing the unserializable object + # We'll place it at various locations in the structure + placement_strategy = st.sampled_from([ + # Direct value + lambda obj: {"unserializable": obj}, + # Nested in dict + lambda obj: {"outer": {"inner": {"unserializable": obj}}}, + # In a list + lambda obj: {"items": [1, 2, obj, 4]}, + # Mixed structure + lambda obj: {"data": {"list": [{"nested": obj}]}}, + ]) + + payload_creator = data.draw(placement_strategy) + payload = payload_creator(unserializable_obj) + + # Attempt to serialize and verify it raises TypeError + try: + json.dumps(payload, cls=LLMeterBytesEncoder) + # If we get here, serialization succeeded when it shouldn't have + raise AssertionError( + f"Expected TypeError for unserializable object of type " + f"{type(unserializable_obj).__name__}, but serialization succeeded" + ) + except TypeError as e: + # Verify the error message is descriptive + error_msg = str(e) + + # The error message should mention that the object is not serializable + assert "not" in error_msg.lower() and "serializable" in error_msg.lower(), ( + f"Error message should indicate object is not serializable. " + f"Got: {error_msg}" + ) + + # The error message should include type information + # Either the class name or the word "type" should appear + type_name = type(unserializable_obj).__name__ + has_type_info = ( + type_name in error_msg or + "type" in error_msg.lower() or + "object" in error_msg.lower() + ) + assert has_type_info, ( + f"Error message should include type information. " + f"Expected reference to '{type_name}' or 'type'. Got: {error_msg}" + ) + + + @given(st.data()) + @settings(max_examples=100) + def test_property_9_deserialization_errors_are_descriptive(self, data): + """Property 9: Deserialization errors are descriptive. + + **Validates: Requirements 6.2** + + For any invalid JSON string or JSON containing marker objects with invalid + base64 strings, attempting to deserialize SHALL raise an appropriate + exception (JSONDecodeError or binascii.Error) with a descriptive message. + + # Feature: json-serialization-optimization, Property 9: Deserialization errors are descriptive + """ + from llmeter.prompt_utils import llmeter_bytes_decoder + + # Test invalid JSON strings that will raise JSONDecodeError + # Note: base64.b64decode() is lenient by default and accepts many inputs, + # so we focus on JSON parsing errors which are more common in practice + + # Generate invalid JSON strings that will definitely fail parsing + invalid_json_strategy = st.sampled_from([ + "{invalid json}", + '{"key": undefined}', + "{'single': 'quotes'}", + '{"unclosed": ', + '{"trailing": "comma",}', + 'not json at all', + '["unclosed array"', + '}invalid start{', + '{"double""quotes"}', + '{"key": value}', # unquoted value + '[1, 2, 3,]', # trailing comma in array + ]) + invalid_json = data.draw(invalid_json_strategy) + + # Attempt to deserialize and verify it raises JSONDecodeError + try: + json.loads(invalid_json, object_hook=llmeter_bytes_decoder) + # If we get here, deserialization succeeded when it shouldn't have + raise AssertionError( + f"Expected JSONDecodeError for invalid JSON: {invalid_json}" + ) + except json.JSONDecodeError as e: + # Verify the error message is descriptive + error_msg = str(e) + + # The error message should indicate it's a JSON parsing error + # JSONDecodeError messages typically contain position information + assert len(error_msg) > 0, ( + "Error message should not be empty" + ) + + # Verify the exception has the expected attributes + assert hasattr(e, 'msg'), "JSONDecodeError should have 'msg' attribute" + assert hasattr(e, 'lineno'), "JSONDecodeError should have 'lineno' attribute" + assert hasattr(e, 'colno'), "JSONDecodeError should have 'colno' attribute" + + # Verify the error message contains useful information + # It should mention what went wrong + assert e.msg, "JSONDecodeError msg should not be empty" + + +class TestInvocationResponseProperties: + """Property-based tests for InvocationResponse serialization.""" + + @given(nested_dict_strategy(max_depth=3, allow_bytes=True)) + @settings(max_examples=100, deadline=None) + def test_property_10_invocation_response_serializes_binary_payloads( + self, input_payload + ): + """Property 10: InvocationResponse serializes binary payloads. + + **Validates: Requirements 7.1** + + For any InvocationResponse object where the input_payload field contains + bytes objects, calling to_json() SHALL produce valid JSON with marker + objects replacing the bytes. + + # Feature: json-serialization-optimization, Property 10: InvocationResponse serializes binary payloads + """ + from llmeter.endpoints.base import InvocationResponse + + # Create an InvocationResponse with the generated input_payload + response = InvocationResponse( + response_text="Test response", + input_payload=input_payload, + id="test-id", + time_to_first_token=0.1, + time_to_last_token=0.5, + num_tokens_input=10, + num_tokens_output=20, + ) + + # Serialize using to_json() + json_str = response.to_json() + + # Verify it's valid JSON by parsing it + parsed = json.loads(json_str) + assert isinstance(parsed, dict), "to_json() should produce a valid JSON object" + + # Verify the input_payload field exists + assert "input_payload" in parsed, "Serialized response should contain input_payload" + + # Helper function to check for bytes in original and markers in serialized + def has_bytes(obj): + """Recursively check if object contains bytes.""" + if isinstance(obj, bytes): + return True + elif isinstance(obj, dict): + return any(has_bytes(v) for v in obj.values()) + elif isinstance(obj, list): + return any(has_bytes(item) for item in obj) + return False + + def verify_bytes_serialized(original, serialized_obj, path=""): + """Recursively verify bytes are replaced with marker objects.""" + if isinstance(original, bytes): + # Bytes should be replaced with marker object + assert isinstance(serialized_obj, dict), ( + f"At {path}: Expected marker dict for bytes, got {type(serialized_obj)}" + ) + assert "__llmeter_bytes__" in serialized_obj, ( + f"At {path}: Expected marker object for bytes" + ) + assert len(serialized_obj) == 1, ( + f"At {path}: Marker object should have only one key" + ) + # Verify the base64 string can be decoded back to original bytes + base64_str = serialized_obj["__llmeter_bytes__"] + assert isinstance(base64_str, str), ( + f"At {path}: Marker value should be a string" + ) + decoded = base64.b64decode(base64_str) + assert decoded == original, ( + f"At {path}: Decoded bytes don't match original" + ) + elif isinstance(original, dict): + assert isinstance(serialized_obj, dict), ( + f"At {path}: Expected dict, got {type(serialized_obj)}" + ) + # Verify all keys are present + for key in original.keys(): + assert key in serialized_obj, ( + f"At {path}: Key '{key}' missing in serialized object" + ) + verify_bytes_serialized( + original[key], + serialized_obj[key], + f"{path}.{key}" if path else key + ) + elif isinstance(original, list): + assert isinstance(serialized_obj, list), ( + f"At {path}: Expected list, got {type(serialized_obj)}" + ) + assert len(original) == len(serialized_obj), ( + f"At {path}: List lengths differ" + ) + for i, (orig_item, ser_item) in enumerate(zip(original, serialized_obj)): + verify_bytes_serialized( + orig_item, + ser_item, + f"{path}[{i}]" + ) + # For other types, no special verification needed + + # If the input_payload contains bytes, verify they're serialized correctly + if has_bytes(input_payload): + assert "__llmeter_bytes__" in json_str, ( + "Serialized JSON should contain marker objects when input_payload has bytes" + ) + verify_bytes_serialized(input_payload, parsed["input_payload"], "input_payload") + + # Verify other fields are serialized correctly + assert parsed["response_text"] == "Test response" + assert parsed["id"] == "test-id" + assert parsed["time_to_first_token"] == 0.1 + assert parsed["time_to_last_token"] == 0.5 + assert parsed["num_tokens_input"] == 10 + assert parsed["num_tokens_output"] == 20 From 39baca7d01f9d9baa2064855baccd6a3be68ee3a Mon Sep 17 00:00:00 2001 From: Alessandro Cere Date: Wed, 11 Mar 2026 11:54:29 +0800 Subject: [PATCH 2/9] feat(multimodal): Add comprehensive multi-modal payload support across endpoints - Add multi-modal content handling (images, videos, audio, documents) to BedrockConverse, OpenAI, and SageMaker endpoints - Implement automatic format detection using puremagic with fallback to file extensions - Add multimodal utility functions for format conversion and content serialization - Support both file paths and raw bytes for multi-modal content - Add endpoint-specific format string handling (Bedrock short format, OpenAI MIME types) - Implement comprehensive unit tests for multi-modal serialization and properties across all endpoints - Add detailed README documentation with usage examples and security warnings for format detection - Fix Path serialization in JSONableBase to handle os.PathLike objects - Update pyproject.toml with optional multimodal extra for puremagic dependency - Improve integration tests with multi-modal payload examples - Enhance prompt utilities with multi-modal content handling --- README.md | 153 ++++ llmeter/callbacks/cost/serde.py | 10 +- llmeter/endpoints/base.py | 16 +- llmeter/endpoints/bedrock.py | 259 ++++++- llmeter/endpoints/openai.py | 227 +++++- llmeter/endpoints/sagemaker.py | 391 +++++++++- llmeter/plotting/plotting.py | 2 +- llmeter/prompt_utils.py | 197 ++++- llmeter/results.py | 6 +- llmeter/runner.py | 7 +- pyproject.toml | 7 +- tests/integ/conftest.py | 68 +- tests/integ/test_openai_bedrock.py | 121 ++-- .../unit/endpoints/test_bedrock_multimodal.py | 247 +++++++ .../endpoints/test_multimodal_properties.py | 682 ++++++++++++++++++ .../test_multimodal_serialization.py | 253 +++++++ .../endpoints/test_multimodal_utilities.py | 212 ++++++ .../unit/endpoints/test_openai_multimodal.py | 144 ++++ .../endpoints/test_sagemaker_multimodal.py | 140 ++++ tests/unit/test_prompt_utils.py | 3 - tests/unit/test_property_save_load.py | 8 +- tests/unit/test_tokenizers.py | 31 +- 22 files changed, 3009 insertions(+), 175 deletions(-) create mode 100644 tests/unit/endpoints/test_bedrock_multimodal.py create mode 100644 tests/unit/endpoints/test_multimodal_properties.py create mode 100644 tests/unit/endpoints/test_multimodal_serialization.py create mode 100644 tests/unit/endpoints/test_multimodal_utilities.py create mode 100644 tests/unit/endpoints/test_openai_multimodal.py create mode 100644 tests/unit/endpoints/test_sagemaker_multimodal.py diff --git a/README.md b/README.md index a83030f..bcdc27a 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,159 @@ Additional functionality like cost modelling and MLFlow experiment tracking is e For more details, check out our selection of end-to-end code examples in the [examples](https://github.com/awslabs/llmeter/tree/main/examples) folder! +## 🖼️ Multi-Modal Payload Support + +LLMeter supports creating payloads with multi-modal content including images, videos, audio, and documents alongside text. This enables testing of modern multi-modal AI models. + +### Installation for Multi-Modal Support + +For enhanced format detection from file content (recommended), install the optional `multimodal` extra: + +```terminal +pip install 'llmeter[multimodal]' +``` + +Or with uv: + +```terminal +uv pip install 'llmeter[multimodal]' +``` + +This installs the `puremagic` library for content-based format detection using magic bytes. Without it, format detection falls back to file extensions. + +### Basic Multi-Modal Usage + +```python +from llmeter.endpoints import BedrockConverse + +# Single image from file +payload = BedrockConverse.create_payload( + user_message="What is in this image?", + images=["photo.jpg"], + max_tokens=256 +) + +# Multiple images +payload = BedrockConverse.create_payload( + user_message="Compare these images:", + images=["image1.jpg", "image2.png"], + max_tokens=512 +) + +# Image from bytes (requires puremagic for format detection) +with open("photo.jpg", "rb") as f: + image_bytes = f.read() + +payload = BedrockConverse.create_payload( + user_message="What is in this image?", + images=[image_bytes], + max_tokens=256 +) + +# Mixed content types +payload = BedrockConverse.create_payload( + user_message="Analyze this presentation and supporting materials", + documents=["slides.pdf"], + images=["chart.png"], + max_tokens=1024 +) + +# Video analysis +payload = BedrockConverse.create_payload( + user_message="Describe what happens in this video", + videos=["clip.mp4"], + max_tokens=1024 +) +``` + +### Supported Content Types + +- **Images**: JPEG, PNG, GIF, WebP +- **Documents**: PDF +- **Videos**: MP4, MOV, AVI +- **Audio**: MP3, WAV, OGG + +Format support varies by model. The library detects formats automatically and lets the API endpoint validate compatibility. + +### Endpoint-Specific Format Handling + +Different endpoints expect different format strings: + +- **Bedrock**: Uses short format strings (e.g., `"jpeg"`, `"png"`, `"pdf"`) +- **OpenAI**: Uses full MIME types (e.g., `"image/jpeg"`, `"image/png"`) +- **SageMaker**: Uses Bedrock format by default (model-dependent) + +The library handles these differences automatically based on the endpoint you're using. + +### ⚠️ Security Warning: Format Detection Is NOT Input Validation + +**IMPORTANT**: The format detection in this library is for testing and development convenience ONLY. It is NOT a security mechanism and MUST NOT be used with untrusted files without proper validation. + +#### What This Library Does + +- Detects likely file format from magic bytes (puremagic) or extension (mimetypes) +- Reads binary content from files +- Packages content for API endpoints +- Provides type checking (bytes vs strings) + +#### What This Library Does NOT Do + +- ❌ Validate file content safety or integrity +- ❌ Scan for malicious content or malware +- ❌ Sanitize or clean file data +- ❌ Protect against malformed or exploited files +- ❌ Guarantee format correctness beyond detection heuristics +- ❌ Validate file size or prevent memory exhaustion +- ❌ Check for embedded scripts or exploits +- ❌ Verify file authenticity or source + +#### Intended Use Cases + +This format detection is designed for: + +- **Testing and development**: Loading known-safe test files during development +- **Internal tools**: Processing files from trusted internal sources +- **Prototyping**: Quick experimentation with multi-modal models +- **Controlled environments**: Scenarios where file sources are fully trusted + +#### NOT Intended For + +This format detection should NOT be used for: + +- **Production user uploads**: Files uploaded by end users through web forms or APIs +- **External file sources**: Files from untrusted URLs, email attachments, or third-party systems +- **Security-sensitive applications**: Any application where file safety is critical +- **Public-facing services**: Services that accept files from the internet + +#### Recommended Security Practices for Untrusted Files + +When working with untrusted files (user uploads, external sources, etc.), you MUST implement proper security measures: + +1. **Validate file sources**: Only accept files from trusted, authenticated sources +2. **Scan for malware**: Use antivirus/malware scanning (e.g., ClamAV) before processing +3. **Validate file integrity**: Verify checksums, digital signatures, or other integrity mechanisms +4. **Sanitize content**: Use specialized libraries to validate and sanitize file content: + - Images: Re-encode with PIL/Pillow to strip metadata and validate structure + - PDFs: Use PDF sanitization libraries to remove scripts and validate structure + - Videos: Re-encode with ffmpeg to validate and sanitize +5. **Limit file sizes**: Enforce maximum file size limits before reading into memory +6. **Sandbox processing**: Process untrusted files in isolated environments (containers, VMs) +7. **Validate API responses**: Check that API endpoints successfully processed the content +8. **Implement rate limiting**: Prevent abuse through excessive file uploads +9. **Log and monitor**: Track file processing for security auditing + +### Backward Compatibility + +Text-only payloads continue to work exactly as before: + +```python +# Still works - no changes needed +payload = BedrockConverse.create_payload( + user_message="Hello, world!", + max_tokens=256 +) +``` + ## Analyze and compare results You can analyze the results of a single run or a load test by generating interactive charts. You can find examples in in the [examples](examples) folder. diff --git a/llmeter/callbacks/cost/serde.py b/llmeter/callbacks/cost/serde.py index 843fad2..4a1bb97 100644 --- a/llmeter/callbacks/cost/serde.py +++ b/llmeter/callbacks/cost/serde.py @@ -60,7 +60,9 @@ def to_dict_recursive_generic(obj: object, **kwargs) -> dict: result.update({k: getattr(obj, k) for k in dir(obj)}) result.update(kwargs) for k, v in result.items(): - if hasattr(v, "to_dict"): + if isinstance(v, os.PathLike): + result[k] = Path(v).as_posix() + elif hasattr(v, "to_dict"): result[k] = v.to_dict() elif isinstance(v, dict): result[k] = to_dict_recursive_generic(v) @@ -192,6 +194,12 @@ def to_file( Returns: output_path: Universal Path representation of the target file. """ + if default is None: + def default(obj): + if isinstance(obj, os.PathLike): + return Path(obj).as_posix() + return str(obj) + output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w") as f: diff --git a/llmeter/endpoints/base.py b/llmeter/endpoints/base.py index bc35b03..62659d2 100644 --- a/llmeter/endpoints/base.py +++ b/llmeter/endpoints/base.py @@ -279,7 +279,13 @@ def save(self, output_path: os.PathLike) -> os.PathLike: output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w") as f: endpoint_conf = self.to_dict() - json.dump(endpoint_conf, f, indent=4, default=str) + + def _default_serializer(obj): + if isinstance(obj, os.PathLike): + return Path(obj).as_posix() + return str(obj) + + json.dump(endpoint_conf, f, indent=4, default=_default_serializer) return output_path def to_dict(self) -> dict: @@ -289,7 +295,13 @@ def to_dict(self) -> dict: Returns: Dict: A dictionary representation of the endpoint configuration. """ - endpoint_conf = {k: v for k, v in vars(self).items() if not k.startswith("_")} + endpoint_conf = {} + for k, v in vars(self).items(): + if k.startswith("_"): + continue + if isinstance(v, os.PathLike): + v = Path(v).as_posix() + endpoint_conf[k] = v endpoint_conf["endpoint_type"] = self.__class__.__name__ return endpoint_conf diff --git a/llmeter/endpoints/bedrock.py b/llmeter/endpoints/bedrock.py index 9edf30b..302df9f 100644 --- a/llmeter/endpoints/bedrock.py +++ b/llmeter/endpoints/bedrock.py @@ -19,11 +19,117 @@ from botocore.config import Config from botocore.exceptions import ClientError +from ..prompt_utils import ( + detect_format_from_bytes, + detect_format_from_file, + read_file, +) from .base import Endpoint, InvocationResponse logger = logging.getLogger(__name__) +def _mime_to_format(mime_type: str) -> str | None: + """Convert MIME type to Bedrock format string. + + Maps MIME types to format names used by Bedrock Converse API. + + Args: + mime_type: MIME type (e.g., "image/jpeg", "application/pdf") + + Returns: + str | None: Format string (e.g., "jpeg", "png", "pdf") or None if not recognized + """ + mime_map = { + "image/jpeg": "jpeg", + "image/png": "png", + "image/gif": "gif", + "image/webp": "webp", + "application/pdf": "pdf", + "video/mp4": "mp4", + "video/quicktime": "mov", + "video/x-msvideo": "avi", + "audio/mpeg": "mp3", + "audio/wav": "wav", + "audio/x-wav": "wav", + "audio/ogg": "ogg", + } + return mime_map.get(mime_type) + + +def _build_content_blocks( + user_message: str | list[str] | None, + images: list[bytes] | list[str] | None, + documents: list[bytes] | list[str] | None, + videos: list[bytes] | list[str] | None, + audio: list[bytes] | list[str] | None, +) -> list[dict]: + """Build content blocks from parameters. + + Returns list of content block dictionaries in Bedrock Converse API format. + + Args: + user_message: Text message(s) + images: List of image bytes or file paths + documents: List of document bytes or file paths + videos: List of video bytes or file paths + audio: List of audio bytes or file paths + + Returns: + list[dict]: Content blocks + + Raises: + ValueError: If format cannot be auto-detected from bytes + """ + content = [] + + # Add text blocks first + if user_message: + messages = [user_message] if isinstance(user_message, str) else user_message + for msg in messages: + content.append({"text": msg}) + + # Add media blocks in order: images, videos, audio, documents + for media_list, media_type in [ + (images, "image"), + (videos, "video"), + (audio, "audio"), + (documents, "document"), + ]: + if media_list: + for item in media_list: + if isinstance(item, bytes): + # Bytes provided directly - detect MIME type from content + data = item + mime_type = detect_format_from_bytes(data) + if mime_type is None: + raise ValueError( + f"Cannot detect format from bytes for {media_type}. " + "Either install puremagic for content-based detection " + "or provide file path for extension-based detection." + ) + fmt = _mime_to_format(mime_type) + if fmt is None: + raise ValueError( + f"Unsupported MIME type '{mime_type}' for {media_type}" + ) + else: + # File path - read and detect MIME type from file + data = read_file(item) + mime_type = detect_format_from_file(item) + if mime_type is None: + raise ValueError(f"Cannot detect format from file: {item}") + fmt = _mime_to_format(mime_type) + if fmt is None: + raise ValueError( + f"Unsupported MIME type '{mime_type}' for file: {item}" + ) + + content.append({media_type: {"format": fmt, "source": {"bytes": data}}}) + + return content + + class BedrockBase(Endpoint): """Base class for interacting with Amazon Bedrock endpoints. @@ -114,43 +220,156 @@ def _parse_payload(self, payload): @staticmethod def create_payload( - user_message: str | list[str], max_tokens: int = 256, **kwargs: Any + user_message: str | list[str] | None = None, + max_tokens: int | None = None, + *, + images: list[bytes] | list[str] | None = None, + documents: list[bytes] | list[str] | None = None, + videos: list[bytes] | list[str] | None = None, + audio: list[bytes] | list[str] | None = None, + **kwargs: Any, ) -> dict: """ - Create a payload for the Bedrock Converse API request. + Create a payload for the Bedrock Converse API request with optional multi-modal content. + + ⚠️ SECURITY WARNING: Format detection is for testing/development convenience ONLY. + This method does NOT validate file safety, integrity, or protect against malicious + content. DO NOT use with untrusted files (user uploads, external sources) without + proper validation, sanitization, and security measures. Args: - user_message (str | Sequence[str]): The user's message or a sequence of messages. - max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 256. + user_message (str | list[str] | None): The user's message or a sequence of messages. + max_tokens (int | None): The maximum number of tokens to generate. Defaults to 256. + images (list[bytes] | list[str] | None): List of image bytes or file paths (keyword-only). + documents (list[bytes] | list[str] | None): List of document bytes or file paths (keyword-only). + videos (list[bytes] | list[str] | None): List of video bytes or file paths (keyword-only). + audio (list[bytes] | list[str] | None): List of audio bytes or file paths (keyword-only). **kwargs: Additional keyword arguments to include in the payload. Returns: dict: The formatted payload for the Bedrock API request. Raises: - TypeError: If user_message is not a string or list of strings - ValueError: If max_tokens is not a positive integer + TypeError: If parameters have invalid types + ValueError: If parameters have invalid values + FileNotFoundError: If a file path doesn't exist + IOError: If a file cannot be read + + Security: + - Format detection (puremagic/extension) is NOT security validation + - Malicious files can have misleading extensions or forged magic bytes + - This method does NOT scan for malware or sanitize content + - Users MUST validate and sanitize untrusted files before calling this method + - Intended for testing/development with trusted files only + - NOT intended for production user uploads without proper security measures + + Examples: + # Text only (backward compatible) + create_payload(user_message="Hello") + create_payload("Hello", 256) # Positional args still work + + # Single image from file path (trusted source) + create_payload( + user_message="What's in this image?", + images=["photo.jpg"] + ) + + # Multiple images from bytes (trusted source) + create_payload( + user_message="Compare these images", + images=[image_bytes1, image_bytes2] + ) + + # Mixed content (trusted source) + create_payload( + user_message="Analyze this", + images=["chart.png"], + documents=["report.pdf"] + ) """ - if not isinstance(user_message, (str, list)): - raise TypeError("user_message must be a string or list of strings") + # Set default for max_tokens if not provided + if max_tokens is None: + max_tokens = 256 + + # Check if any multi-modal content is provided + has_multimodal = any([images, documents, videos, audio]) + + # Backward compatibility: if only user_message provided, use old logic + if not has_multimodal: + if user_message is None: + raise ValueError("user_message is required when no media is provided") + + if not isinstance(user_message, (str, list)): + raise TypeError("user_message must be a string or list of strings") + + if isinstance(user_message, list): + if not all(isinstance(msg, str) for msg in user_message): + raise TypeError("All messages must be strings") + if not user_message: + raise ValueError("user_message list cannot be empty") + + if not isinstance(max_tokens, int) or max_tokens <= 0: + raise ValueError("max_tokens must be a positive integer") - if isinstance(user_message, list): - if not all(isinstance(msg, str) for msg in user_message): - raise TypeError("All messages must be strings") - if not user_message: - raise ValueError("user_message list cannot be empty") + if isinstance(user_message, str): + user_message = [user_message] + + try: + payload: dict = { + "messages": [ + {"role": "user", "content": [{"text": k}]} for k in user_message + ], + } + payload.update(kwargs) + if payload.get("inferenceConfig") is None: + payload["inferenceConfig"] = {} + + payload["inferenceConfig"] = { + **payload["inferenceConfig"], + "maxTokens": max_tokens, + } + return payload + + except Exception as e: + logger.error(f"Error creating payload: {e}") + raise RuntimeError(f"Failed to create payload: {str(e)}") + + # Multi-modal path: validate types + if images is not None and not isinstance(images, list): + raise TypeError("images must be a list") + if documents is not None and not isinstance(documents, list): + raise TypeError("documents must be a list") + if videos is not None and not isinstance(videos, list): + raise TypeError("videos must be a list") + if audio is not None and not isinstance(audio, list): + raise TypeError("audio must be a list") + + # Validate list items are bytes or str + for media_list, media_name in [ + (images, "images"), + (documents, "documents"), + (videos, "videos"), + (audio, "audio"), + ]: + if media_list: + for item in media_list: + if not isinstance(item, (bytes, str)): + raise TypeError( + f"Items in {media_name} list must be bytes or str (file path), " + f"got {type(item).__name__}" + ) if not isinstance(max_tokens, int) or max_tokens <= 0: raise ValueError("max_tokens must be a positive integer") - if isinstance(user_message, str): - user_message = [user_message] - try: + # Build content blocks + content_blocks = _build_content_blocks( + user_message, images, documents, videos, audio + ) + payload: dict = { - "messages": [ - {"role": "user", "content": [{"text": k}]} for k in user_message - ], + "messages": [{"role": "user", "content": content_blocks}], } payload.update(kwargs) if payload.get("inferenceConfig") is None: @@ -164,7 +383,7 @@ def create_payload( except Exception as e: logger.error(f"Error creating payload: {e}") - raise RuntimeError(f"Failed to create payload: {str(e)}") + raise class BedrockConverse(BedrockBase): diff --git a/llmeter/endpoints/openai.py b/llmeter/endpoints/openai.py index 88dd041..a3e8d85 100644 --- a/llmeter/endpoints/openai.py +++ b/llmeter/endpoints/openai.py @@ -12,6 +12,108 @@ from openai.types.chat import ChatCompletion from .base import Endpoint, InvocationResponse +from ..prompt_utils import read_file, detect_format_from_bytes, detect_format_from_file + +logger = logging.getLogger(__name__) + + +def _mime_to_openai_format(mime_type: str) -> str | None: + """Convert MIME type to OpenAI format string. + + OpenAI expects full MIME types for media content. + + Args: + mime_type: MIME type (e.g., "image/jpeg", "application/pdf") + + Returns: + str | None: Format string (full MIME type) or None if not recognized + """ + # OpenAI uses full MIME types + supported_mimes = { + "image/jpeg", + "image/png", + "image/gif", + "image/webp", + "application/pdf", + "video/mp4", + "audio/mpeg", + "audio/wav", + } + return mime_type if mime_type in supported_mimes else None + + +def _build_content_blocks_openai( + user_message: str | list[str] | None, + images: list[bytes] | list[str] | None, + documents: list[bytes] | list[str] | None, + videos: list[bytes] | list[str] | None, + audio: list[bytes] | list[str] | None, +) -> list[dict]: + """Build content blocks from parameters for OpenAI API. + + Returns list of content block dictionaries in OpenAI format. + + Args: + user_message: Text message(s) + images: List of image bytes or file paths + documents: List of document bytes or file paths + videos: List of video bytes or file paths + audio: List of audio bytes or file paths + + Returns: + list[dict]: Content blocks + + Raises: + ValueError: If format cannot be auto-detected from bytes + """ + content = [] + + # Add text blocks first + if user_message: + messages = [user_message] if isinstance(user_message, str) else user_message + for msg in messages: + content.append({"text": msg}) + + # Add media blocks in order: images, videos, audio, documents + for media_list, media_type in [ + (images, "image"), + (videos, "video"), + (audio, "audio"), + (documents, "document"), + ]: + if media_list: + for item in media_list: + if isinstance(item, bytes): + # Bytes provided directly - detect MIME type from content + data = item + mime_type = detect_format_from_bytes(data) + if mime_type is None: + raise ValueError( + f"Cannot detect format from bytes for {media_type}. " + "Either install puremagic for content-based detection " + "or provide file path for extension-based detection." + ) + fmt = _mime_to_openai_format(mime_type) + if fmt is None: + raise ValueError( + f"Unsupported MIME type '{mime_type}' for {media_type}" + ) + else: + # File path - read and detect MIME type from file + data = read_file(item) + mime_type = detect_format_from_file(item) + if mime_type is None: + raise ValueError(f"Cannot detect format from file: {item}") + fmt = _mime_to_openai_format(mime_type) + if fmt is None: + raise ValueError( + f"Unsupported MIME type '{mime_type}' for file: {item}" + ) + + content.append({media_type: {"format": fmt, "source": {"bytes": data}}}) + + return content + logger = logging.getLogger(__name__) @@ -65,26 +167,129 @@ def _parse_payload(self, payload): @staticmethod def create_payload( - user_message: str | Sequence[str], max_tokens: int = 256, **kwargs: Any + user_message: str | Sequence[str] | None = None, + max_tokens: int = 256, + *, + images: list[bytes] | list[str] | None = None, + documents: list[bytes] | list[str] | None = None, + videos: list[bytes] | list[str] | None = None, + audio: list[bytes] | list[str] | None = None, + **kwargs: Any, ) -> dict: - """Create a payload for the OpenAI API request. + """Create a payload for the OpenAI API request with optional multi-modal content. + + ⚠️ SECURITY WARNING: Format detection is for testing/development convenience ONLY. + This method does NOT validate file safety, integrity, or protect against malicious + content. DO NOT use with untrusted files (user uploads, external sources) without + proper validation, sanitization, and security measures. Args: - user_message (str | Sequence[str]): User message(s) to send + user_message (str | Sequence[str] | None): User message(s) to send max_tokens (int, optional): Maximum tokens in response. Defaults to 256. + images (list[bytes] | list[str] | None): List of image bytes or file paths (keyword-only). + documents (list[bytes] | list[str] | None): List of document bytes or file paths (keyword-only). + videos (list[bytes] | list[str] | None): List of video bytes or file paths (keyword-only). + audio (list[bytes] | list[str] | None): List of audio bytes or file paths (keyword-only). **kwargs: Additional payload parameters Returns: dict: Formatted payload for API request + + Raises: + TypeError: If parameters have invalid types + ValueError: If parameters have invalid values + FileNotFoundError: If a file path doesn't exist + IOError: If a file cannot be read + + Security: + - Format detection (puremagic/extension) is NOT security validation + - Malicious files can have misleading extensions or forged magic bytes + - This method does NOT scan for malware or sanitize content + - Users MUST validate and sanitize untrusted files before calling this method + - Intended for testing/development with trusted files only + - NOT intended for production user uploads without proper security measures + + Examples: + # Text only (backward compatible) + create_payload(user_message="Hello") + + # Single image from file path (trusted source) + create_payload( + user_message="What's in this image?", + images=["photo.jpg"] + ) + + # Multiple images from bytes (trusted source) + create_payload( + user_message="Compare these images", + images=[image_bytes1, image_bytes2] + ) + + # Mixed content (trusted source) + create_payload( + user_message="Analyze this", + images=["chart.png"], + documents=["report.pdf"] + ) """ - if isinstance(user_message, str): - user_message = [user_message] - payload = { - "messages": [{"role": "user", "content": k} for k in user_message], - "max_tokens": max_tokens, - } - payload.update(kwargs) - return payload + # Check if any multi-modal content is provided + has_multimodal = any([images, documents, videos, audio]) + + # Backward compatibility: if only user_message provided, use old logic + if not has_multimodal: + if isinstance(user_message, str): + user_message = [user_message] + payload = { + "messages": [{"role": "user", "content": k} for k in user_message], + "max_tokens": max_tokens, + } + payload.update(kwargs) + return payload + + # Multi-modal path: validate types + if images is not None and not isinstance(images, list): + raise TypeError("images must be a list") + if documents is not None and not isinstance(documents, list): + raise TypeError("documents must be a list") + if videos is not None and not isinstance(videos, list): + raise TypeError("videos must be a list") + if audio is not None and not isinstance(audio, list): + raise TypeError("audio must be a list") + + # Validate list items are bytes or str + for media_list, media_name in [ + (images, "images"), + (documents, "documents"), + (videos, "videos"), + (audio, "audio"), + ]: + if media_list: + for item in media_list: + if not isinstance(item, (bytes, str)): + raise TypeError( + f"Items in {media_name} list must be bytes or str (file path), " + f"got {type(item).__name__}" + ) + + if not isinstance(max_tokens, int) or max_tokens <= 0: + raise ValueError("max_tokens must be a positive integer") + + try: + # Build content blocks + content_blocks = _build_content_blocks_openai( + user_message, images, documents, videos, audio + ) + + payload = { + "messages": [{"role": "user", "content": content_blocks}], + "max_tokens": max_tokens, + } + payload.update(kwargs) + return payload + + except Exception as e: + logger.error(f"Error creating payload: {e}") + raise class OpenAICompletionEndpoint(OpenAIEndpoint): diff --git a/llmeter/endpoints/sagemaker.py b/llmeter/endpoints/sagemaker.py index 05c67af..ec81708 100644 --- a/llmeter/endpoints/sagemaker.py +++ b/llmeter/endpoints/sagemaker.py @@ -13,6 +13,112 @@ from botocore.exceptions import ClientError from .base import Endpoint, InvocationResponse +from ..prompt_utils import read_file, detect_format_from_bytes, detect_format_from_file + +logger = logging.getLogger(__name__) + + +def _mime_to_sagemaker_format(mime_type: str) -> str | None: + """Convert MIME type to SageMaker format string. + + SageMaker format depends on the deployed model, using Bedrock format as default. + + Args: + mime_type: MIME type (e.g., "image/jpeg", "application/pdf") + + Returns: + str | None: Format string (e.g., "jpeg", "png", "pdf") or None if not recognized + """ + # SageMaker uses Bedrock-style short format strings as default + mime_map = { + "image/jpeg": "jpeg", + "image/png": "png", + "image/gif": "gif", + "image/webp": "webp", + "application/pdf": "pdf", + "video/mp4": "mp4", + "video/quicktime": "mov", + "video/x-msvideo": "avi", + "audio/mpeg": "mp3", + "audio/wav": "wav", + "audio/x-wav": "wav", + "audio/ogg": "ogg", + } + return mime_map.get(mime_type) + + +def _build_content_blocks_sagemaker( + user_message: str | list[str] | None, + images: list[bytes] | list[str] | None, + documents: list[bytes] | list[str] | None, + videos: list[bytes] | list[str] | None, + audio: list[bytes] | list[str] | None, +) -> list[dict]: + """Build content blocks from parameters for SageMaker API. + + Returns list of content block dictionaries in SageMaker format. + + Args: + user_message: Text message(s) + images: List of image bytes or file paths + documents: List of document bytes or file paths + videos: List of video bytes or file paths + audio: List of audio bytes or file paths + + Returns: + list[dict]: Content blocks + + Raises: + ValueError: If format cannot be auto-detected from bytes + """ + content = [] + + # Add text blocks first + if user_message: + messages = [user_message] if isinstance(user_message, str) else user_message + for msg in messages: + content.append({"text": msg}) + + # Add media blocks in order: images, videos, audio, documents + for media_list, media_type in [ + (images, "image"), + (videos, "video"), + (audio, "audio"), + (documents, "document"), + ]: + if media_list: + for item in media_list: + if isinstance(item, bytes): + # Bytes provided directly - detect MIME type from content + data = item + mime_type = detect_format_from_bytes(data) + if mime_type is None: + raise ValueError( + f"Cannot detect format from bytes for {media_type}. " + "Either install puremagic for content-based detection " + "or provide file path for extension-based detection." + ) + fmt = _mime_to_sagemaker_format(mime_type) + if fmt is None: + raise ValueError( + f"Unsupported MIME type '{mime_type}' for {media_type}" + ) + else: + # File path - read and detect MIME type from file + data = read_file(item) + mime_type = detect_format_from_file(item) + if mime_type is None: + raise ValueError(f"Cannot detect format from file: {item}") + fmt = _mime_to_sagemaker_format(mime_type) + if fmt is None: + raise ValueError( + f"Unsupported MIME type '{mime_type}' for file: {item}" + ) + + content.append({media_type: {"format": fmt, "source": {"bytes": data}}}) + + return content + logger = logging.getLogger(__name__) @@ -54,22 +160,139 @@ def _parse_input(self, payload: dict) -> str | None: @staticmethod def create_payload( - input_text: str | list[str], + input_text: str | list[str] | None = None, max_tokens: int = 256, inference_parameters: dict = {}, + *, + images: list[bytes] | list[str] | None = None, + documents: list[bytes] | list[str] | None = None, + videos: list[bytes] | list[str] | None = None, + audio: list[bytes] | list[str] | None = None, **kwargs, ): - payload = { - "inputs": input_text, - "parameters": { - "max_new_tokens": max_tokens, - "details": True, - }, - } - if inference_parameters: - payload["parameters"].update(inference_parameters) - payload.update(kwargs) - return payload + """Create a payload for the SageMaker API request with optional multi-modal content. + + ⚠️ SECURITY WARNING: Format detection is for testing/development convenience ONLY. + This method does NOT validate file safety, integrity, or protect against malicious + content. DO NOT use with untrusted files (user uploads, external sources) without + proper validation, sanitization, and security measures. + + Args: + input_text (str | list[str] | None): The input text or a sequence of texts. + max_tokens (int): Maximum tokens to generate. Defaults to 256. + inference_parameters (dict): Additional inference parameters. Defaults to {}. + images (list[bytes] | list[str] | None): List of image bytes or file paths (keyword-only). + documents (list[bytes] | list[str] | None): List of document bytes or file paths (keyword-only). + videos (list[bytes] | list[str] | None): List of video bytes or file paths (keyword-only). + audio (list[bytes] | list[str] | None): List of audio bytes or file paths (keyword-only). + **kwargs: Additional keyword arguments to include in the payload. + + Returns: + dict: The formatted payload for the SageMaker API request. + + Raises: + TypeError: If parameters have invalid types + ValueError: If parameters have invalid values + FileNotFoundError: If a file path doesn't exist + IOError: If a file cannot be read + + Security: + - Format detection (puremagic/extension) is NOT security validation + - Malicious files can have misleading extensions or forged magic bytes + - This method does NOT scan for malware or sanitize content + - Users MUST validate and sanitize untrusted files before calling this method + - Intended for testing/development with trusted files only + - NOT intended for production user uploads without proper security measures + + Examples: + # Text only (backward compatible) + create_payload(input_text="Hello") + + # Single image from file path (trusted source) + create_payload( + input_text="What's in this image?", + images=["photo.jpg"] + ) + + # Multiple images from bytes (trusted source) + create_payload( + input_text="Compare these images", + images=[image_bytes1, image_bytes2] + ) + + # Mixed content (trusted source) + create_payload( + input_text="Analyze this", + images=["chart.png"], + documents=["report.pdf"] + ) + """ + # Check if any multi-modal content is provided + has_multimodal = any([images, documents, videos, audio]) + + # Backward compatibility: if only input_text provided, use old logic + if not has_multimodal: + payload = { + "inputs": input_text, + "parameters": { + "max_new_tokens": max_tokens, + "details": True, + }, + } + if inference_parameters: + payload["parameters"].update(inference_parameters) + payload.update(kwargs) + return payload + + # Multi-modal path: validate types + if images is not None and not isinstance(images, list): + raise TypeError("images must be a list") + if documents is not None and not isinstance(documents, list): + raise TypeError("documents must be a list") + if videos is not None and not isinstance(videos, list): + raise TypeError("videos must be a list") + if audio is not None and not isinstance(audio, list): + raise TypeError("audio must be a list") + + # Validate list items are bytes or str + for media_list, media_name in [ + (images, "images"), + (documents, "documents"), + (videos, "videos"), + (audio, "audio"), + ]: + if media_list: + for item in media_list: + if not isinstance(item, (bytes, str)): + raise TypeError( + f"Items in {media_name} list must be bytes or str (file path), " + f"got {type(item).__name__}" + ) + + if not isinstance(max_tokens, int) or max_tokens <= 0: + raise ValueError("max_tokens must be a positive integer") + + try: + # Build content blocks + content_blocks = _build_content_blocks_sagemaker( + input_text, images, documents, videos, audio + ) + + payload = { + "inputs": content_blocks, + "parameters": { + "max_new_tokens": max_tokens, + "details": True, + }, + } + if inference_parameters: + payload["parameters"].update(inference_parameters) + payload.update(kwargs) + return payload + + except Exception as e: + logger.error(f"Error creating payload: {e}") + raise class SageMakerEndpoint(SageMakerBase): @@ -251,23 +474,141 @@ def invoke(self, payload: dict) -> InvocationResponse: @staticmethod def create_payload( - input_text: str | list[str], + input_text: str | list[str] | None = None, max_tokens: int = 256, inference_parameters: dict = {}, + *, + images: list[bytes] | list[str] | None = None, + documents: list[bytes] | list[str] | None = None, + videos: list[bytes] | list[str] | None = None, + audio: list[bytes] | list[str] | None = None, **kwargs, ): - payload = { - "inputs": input_text, - "parameters": { - "max_new_tokens": max_tokens, - "details": True, - }, - "stream": True, - } - if inference_parameters: - payload["parameters"].update(inference_parameters) - payload.update(kwargs) - return payload + """Create a payload for the SageMaker streaming API request with optional multi-modal content. + + ⚠️ SECURITY WARNING: Format detection is for testing/development convenience ONLY. + This method does NOT validate file safety, integrity, or protect against malicious + content. DO NOT use with untrusted files (user uploads, external sources) without + proper validation, sanitization, and security measures. + + Args: + input_text (str | list[str] | None): The input text or a sequence of texts. + max_tokens (int): Maximum tokens to generate. Defaults to 256. + inference_parameters (dict): Additional inference parameters. Defaults to {}. + images (list[bytes] | list[str] | None): List of image bytes or file paths (keyword-only). + documents (list[bytes] | list[str] | None): List of document bytes or file paths (keyword-only). + videos (list[bytes] | list[str] | None): List of video bytes or file paths (keyword-only). + audio (list[bytes] | list[str] | None): List of audio bytes or file paths (keyword-only). + **kwargs: Additional keyword arguments to include in the payload. + + Returns: + dict: The formatted payload for the SageMaker streaming API request. + + Raises: + TypeError: If parameters have invalid types + ValueError: If parameters have invalid values + FileNotFoundError: If a file path doesn't exist + IOError: If a file cannot be read + + Security: + - Format detection (puremagic/extension) is NOT security validation + - Malicious files can have misleading extensions or forged magic bytes + - This method does NOT scan for malware or sanitize content + - Users MUST validate and sanitize untrusted files before calling this method + - Intended for testing/development with trusted files only + - NOT intended for production user uploads without proper security measures + + Examples: + # Text only (backward compatible) + create_payload(input_text="Hello") + + # Single image from file path (trusted source) + create_payload( + input_text="What's in this image?", + images=["photo.jpg"] + ) + + # Multiple images from bytes (trusted source) + create_payload( + input_text="Compare these images", + images=[image_bytes1, image_bytes2] + ) + + # Mixed content (trusted source) + create_payload( + input_text="Analyze this", + images=["chart.png"], + documents=["report.pdf"] + ) + """ + # Check if any multi-modal content is provided + has_multimodal = any([images, documents, videos, audio]) + + # Backward compatibility: if only input_text provided, use old logic + if not has_multimodal: + payload = { + "inputs": input_text, + "parameters": { + "max_new_tokens": max_tokens, + "details": True, + }, + "stream": True, + } + if inference_parameters: + payload["parameters"].update(inference_parameters) + payload.update(kwargs) + return payload + + # Multi-modal path: validate types + if images is not None and not isinstance(images, list): + raise TypeError("images must be a list") + if documents is not None and not isinstance(documents, list): + raise TypeError("documents must be a list") + if videos is not None and not isinstance(videos, list): + raise TypeError("videos must be a list") + if audio is not None and not isinstance(audio, list): + raise TypeError("audio must be a list") + + # Validate list items are bytes or str + for media_list, media_name in [ + (images, "images"), + (documents, "documents"), + (videos, "videos"), + (audio, "audio"), + ]: + if media_list: + for item in media_list: + if not isinstance(item, (bytes, str)): + raise TypeError( + f"Items in {media_name} list must be bytes or str (file path), " + f"got {type(item).__name__}" + ) + + if not isinstance(max_tokens, int) or max_tokens <= 0: + raise ValueError("max_tokens must be a positive integer") + + try: + # Build content blocks + content_blocks = _build_content_blocks_sagemaker( + input_text, images, documents, videos, audio + ) + + payload = { + "inputs": content_blocks, + "parameters": { + "max_new_tokens": max_tokens, + "details": True, + }, + "stream": True, + } + if inference_parameters: + payload["parameters"].update(inference_parameters) + payload.update(kwargs) + return payload + + except Exception as e: + logger.error(f"Error creating payload: {e}") + raise class TokenIterator: diff --git a/llmeter/plotting/plotting.py b/llmeter/plotting/plotting.py index 8852550..23561b5 100644 --- a/llmeter/plotting/plotting.py +++ b/llmeter/plotting/plotting.py @@ -22,7 +22,7 @@ try: import kaleido -except ModuleNotFoundError as e: +except ImportError as e: kaleido = DeferredError(e) from typing import TYPE_CHECKING diff --git a/llmeter/prompt_utils.py b/llmeter/prompt_utils.py index c4c7e18..61db642 100644 --- a/llmeter/prompt_utils.py +++ b/llmeter/prompt_utils.py @@ -13,54 +13,185 @@ from upath import UPath as Path from .tokenizers import DummyTokenizer, Tokenizer +from .utils import DeferredError logger = logging.getLogger(__name__) +# Optional dependency: puremagic for content-based format detection +try: + import puremagic +except ImportError as e: + logger.debug( + "puremagic not available. Format detection will fall back to file extensions. " + "Install with: pip install 'llmeter[multimodal]'" + ) + puremagic = DeferredError(e) + + +# Multi-modal content utilities + + +def read_file(file_path: str) -> bytes: + """Read binary content from a file. + + Args: + file_path: Path to the file + + Returns: + bytes: File content + + Raises: + FileNotFoundError: If file doesn't exist + IOError: If file cannot be read + """ + try: + with open(file_path, "rb") as f: + return f.read() + except FileNotFoundError: + raise FileNotFoundError(f"File not found: {file_path}") + except Exception as e: + raise IOError(f"Failed to read file {file_path}: {e}") + + +def detect_format_from_extension(file_path: str) -> str | None: + """Detect MIME type from file extension. + + Args: + file_path: Path to the file + + Returns: + str | None: MIME type or None if extension not recognized + + Examples: + >>> detect_format_from_extension("image.jpg") + "image/jpeg" + >>> detect_format_from_extension("document.pdf") + "application/pdf" + """ + extension = Path(file_path).suffix.lower() + + # Map common extensions to MIME types + extension_to_mime = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".gif": "image/gif", + ".webp": "image/webp", + ".pdf": "application/pdf", + ".mp4": "video/mp4", + ".mov": "video/quicktime", + ".avi": "video/x-msvideo", + ".mp3": "audio/mpeg", + ".wav": "audio/wav", + ".ogg": "audio/ogg", + } + + return extension_to_mime.get(extension) + + +def detect_format_from_bytes(content: bytes) -> str | None: + """Detect MIME type from bytes content using puremagic. + + Args: + content: Binary content + + Returns: + str | None: MIME type or None if detection fails or puremagic not available + + Examples: + >>> detect_format_from_bytes(b"\\xff\\xd8\\xff\\xe0") # JPEG magic bytes + "image/jpeg" + """ + try: + # Get MIME type from content using puremagic (v2.0+ API) + mime_type = puremagic.from_string(content, mime=True) + return mime_type if mime_type else None + except (ImportError, AttributeError): + # puremagic not available or DeferredError raised + return None + except Exception: + pass + + return None + + +def detect_format_from_file(file_path: str) -> str | None: + """Detect MIME type from file using puremagic or extension fallback. + + Args: + file_path: Path to the file + + Returns: + str | None: MIME type or None if format cannot be detected + + Examples: + >>> detect_format_from_file("photo.jpg") + "image/jpeg" + """ + # Try puremagic first if available + try: + matches = puremagic.magic_file(file_path) + if matches: + # Extract MIME type from first match + mime_type = ( + matches[0].mime_type if hasattr(matches[0], "mime_type") else None + ) + if mime_type: + return mime_type + except (ImportError, AttributeError): + # puremagic not available or DeferredError raised + pass + except Exception: + pass + + # Fallback to extension-based detection + return detect_format_from_extension(file_path) + class LLMeterBytesEncoder(json.JSONEncoder): """Custom JSON encoder that handles bytes objects by converting them to base64. - + This encoder wraps bytes objects in a marker object with the key "__llmeter_bytes__" to enable round-trip serialization and deserialization of binary content in payloads. - + Example: >>> encoder = LLMeterBytesEncoder() >>> payload = {"image": {"bytes": b"\\xff\\xd8\\xff\\xe0"}} >>> json.dumps(payload, cls=LLMeterBytesEncoder) '{"image": {"bytes": {"__llmeter_bytes__": "/9j/4A=="}}}' """ - + def default(self, obj): """Encode bytes objects as marker objects with base64 strings. - + Args: obj: Object to encode - + Returns: dict: Marker object for bytes, or delegates to parent for other types - + Raises: TypeError: If object is not JSON serializable """ if isinstance(obj, bytes): - return { - "__llmeter_bytes__": base64.b64encode(obj).decode("utf-8") - } + return {"__llmeter_bytes__": base64.b64encode(obj).decode("utf-8")} + if isinstance(obj, os.PathLike): + return Path(obj).as_posix() return super().default(obj) def llmeter_bytes_decoder(dct: dict) -> dict | bytes: """Decode marker objects back to bytes during JSON deserialization. - + This function is used as an object_hook for json.loads to detect and decode marker objects created by LLMeterBytesEncoder back to bytes during deserialization. - + Args: dct: Dictionary from JSON parsing - + Returns: bytes if marker detected, otherwise original dict - + Example: >>> marker = {"__llmeter_bytes__": "/9j/4A=="} >>> llmeter_bytes_decoder(marker) @@ -176,13 +307,13 @@ def load_payloads( in a directory. It supports both .json and .jsonl file formats. Binary content (bytes objects) that were serialized using LLMeterBytesEncoder are automatically restored during deserialization. - + Binary Content Handling: When loading payloads saved with save_payloads(), marker objects with the key "__llmeter_bytes__" are automatically detected and converted back to bytes objects. The base64-encoded strings are decoded to restore the original binary data, enabling round-trip preservation of multimodal content like images and video. - + The marker object format is: {"__llmeter_bytes__": ""} Args: @@ -203,7 +334,7 @@ def load_payloads( Examples: Load a Bedrock Converse API payload with image content: - + >>> # Assuming a file was saved with save_payloads() containing binary data >>> payloads = list(load_payloads("/tmp/output/payload.jsonl")) >>> payload = payloads[0] @@ -214,25 +345,25 @@ def load_payloads( >>> # The bytes can be used directly with the API >>> print(f"Image size: {len(image_bytes)} bytes") Image size: 52341 bytes - + Load multiple payloads with video content: - + >>> for payload in load_payloads("/tmp/output/multimodal.jsonl"): ... video_content = payload["messages"][0]["content"][1] ... if "video" in video_content: ... video_bytes = video_content["video"]["source"]["bytes"] ... print(f"Loaded video: {len(video_bytes)} bytes") Loaded video: 1048576 bytes - + Load all payloads from a directory: - + >>> # Load all .json and .jsonl files in a directory >>> all_payloads = list(load_payloads("/tmp/output/")) >>> print(f"Loaded {len(all_payloads)} payloads") Loaded 5 payloads - + Round-trip example showing binary preservation: - + >>> # Original payload with binary data >>> original = { ... "modelId": "test-model", @@ -263,7 +394,9 @@ def load_payloads( yield from _load_data_file(file, deserializer) -def _load_data_file(file: Path, deserializer: Callable[[str], dict] | None = None) -> Iterator[dict]: +def _load_data_file( + file: Path, deserializer: Callable[[str], dict] | None = None +) -> Iterator[dict]: try: with file.open(mode="r") as f: if file.suffix.lower() in [".jsonl", ".manifest"]: @@ -272,7 +405,9 @@ def _load_data_file(file: Path, deserializer: Callable[[str], dict] | None = Non if not line.strip(): continue if deserializer is None: - yield json.loads(line.strip(), object_hook=llmeter_bytes_decoder) + yield json.loads( + line.strip(), object_hook=llmeter_bytes_decoder + ) else: yield deserializer(line.strip()) except json.JSONDecodeError as e: @@ -302,13 +437,13 @@ def save_payloads( This function saves payloads to a JSONL file, with automatic handling of binary content (bytes objects) through base64 encoding. Binary data is wrapped in marker objects during serialization to enable round-trip preservation. - + Binary Content Handling: When a payload contains bytes objects (e.g., images, video), they are automatically converted to base64-encoded strings and wrapped in a marker object with the key "__llmeter_bytes__". This approach enables JSON serialization while preserving the ability to restore the original bytes during deserialization with load_payloads(). - + The marker object format is: {"__llmeter_bytes__": ""} Args: @@ -329,7 +464,7 @@ def save_payloads( Examples: Save a Bedrock Converse API payload with image content: - + >>> import base64 >>> # Create a payload with binary image data >>> with open("image.jpg", "rb") as f: @@ -352,9 +487,9 @@ def save_payloads( >>> output_path = save_payloads(payload, "/tmp/output") >>> print(output_path) /tmp/output/payload.jsonl - + Save multiple payloads with video content: - + >>> with open("video.mp4", "rb") as f: ... video_bytes = f.read() >>> payloads = [ @@ -376,9 +511,9 @@ def save_payloads( ... ] >>> save_payloads(payloads, "/tmp/output", "multimodal.jsonl") PosixPath('/tmp/output/multimodal.jsonl') - + The saved JSON file will contain marker objects for binary data: - + >>> # Example of what gets written to the file: >>> # { >>> # "modelId": "anthropic.claude-3-haiku-20240307-v1:0", diff --git a/llmeter/results.py b/llmeter/results.py index 5392d2f..b02cac7 100644 --- a/llmeter/results.py +++ b/llmeter/results.py @@ -37,6 +37,8 @@ def utc_datetime_serializer(obj: Any) -> str: if obj.tzinfo is not None: obj = obj.astimezone(timezone.utc) return obj.isoformat(timespec="seconds").replace("+00:00", "Z") + if isinstance(obj, os.PathLike): + return Path(obj).as_posix() return str(obj) @@ -66,6 +68,8 @@ def default(self, obj): # First try bytes encoding from parent if isinstance(obj, bytes): return super().default(obj) + if isinstance(obj, os.PathLike): + return Path(obj).as_posix() # Fallback to string representation for other non-serializable types try: return str(obj) @@ -160,7 +164,7 @@ def save(self, output_path: os.PathLike | str | None = None): if not responses_path.exists(): with responses_path.open("w") as f: for response in self.responses: - f.write(json.dumps(asdict(response)) + "\n") + f.write(json.dumps(asdict(response), cls=InvocationResponseEncoder) + "\n") def to_json(self, **kwargs): """Return the results as a JSON string.""" diff --git a/llmeter/runner.py b/llmeter/runner.py index 47626e3..e8bf265 100644 --- a/llmeter/runner.py +++ b/llmeter/runner.py @@ -121,8 +121,13 @@ def save( if not isinstance(self.tokenizer, dict): config_copy.tokenizer = Tokenizer.to_dict(self.tokenizer) + def _default_serializer(obj): + if isinstance(obj, os.PathLike): + return Path(obj).as_posix() + return str(obj) + with run_config_path.open("w") as f: - f.write(json.dumps(asdict(config_copy), default=str, indent=4)) + f.write(json.dumps(asdict(config_copy), default=_default_serializer, indent=4)) @classmethod def load(cls, load_path: Path | str, file_name: str = "run_config.json"): diff --git a/pyproject.toml b/pyproject.toml index 6a5062d..ba00368 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,8 @@ openai = ["openai>=1.35.1"] litellm = ["litellm>=1.47.1"] plotting = ["plotly>=5.24.1", "kaleido<=0.2.1", "pandas>=2.2.0"] mlflow = ["mlflow-skinny>=3.10.0"] -all = ["openai>=1.35.1", "litellm>=1.47.1", "plotly>=5.24.1", "kaleido<=0.2.1", "pandas>=2.2.0", "mlflow-skinny>=3.10.0"] +multimodal = ["puremagic>=1.28"] +all = ["openai>=1.35.1", "litellm>=1.47.1", "plotly>=5.24.1", "kaleido<=0.2.1", "pandas>=2.2.0", "mlflow-skinny>=3.10.0", "puremagic>=1.28"] [project.urls] Repository = "https://github.com/awslabs/llmeter" @@ -57,12 +58,16 @@ test = [ "pytest-cov>=5.0.0", "pillow>=12.1.1", "aws-bedrock-token-generator>=1.1.0", + # Include all optional dependencies for comprehensive testing "openai>=1.35.1", "litellm>=1.47.1", "plotly>=5.24.1", "kaleido<=0.2.1", "pandas>=2.2.0", "mlflow-skinny>=3.10.0", + "puremagic>=1.28", + "transformers>=4.40.2", + "tiktoken>=0.7.0", ] [tool.pytest.ini_options] diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index 8ec054a..2459eb4 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -116,10 +116,59 @@ def bedrock_openai_test_model(): ) +@pytest.fixture(scope="session") +def bedrock_openai_multimodal_test_model(): + """ + Get test model ID for OpenAI SDK multimodal tests. + + The model ID can be overridden via the BEDROCK_OPENAI_MULTIMODAL_TEST_MODEL environment variable. + Defaults to qwen.qwen3-vl-235b-a22b-instruct which supports images and is available in Bedrock. + + Note: This model is specifically for testing multimodal content (images, video, etc.) + via Bedrock's OpenAI-compatible endpoint. + + Returns: + str: OpenAI multimodal model ID for Bedrock OpenAI SDK testing. + """ + return os.environ.get( + "BEDROCK_OPENAI_MULTIMODAL_TEST_MODEL", "qwen.qwen3-vl-235b-a22b-instruct" + ) + + +@pytest.fixture +def test_image_bytes(): + """ + Create test images as binary JPEG data for multimodal testing. + + Returns two small JPEG images (32x32 pixels) as bytes objects. + These are used to test binary content serialization and API calls. + JPEG format is used for broad model compatibility. + + Returns: + tuple: (image1_bytes, image2_bytes) - Two JPEG images as bytes + """ + import io + from PIL import Image + + # 32x32 red square JPEG - binary format + img1 = Image.new("RGB", (32, 32), color=(255, 0, 0)) + buf1 = io.BytesIO() + img1.save(buf1, format="JPEG") + image1_bytes = buf1.getvalue() + + # 32x32 blue square JPEG - binary format + img2 = Image.new("RGB", (32, 32), color=(0, 0, 255)) + buf2 = io.BytesIO() + img2.save(buf2, format="JPEG") + image2_bytes = buf2.getvalue() + + return image1_bytes, image2_bytes + + @pytest.fixture(scope="session") def bedrock_openai_endpoint_url(aws_region): """ - Construct Bedrock OpenAI-compatible endpoint URL. + Construct Bedrock OpenAI-compatible endpoint URL for standard models. Args: aws_region: AWS region from the aws_region fixture. @@ -130,6 +179,23 @@ def bedrock_openai_endpoint_url(aws_region): return f"https://bedrock-runtime.{aws_region}.amazonaws.com/openai/v1" +@pytest.fixture(scope="session") +def bedrock_openai_multimodal_endpoint_url(aws_region): + """ + Construct Bedrock OpenAI-compatible endpoint URL for multimodal models (Bedrock Mantle). + + This endpoint supports non-OpenAI models (e.g., Qwen, Kimi) via the OpenAI-compatible + interface and is required for multimodal content (images, video). + + Args: + aws_region: AWS region from the aws_region fixture. + + Returns: + str: Bedrock Mantle endpoint URL for multimodal testing. + """ + return f"https://bedrock-mantle.{aws_region}.api.aws/v1" + + @pytest.fixture def test_payload(): """ diff --git a/tests/integ/test_openai_bedrock.py b/tests/integ/test_openai_bedrock.py index d3b9b28..974f317 100644 --- a/tests/integ/test_openai_bedrock.py +++ b/tests/integ/test_openai_bedrock.py @@ -301,17 +301,19 @@ def test_save_load_openai_payload_with_image_url(tmp_path): @pytest.mark.integ @pytest.mark.skipif(not OPENAI_AVAILABLE, reason="OpenAI SDK not installed") def test_save_load_openai_complete_structure( - tmp_path, aws_credentials, aws_region, bedrock_openai_test_model + tmp_path, aws_credentials, aws_region, bedrock_openai_multimodal_test_model, + bedrock_openai_multimodal_endpoint_url, test_image_bytes ): """ - Test round-trip with actual OpenAI chat.completions structure. + Test round-trip with actual OpenAI chat.completions structure and API call. This test validates that: - Complete OpenAI chat.completions payloads serialize correctly - All typical OpenAI fields are preserved - Multiple content items with mixed text and images work correctly - The messages[].content[].image_url.url path is handled correctly - - Loaded payload can be used with the OpenAI client + - Binary image data is preserved through save/load cycle + - Loaded payload can be used with the OpenAI client for multimodal requests **Validates: Requirements 9.6, 9.7** @@ -319,49 +321,43 @@ def test_save_load_openai_complete_structure( tmp_path: Temporary directory for test files (from pytest). aws_credentials: Boto3 session with valid AWS credentials (from fixture). aws_region: AWS region for testing (from fixture). - bedrock_openai_test_model: Model ID for OpenAI SDK testing (from fixture). + bedrock_openai_multimodal_test_model: Model ID for OpenAI SDK multimodal testing (from fixture). + bedrock_openai_multimodal_endpoint_url: Bedrock Mantle endpoint URL for multimodal (from fixture). + test_image_bytes: Tuple of test images as binary data (from fixture). + + AWS Permissions Required: + - bedrock:InvokeModel + + Estimated Cost: + ~$0.0002 per test run (using Qwen3-VL with minimal tokens) """ from llmeter.prompt_utils import save_payloads, load_payloads import base64 - # Create test images - image1_bytes = base64.b64decode( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" - ) - image2_bytes = base64.b64decode( - "iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAYAAABytg0kAAAAFElEQVR42mNk+M/wn4EIwDiqkL4KAcT9BAFZhEjRAAAAAElFTkSuQmCC" - ) - - # Use google.gemma-3-4b-it for multi-modal testing (supports images) - multimodal_model = "google.gemma-3-4b-it" + # Get test images as binary data + image1_binary, image2_binary = test_image_bytes - # Create complete OpenAI payload with multiple messages and images + # Create complete OpenAI payload with images stored as binary bytes + # Note: We store raw bytes in the payload, which will be serialized to base64 for storage complete_payload = { - "model": multimodal_model, + "model": bedrock_openai_multimodal_test_model, "messages": [ - { - "role": "system", - "content": "You are a helpful assistant that analyzes images.", - }, { "role": "user", "content": [ - {"type": "text", "text": "Compare these two images:"}, + {"type": "text", "text": "Describe these images briefly:"}, { "type": "image_url", - "image_url": {"url": image1_bytes, "detail": "high"}, + "image_url": {"url": image1_binary}, # Binary bytes }, { "type": "image_url", - "image_url": {"url": image2_bytes, "detail": "low"}, + "image_url": {"url": image2_binary}, # Binary bytes }, - {"type": "text", "text": "What are the differences?"}, ], }, ], - "max_tokens": 500, - "temperature": 0.7, - "top_p": 0.9, + "max_tokens": 100, } # Save and load the complete payload @@ -375,70 +371,59 @@ def test_save_load_openai_complete_structure( # Verify structure is preserved assert loaded["model"] == complete_payload["model"] - assert len(loaded["messages"]) == 2 - assert loaded["messages"][0]["role"] == "system" - assert loaded["messages"][1]["role"] == "user" - assert len(loaded["messages"][1]["content"]) == 4 - - # Verify image bytes are restored correctly - loaded_image1 = loaded["messages"][1]["content"][1]["image_url"]["url"] - loaded_image2 = loaded["messages"][1]["content"][2]["image_url"]["url"] + assert len(loaded["messages"]) == 1 + assert loaded["messages"][0]["role"] == "user" + assert len(loaded["messages"][0]["content"]) == 3 - assert isinstance(loaded_image1, bytes), "First image should be bytes" - assert isinstance(loaded_image2, bytes), "Second image should be bytes" - assert loaded_image1 == image1_bytes, "First image bytes should match" - assert loaded_image2 == image2_bytes, "Second image bytes should match" + # Verify image bytes are restored correctly from serialization + loaded_image1_binary = loaded["messages"][0]["content"][1]["image_url"]["url"] + loaded_image2_binary = loaded["messages"][0]["content"][2]["image_url"]["url"] - # Verify detail field is preserved - assert loaded["messages"][1]["content"][1]["image_url"]["detail"] == "high" - assert loaded["messages"][1]["content"][2]["image_url"]["detail"] == "low" + assert isinstance(loaded_image1_binary, bytes), "First image should be bytes" + assert isinstance(loaded_image2_binary, bytes), "Second image should be bytes" + assert loaded_image1_binary == image1_binary, "First image bytes should match" + assert loaded_image2_binary == image2_binary, "Second image bytes should match" # Verify full round-trip equality assert loaded == complete_payload, "Full payload should match after round-trip" # Verify the loaded payload can be used with the OpenAI client - # Note: OpenAI expects base64 data URIs, not raw bytes, so we need to convert - # the bytes back to data URI format for the actual API call + # Note: OpenAI API expects base64-encoded data URIs (ASCII strings), not raw bytes + # Multimodal models require the Bedrock Mantle endpoint, not bedrock-runtime token = provide_token(region=aws_region) - base_url = f"https://bedrock-runtime.{aws_region}.amazonaws.com/openai/v1" - client = OpenAI(api_key=token, base_url=base_url) + client = OpenAI(api_key=token, base_url=bedrock_openai_multimodal_endpoint_url) + + # Convert binary bytes to base64-encoded ASCII strings for API call + image1_base64_ascii = base64.b64encode(loaded_image1_binary).decode('utf-8') + image2_base64_ascii = base64.b64encode(loaded_image2_binary).decode('utf-8') - # Convert bytes back to data URIs for OpenAI API + # Build API payload - data URI format with JPEG MIME type api_payload = { - "model": loaded["model"], + "model": bedrock_openai_multimodal_test_model, "messages": [ - loaded["messages"][0], # System message { - "role": loaded["messages"][1]["role"], + "role": "user", "content": [ - loaded["messages"][1]["content"][0], # Text + {"type": "text", "text": "Describe these images briefly:"}, { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{base64.b64encode(loaded['messages'][1]['content'][1]['image_url']['url']).decode()}", - "detail": loaded["messages"][1]["content"][1]["image_url"][ - "detail" - ], + "url": f"data:image/jpeg;base64,{image1_base64_ascii}", }, }, { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{base64.b64encode(loaded['messages'][1]['content'][2]['image_url']['url']).decode()}", - "detail": loaded["messages"][1]["content"][2]["image_url"][ - "detail" - ], + "url": f"data:image/jpeg;base64,{image2_base64_ascii}", }, }, - loaded["messages"][1]["content"][3], # Text ], }, ], - "max_tokens": loaded["max_tokens"], - "temperature": loaded["temperature"], - "top_p": loaded["top_p"], + "max_tokens": 100, } + # Invoke the API with the loaded multimodal payload response = client.chat.completions.create(**api_payload) # Verify the client successfully processed the loaded payload @@ -450,13 +435,3 @@ def test_save_load_openai_complete_structure( assert len(response.choices[0].message.content) > 0, ( "Response text should not be empty" ) - - # Verify the client successfully processed the loaded payload - assert response.choices is not None, "Response should contain choices" - assert len(response.choices) > 0, "Response should have at least one choice" - assert response.choices[0].message.content is not None, ( - "Response should contain text" - ) - assert len(response.choices[0].message.content) > 0, ( - "Response text should not be empty" - ) diff --git a/tests/unit/endpoints/test_bedrock_multimodal.py b/tests/unit/endpoints/test_bedrock_multimodal.py new file mode 100644 index 0000000..c3a9a06 --- /dev/null +++ b/tests/unit/endpoints/test_bedrock_multimodal.py @@ -0,0 +1,247 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pytest + +from llmeter.endpoints.bedrock import BedrockBase + + +class TestBedrockMultiModal: + """Test multi-modal functionality for Bedrock endpoints.""" + + def test_create_payload_single_image_from_file(self): + """Test creating payload with single image from file path.""" + # Create a temporary image file + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(b"\xff\xd8\xff\xe0") # JPEG magic bytes + temp_path = f.name + + try: + payload = BedrockBase.create_payload( + user_message="What's in this image?", images=[temp_path], max_tokens=256 + ) + + assert "messages" in payload + assert len(payload["messages"]) == 1 + assert payload["messages"][0]["role"] == "user" + + content = payload["messages"][0]["content"] + assert len(content) == 2 # text + image + assert content[0]["text"] == "What's in this image?" + assert "image" in content[1] + assert content[1]["image"]["format"] == "jpeg" + assert "bytes" in content[1]["image"]["source"] + + finally: + Path(temp_path).unlink() + + def test_create_payload_single_image_from_bytes(self): + """Test creating payload with single image from bytes.""" + # Create a minimal valid JPEG file + # JPEG structure: SOI (FFD8) + APP0 marker + minimal data + EOI (FFD9) + jpeg_bytes = ( + b"\xff\xd8" # SOI (Start of Image) + b"\xff\xe0" # APP0 marker + b"\x00\x10" # APP0 length (16 bytes) + b"JFIF\x00" # JFIF identifier + b"\x01\x01" # JFIF version 1.1 + b"\x00" # density units (0 = no units) + b"\x00\x01" # X density = 1 + b"\x00\x01" # Y density = 1 + b"\x00\x00" # thumbnail width and height = 0 + b"\xff\xd9" # EOI (End of Image) + ) + + try: + payload = BedrockBase.create_payload( + user_message="What's in this image?", + images=[jpeg_bytes], + max_tokens=256, + ) + + assert "messages" in payload + content = payload["messages"][0]["content"] + assert len(content) == 2 # text + image + assert "image" in content[1] + assert content[1]["image"]["format"] == "jpeg" + assert content[1]["image"]["source"]["bytes"] == jpeg_bytes + except ValueError as e: + # If puremagic can't detect the format, skip this test + if "Cannot detect format from bytes" in str(e): + pytest.skip("puremagic cannot detect format from minimal JPEG bytes") + raise + + def test_create_payload_multiple_images(self): + """Test creating payload with multiple images.""" + # Create temporary image files + temp_files = [] + for i in range(2): + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n\x1a\n") # PNG magic bytes + temp_files.append(f.name) + + try: + payload = BedrockBase.create_payload( + user_message="Compare these images", images=temp_files, max_tokens=512 + ) + + content = payload["messages"][0]["content"] + assert len(content) == 3 # text + 2 images + assert content[0]["text"] == "Compare these images" + assert "image" in content[1] + assert "image" in content[2] + + finally: + for path in temp_files: + Path(path).unlink() + + def test_create_payload_mixed_content(self): + """Test creating payload with mixed content types.""" + # Create temporary files + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_file: + img_file.write(b"\xff\xd8\xff\xe0") + img_path = img_file.name + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as doc_file: + doc_file.write(b"%PDF-1.4") + doc_path = doc_file.name + + try: + payload = BedrockBase.create_payload( + user_message="Analyze this", + images=[img_path], + documents=[doc_path], + max_tokens=1024, + ) + + content = payload["messages"][0]["content"] + assert len(content) == 3 # text + image + document + assert content[0]["text"] == "Analyze this" + assert "image" in content[1] + assert content[1]["image"]["format"] == "jpeg" + assert "document" in content[2] + assert content[2]["document"]["format"] == "pdf" + + finally: + Path(img_path).unlink() + Path(doc_path).unlink() + + def test_create_payload_text_only_backward_compatible(self): + """Test that text-only payloads still work (backward compatibility).""" + payload = BedrockBase.create_payload( + user_message="Hello, world!", max_tokens=256 + ) + + assert "messages" in payload + content = payload["messages"][0]["content"] + assert len(content) == 1 + assert content[0]["text"] == "Hello, world!" + + def test_create_payload_empty_media_lists(self): + """Test that empty media lists are handled correctly.""" + payload = BedrockBase.create_payload( + user_message="Hello", images=[], documents=None, max_tokens=256 + ) + + # Should behave like text-only + content = payload["messages"][0]["content"] + assert len(content) == 1 + assert content[0]["text"] == "Hello" + + def test_create_payload_invalid_image_type(self): + """Test that invalid image types raise TypeError.""" + with pytest.raises( + TypeError, match="Items in images list must be bytes or str" + ): + BedrockBase.create_payload( + user_message="Test", + images=[123], # Invalid type + max_tokens=256, + ) + + def test_create_payload_invalid_images_not_list(self): + """Test that non-list images parameter raises TypeError.""" + with pytest.raises(TypeError, match="images must be a list"): + BedrockBase.create_payload( + user_message="Test", images="not_a_list", max_tokens=256 + ) + + def test_create_payload_missing_file(self): + """Test that missing file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File not found"): + BedrockBase.create_payload( + user_message="Test", images=["/nonexistent/file.jpg"], max_tokens=256 + ) + + def test_create_payload_bytes_without_puremagic(self): + """Test that bytes without detectable format raises ValueError.""" + # Random bytes that don't match any known format + random_bytes = b"\x00\x01\x02\x03\x04\x05" + + with pytest.raises(ValueError, match="Cannot detect format from bytes"): + BedrockBase.create_payload( + user_message="Test", images=[random_bytes], max_tokens=256 + ) + + def test_create_payload_file_without_extension(self): + """Test that file without recognized extension raises ValueError.""" + with tempfile.NamedTemporaryFile(suffix="", delete=False) as f: + f.write(b"some content") + temp_path = f.name + + try: + with pytest.raises( + ValueError, match="(Cannot detect format|Unsupported MIME type)" + ): + BedrockBase.create_payload( + user_message="Test", images=[temp_path], max_tokens=256 + ) + finally: + Path(temp_path).unlink() + + def test_create_payload_content_ordering(self): + """Test that content blocks are ordered correctly: text, images, videos, audio, documents.""" + # Create temporary files for each media type with proper extensions + # puremagic may not detect these minimal magic bytes, so we rely on extensions + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img: + img.write( + b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9" + ) + img_path = img.name + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as vid: + vid.write(b"\x00\x00\x00\x20ftypisom\x00\x00\x02\x00isomiso2mp41") + vid_path = vid.name + + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as aud: + aud.write(b"ID3\x03\x00\x00\x00\x00\x00\x00") + aud_path = aud.name + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as doc: + doc.write(b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n") + doc_path = doc.name + + try: + payload = BedrockBase.create_payload( + user_message="Analyze all", + images=[img_path], + videos=[vid_path], + audio=[aud_path], + documents=[doc_path], + max_tokens=1024, + ) + + content = payload["messages"][0]["content"] + assert len(content) == 5 # text + 4 media types + assert "text" in content[0] + assert "image" in content[1] + assert "video" in content[2] + assert "audio" in content[3] + assert "document" in content[4] + + finally: + for path in [img_path, vid_path, aud_path, doc_path]: + Path(path).unlink() diff --git a/tests/unit/endpoints/test_multimodal_properties.py b/tests/unit/endpoints/test_multimodal_properties.py new file mode 100644 index 0000000..e240546 --- /dev/null +++ b/tests/unit/endpoints/test_multimodal_properties.py @@ -0,0 +1,682 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Property-based tests for multi-modal payload functionality. + +This module contains property-based tests using Hypothesis to validate +the correctness properties defined in the multi-modal payload design document. +Each test validates universal properties that should hold across all valid inputs. +""" + +import json +import tempfile +from pathlib import Path + +import pytest +from hypothesis import HealthCheck, given, settings, strategies as st + +from llmeter.endpoints.bedrock import BedrockBase +from llmeter.prompt_utils import ( + LLMeterBytesEncoder, + llmeter_bytes_decoder, + load_payloads, + save_payloads, +) + + +# Test file content generators +@st.composite +def image_bytes(draw): + """Generate valid JPEG image bytes with magic bytes.""" + # JPEG magic bytes: FF D8 FF E0 + jpeg_header = b"\xff\xd8\xff\xe0" + # Add some random content + content_size = draw(st.integers(min_value=10, max_value=100)) + content = draw(st.binary(min_size=content_size, max_size=content_size)) + # JPEG end marker: FF D9 + jpeg_footer = b"\xff\xd9" + return jpeg_header + content + jpeg_footer + + +@st.composite +def png_bytes(draw): + """Generate valid PNG image bytes with magic bytes.""" + # PNG magic bytes: 89 50 4E 47 0D 0A 1A 0A + png_header = b"\x89PNG\r\n\x1a\n" + content_size = draw(st.integers(min_value=10, max_value=100)) + content = draw(st.binary(min_size=content_size, max_size=content_size)) + return png_header + content + + +@st.composite +def pdf_bytes(draw): + """Generate valid PDF bytes with magic bytes.""" + # PDF magic bytes: %PDF- + pdf_header = b"%PDF-1.4\n" + content_size = draw(st.integers(min_value=10, max_value=100)) + content = draw(st.binary(min_size=content_size, max_size=content_size)) + pdf_footer = b"\n%%EOF" + return pdf_header + content + pdf_footer + + +@st.composite +def mp4_bytes(draw): + """Generate valid MP4 video bytes with magic bytes.""" + # MP4 magic bytes typically start with ftyp + mp4_header = b"\x00\x00\x00\x20ftypisom" + content_size = draw(st.integers(min_value=10, max_value=100)) + content = draw(st.binary(min_size=content_size, max_size=content_size)) + return mp4_header + content + + +@st.composite +def mp3_bytes(draw): + """Generate valid MP3 audio bytes with magic bytes.""" + # MP3 magic bytes: ID3 or FF FB + mp3_header = b"ID3\x03\x00\x00" + content_size = draw(st.integers(min_value=10, max_value=100)) + content = draw(st.binary(min_size=content_size, max_size=content_size)) + return mp3_header + content + + +# Strategy for generating media bytes +media_bytes_strategy = st.one_of( + image_bytes(), + png_bytes(), + pdf_bytes(), + mp4_bytes(), + mp3_bytes(), +) + + +# Strategy for generating file extensions +image_extensions = st.sampled_from([".jpg", ".jpeg", ".png", ".gif", ".webp"]) +video_extensions = st.sampled_from([".mp4", ".mov", ".avi"]) +audio_extensions = st.sampled_from([".mp3", ".wav", ".ogg"]) +document_extensions = st.sampled_from([".pdf"]) + + +def create_temp_file(tmp_path: Path, content: bytes, extension: str) -> str: + """Create a temporary file with given content and extension.""" + file_path = tmp_path / f"test_file_{id(content)}{extension}" + file_path.write_bytes(content) + return str(file_path) + + +# Property-based tests + + +@given( + file_content=image_bytes(), + extension=image_extensions, +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_1_file_path_content_inclusion(file_content, extension): + """Property 1: File path content inclusion. + + **Validates: Requirements 1.2, 2.2, 3.2, 4.2** + + For any media type (image, video, audio, document) and valid file path, + when creating a payload with that file path in the corresponding parameter list, + the resulting payload SHALL contain a content block with the file's bytes and + the format detected from the file extension. + """ + # Create temporary file and use it within the same context + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + file_path = create_temp_file(tmp_path, file_content, extension) + + # Create payload with the file + payload = BedrockBase.create_payload( + user_message="Test message", images=[file_path], max_tokens=256 + ) + + # Verify payload structure + assert "messages" in payload + assert len(payload["messages"]) == 1 + assert "content" in payload["messages"][0] + + content_blocks = payload["messages"][0]["content"] + + # Find the image content block + image_blocks = [block for block in content_blocks if "image" in block] + assert len(image_blocks) == 1 + + # Verify the image block contains the file's bytes + image_block = image_blocks[0] + assert "image" in image_block + assert "source" in image_block["image"] + assert "bytes" in image_block["image"]["source"] + assert image_block["image"]["source"]["bytes"] == file_content + + # Verify format is detected + assert "format" in image_block["image"] + assert image_block["image"]["format"] in ["jpeg", "png", "gif", "webp"] + + +@given( + num_images=st.integers(min_value=1, max_value=5), + file_content=image_bytes(), +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_2_multiple_items_preservation(num_images, file_content): + """Property 2: Multiple items preservation. + + **Validates: Requirements 1.3, 2.3, 3.3, 4.3** + + For any media type and list of file paths, when creating a payload, + the resulting payload SHALL contain exactly as many content blocks of that + media type as there are items in the list, preserving their count. + """ + # Create multiple temporary files + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + file_paths = [] + for i in range(num_images): + file_path = create_temp_file(tmp_path, file_content, ".jpg") + file_paths.append(file_path) + + # Create payload with multiple images + payload = BedrockBase.create_payload( + user_message="Test message", images=file_paths, max_tokens=256 + ) + + # Verify payload structure + content_blocks = payload["messages"][0]["content"] + + # Count image blocks + image_blocks = [block for block in content_blocks if "image" in block] + + # Verify count matches input + assert len(image_blocks) == num_images + + +@given( + num_texts=st.integers(min_value=1, max_value=3), + num_images=st.integers(min_value=1, max_value=3), + num_videos=st.integers(min_value=0, max_value=2), + num_audio=st.integers(min_value=0, max_value=2), + num_documents=st.integers(min_value=0, max_value=2), +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_3_content_ordering_preservation( + num_texts, num_images, num_videos, num_audio, num_documents +): + """Property 3: Content ordering preservation. + + **Validates: Requirements 1.4** + + For any combination of text messages and media items, when creating a payload, + the resulting content array SHALL preserve the order: text blocks first (in order), + then media blocks (in the order: images, videos, audio, documents). + """ + # Create text messages + text_messages = [f"Text {i}" for i in range(num_texts)] + + # Create media files + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + images = [ + create_temp_file(tmp_path, b"\xff\xd8\xff\xe0test\xff\xd9", ".jpg") + for _ in range(num_images) + ] + videos = [ + create_temp_file(tmp_path, b"\x00\x00\x00\x20ftypisomtest", ".mp4") + for _ in range(num_videos) + ] + audio = [ + create_temp_file(tmp_path, b"ID3\x03\x00\x00test", ".mp3") + for _ in range(num_audio) + ] + documents = [ + create_temp_file(tmp_path, b"%PDF-1.4\ntest\n%%EOF", ".pdf") + for _ in range(num_documents) + ] + + # Create payload + payload = BedrockBase.create_payload( + user_message=text_messages if len(text_messages) > 1 else text_messages[0], + images=images if images else None, + videos=videos if videos else None, + audio=audio if audio else None, + documents=documents if documents else None, + max_tokens=256, + ) + + # Verify ordering + content_blocks = payload["messages"][0]["content"] + + # Extract block types in order + block_types = [] + for block in content_blocks: + if "text" in block: + block_types.append("text") + elif "image" in block: + block_types.append("image") + elif "video" in block: + block_types.append("video") + elif "audio" in block: + block_types.append("audio") + elif "document" in block: + block_types.append("document") + + # Verify text blocks come first + text_count = block_types.count("text") + assert text_count == num_texts + assert block_types[:text_count] == ["text"] * text_count + + # Verify media blocks follow in order: images, videos, audio, documents + media_blocks = block_types[text_count:] + expected_order = ( + ["image"] * num_images + + ["video"] * num_videos + + ["audio"] * num_audio + + ["document"] * num_documents + ) + assert media_blocks == expected_order + + +def valid_file_path_string(s: str) -> bool: + """Check if a string is a valid file path (no null bytes, not too long).""" + if not s or len(s) > 255: + return False + if "\x00" in s: + return False + # Check if it's a valid path that doesn't exist + try: + path = Path(s) + return not path.exists() + except (ValueError, OSError): + return False + + +@given( + non_existent_path=st.text(min_size=1, max_size=50).filter(valid_file_path_string) +) +@settings(max_examples=100) +def test_property_4_missing_file_error_handling(non_existent_path): + """Property 4: Missing file error handling. + + **Validates: Requirements 5.2** + + For any non-existent file path provided in any media parameter, + attempting to create a payload SHALL raise a FileNotFoundError + with a message containing the file path. + """ + with pytest.raises(FileNotFoundError) as exc_info: + BedrockBase.create_payload( + user_message="Test message", images=[non_existent_path], max_tokens=256 + ) + + # Verify error message contains the file path + assert non_existent_path in str(exc_info.value) + + +@given( + invalid_item=st.one_of( + st.integers(), + st.floats(), + st.dictionaries(st.text(), st.text()), + st.lists(st.integers()), + ) +) +@settings(max_examples=100) +def test_property_5_invalid_type_rejection(invalid_item): + """Property 5: Invalid type rejection. + + **Validates: Requirements 5.4** + + For any media parameter list containing items that are neither bytes nor strings, + attempting to create a payload SHALL raise a TypeError with a descriptive message. + """ + with pytest.raises(TypeError) as exc_info: + BedrockBase.create_payload( + user_message="Test message", images=[invalid_item], max_tokens=256 + ) + + # Verify error message is descriptive + error_msg = str(exc_info.value) + assert "images" in error_msg + assert "bytes" in error_msg or "str" in error_msg + + +@given( + user_message=st.one_of( + st.text(min_size=1, max_size=100), + st.lists(st.text(min_size=1, max_size=50), min_size=1, max_size=3), + ), + max_tokens=st.integers(min_value=1, max_value=4096), +) +@settings(max_examples=100) +def test_property_6_backward_compatibility(user_message, max_tokens): + """Property 6: Backward compatibility for text-only payloads. + + **Validates: Requirements 6.1, 6.2, 6.4** + + For any text input provided using the user_message parameter (without media parameters), + the new create_payload implementation SHALL produce output identical to the current + implementation, maintaining the same structure and field values. + """ + # Create payload with text only + payload = BedrockBase.create_payload( + user_message=user_message, max_tokens=max_tokens + ) + + # Verify expected structure + assert "messages" in payload + assert "inferenceConfig" in payload + assert payload["inferenceConfig"]["maxTokens"] == max_tokens + + # Verify messages structure + messages = payload["messages"] + if isinstance(user_message, str): + assert len(messages) == 1 + assert messages[0]["role"] == "user" + assert len(messages[0]["content"]) == 1 + assert messages[0]["content"][0]["text"] == user_message + else: + assert len(messages) == len(user_message) + for i, msg in enumerate(messages): + assert msg["role"] == "user" + assert len(msg["content"]) == 1 + assert msg["content"][0]["text"] == user_message[i] + + +@given( + max_tokens=st.integers(min_value=1, max_value=4096), + extra_kwargs=st.dictionaries( + st.text(min_size=1, max_size=20).filter( + lambda x: ( + x + not in [ + "messages", + "inferenceConfig", + "images", + "documents", + "videos", + "audio", + "user_message", + "max_tokens", + ] + ) + ), + st.one_of(st.text(), st.integers(), st.booleans()), + min_size=0, + max_size=3, + ), +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_7_existing_parameters_preservation(max_tokens, extra_kwargs): + """Property 7: Existing parameters preservation. + + **Validates: Requirements 6.3** + + For any values of max_tokens and additional kwargs, when creating a payload + (text-only or multi-modal), these parameters SHALL be preserved in the + resulting payload structure exactly as they are in the current implementation. + """ + # Create a test image file + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + image_path = create_temp_file(tmp_path, b"\xff\xd8\xff\xe0test\xff\xd9", ".jpg") + + # Create payload with extra kwargs + payload = BedrockBase.create_payload( + user_message="Test message", + images=[image_path], + max_tokens=max_tokens, + **extra_kwargs, + ) + + # Verify max_tokens is preserved + assert payload["inferenceConfig"]["maxTokens"] == max_tokens + + # Verify extra kwargs are preserved + for key, value in extra_kwargs.items(): + assert key in payload + assert payload[key] == value + + +@given( + image_content=image_bytes(), +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_8_serialization_round_trip(image_content): + """Property 8: Multi-modal payload serialization round-trip. + + **Validates: Requirements 7.1, 7.2, 7.3, 7.4, 7.5, 9.2, 9.3** + + For any valid multi-modal payload containing binary content, the round-trip + property SHALL hold: load_payloads(save_payloads(payload)) == payload, + preserving byte-for-byte equality of all binary content and exact equality + of all other values. + """ + # Create a test image file + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + image_path = create_temp_file(tmp_path, image_content, ".jpg") + + # Create payload with binary content + original_payload = BedrockBase.create_payload( + user_message="Test message", images=[image_path], max_tokens=256 + ) + + # Save and load the payload + save_path = save_payloads(original_payload, tmp_dir, "test_payload.jsonl") + loaded_payloads = list(load_payloads(save_path)) + + # Verify round-trip preservation + assert len(loaded_payloads) == 1 + loaded_payload = loaded_payloads[0] + + # Verify structure equality + assert loaded_payload == original_payload + + # Verify binary content is preserved byte-for-byte + original_bytes = original_payload["messages"][0]["content"][1]["image"][ + "source" + ]["bytes"] + loaded_bytes = loaded_payload["messages"][0]["content"][1]["image"]["source"][ + "bytes" + ] + assert original_bytes == loaded_bytes + assert isinstance(loaded_bytes, bytes) + + +@given( + extension=st.sampled_from([".jpg", ".png", ".pdf", ".mp4", ".mp3"]), +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_9_format_detection(extension): + """Property 9: Format detection from file content. + + **Validates: Requirements 8.5** + + For any file with recognizable content (JPEG, PNG, PDF, MP4, etc.), + when creating a payload, the format SHALL be correctly detected from + the file's magic bytes using puremagic (or from extension as fallback) + and included in the content block. + """ + # Generate appropriate content for the extension + if extension in [".jpg", ".jpeg"]: + file_content = b"\xff\xd8\xff\xe0test\xff\xd9" + media_param = "images" + media_key = "image" + elif extension == ".png": + file_content = b"\x89PNG\r\n\x1a\ntest" + media_param = "images" + media_key = "image" + elif extension == ".pdf": + file_content = b"%PDF-1.4\ntest\n%%EOF" + media_param = "documents" + media_key = "document" + elif extension == ".mp4": + file_content = b"\x00\x00\x00\x20ftypisomtest" + media_param = "videos" + media_key = "video" + elif extension == ".mp3": + file_content = b"ID3\x03\x00\x00test" + media_param = "audio" + media_key = "audio" + + # Create temporary file and use it within the same context + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + file_path = create_temp_file(tmp_path, file_content, extension) + + # Create payload + kwargs = {media_param: [file_path]} + payload = BedrockBase.create_payload( + user_message="Test message", max_tokens=256, **kwargs + ) + + # Find the media content block + content_blocks = payload["messages"][0]["content"] + media_blocks = [block for block in content_blocks if media_key in block] + assert len(media_blocks) == 1 + + # Verify format is detected and included + media_block = media_blocks[0] + assert "format" in media_block[media_key] + assert isinstance(media_block[media_key]["format"], str) + assert len(media_block[media_key]["format"]) > 0 + + +@given( + num_prompts=st.integers(min_value=1, max_value=5), +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_10_load_prompts_integration(num_prompts): + """Property 10: load_prompts integration with multi-modal payloads. + + **Validates: Requirements 9.1** + + For any create_payload call that produces multi-modal payloads, + when used with load_prompts, the function SHALL yield valid multi-modal + payloads that can be serialized and used with endpoint invoke methods. + """ + # Create a test image file + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + image_path = create_temp_file(tmp_path, b"\xff\xd8\xff\xe0test\xff\xd9", ".jpg") + + # Create a prompts file + prompts_file = tmp_path / "prompts.txt" + prompts = [f"Prompt {i}" for i in range(num_prompts)] + prompts_file.write_text("\n".join(prompts)) + + # Define create_payload function that produces multi-modal payloads + def create_multimodal_payload(input_text, **kwargs): + return BedrockBase.create_payload( + user_message=input_text, images=[image_path], max_tokens=256, **kwargs + ) + + # Use load_prompts with the multi-modal create_payload function + from llmeter.prompt_utils import load_prompts + + payloads = list( + load_prompts( + prompts_file, + create_payload_fn=create_multimodal_payload, + create_payload_kwargs={}, + ) + ) + + # Verify correct number of payloads + assert len(payloads) == num_prompts + + # Verify each payload is valid and contains multi-modal content + for i, payload in enumerate(payloads): + assert "messages" in payload + content_blocks = payload["messages"][0]["content"] + + # Should have text and image blocks + text_blocks = [block for block in content_blocks if "text" in block] + image_blocks = [block for block in content_blocks if "image" in block] + + assert len(text_blocks) >= 1 + assert len(image_blocks) == 1 + + # Verify the payload can be serialized + json_str = json.dumps(payload, cls=LLMeterBytesEncoder) + assert len(json_str) > 0 + + # Verify it can be deserialized + deserialized = json.loads(json_str, object_hook=llmeter_bytes_decoder) + assert deserialized == payload + + +@given( + extra_kwargs=st.dictionaries( + st.text(min_size=1, max_size=20).filter( + lambda x: ( + x + not in [ + "messages", + "inferenceConfig", + "user_message", + "max_tokens", + "images", + "documents", + "videos", + "audio", + ] + ) + ), + st.one_of(st.text(min_size=1), st.integers(), st.booleans()), + min_size=1, + max_size=3, + ) +) +@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_property_11_create_payload_kwargs_compatibility(extra_kwargs): + """Property 11: create_payload_kwargs pattern compatibility. + + **Validates: Requirements 9.4** + + For any dictionary of additional parameters passed via create_payload_kwargs, + when used with load_prompts or directly with create_payload, these parameters + SHALL be passed through and included in the resulting payload. + """ + # Create a test image file + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + image_path = create_temp_file(tmp_path, b"\xff\xd8\xff\xe0test\xff\xd9", ".jpg") + + # Create payload with extra kwargs + payload = BedrockBase.create_payload( + user_message="Test message", + images=[image_path], + max_tokens=256, + **extra_kwargs, + ) + + # Verify all extra kwargs are present in the payload + for key, value in extra_kwargs.items(): + assert key in payload + assert payload[key] == value + + # Test with load_prompts pattern + prompts_file = tmp_path / "prompts.txt" + prompts_file.write_text("Test prompt") + + def create_multimodal_payload(input_text, **kwargs): + return BedrockBase.create_payload( + user_message=input_text, images=[image_path], max_tokens=256, **kwargs + ) + + from llmeter.prompt_utils import load_prompts + + payloads = list( + load_prompts( + prompts_file, + create_payload_fn=create_multimodal_payload, + create_payload_kwargs=extra_kwargs, + ) + ) + + # Verify kwargs are passed through + assert len(payloads) == 1 + for key, value in extra_kwargs.items(): + assert key in payloads[0] + assert payloads[0][key] == value diff --git a/tests/unit/endpoints/test_multimodal_serialization.py b/tests/unit/endpoints/test_multimodal_serialization.py new file mode 100644 index 0000000..0d353fb --- /dev/null +++ b/tests/unit/endpoints/test_multimodal_serialization.py @@ -0,0 +1,253 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + + +from llmeter.endpoints.bedrock import BedrockBase +from llmeter.endpoints.openai import OpenAIEndpoint +from llmeter.endpoints.sagemaker import SageMakerBase +from llmeter.prompt_utils import save_payloads, load_payloads, load_prompts + + +class TestMultiModalSerialization: + """Test serialization and deserialization of multi-modal payloads.""" + + def test_save_and_load_single_image_payload(self): + """Test saving and loading a payload with a single image.""" + # Create a payload with image bytes + image_bytes = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9" + + # Create temporary image file + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(image_bytes) + temp_image_path = f.name + + try: + payload = BedrockBase.create_payload( + user_message="What's in this image?", + images=[temp_image_path], + max_tokens=256, + ) + + # Save payload + with tempfile.TemporaryDirectory() as temp_dir: + output_path = save_payloads(payload, temp_dir, "test_payload.jsonl") + assert output_path.exists() + + # Load payload + loaded_payloads = list(load_payloads(output_path)) + assert len(loaded_payloads) == 1 + + loaded_payload = loaded_payloads[0] + + # Verify structure + assert "messages" in loaded_payload + content = loaded_payload["messages"][0]["content"] + assert len(content) == 2 # text + image + + # Verify binary content is preserved + loaded_image_bytes = content[1]["image"]["source"]["bytes"] + assert isinstance(loaded_image_bytes, bytes) + assert loaded_image_bytes == image_bytes + + finally: + Path(temp_image_path).unlink() + + def test_save_and_load_multiple_content_types(self): + """Test saving and loading a payload with multiple content types.""" + # Create test files + image_bytes = b"\xff\xd8\xff\xe0" + pdf_bytes = b"%PDF-1.4\n" + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_file: + img_file.write(image_bytes) + img_path = img_file.name + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as doc_file: + doc_file.write(pdf_bytes) + doc_path = doc_file.name + + try: + payload = BedrockBase.create_payload( + user_message="Analyze this", + images=[img_path], + documents=[doc_path], + max_tokens=1024, + ) + + # Save and load + with tempfile.TemporaryDirectory() as temp_dir: + output_path = save_payloads(payload, temp_dir) + loaded_payloads = list(load_payloads(output_path)) + + assert len(loaded_payloads) == 1 + loaded_payload = loaded_payloads[0] + + # Verify all content is preserved + content = loaded_payload["messages"][0]["content"] + assert len(content) == 3 # text + image + document + + # Verify binary content + loaded_image = content[1]["image"]["source"]["bytes"] + loaded_doc = content[2]["document"]["source"]["bytes"] + + assert isinstance(loaded_image, bytes) + assert isinstance(loaded_doc, bytes) + assert loaded_image == image_bytes + assert loaded_doc == pdf_bytes + + finally: + Path(img_path).unlink() + Path(doc_path).unlink() + + def test_round_trip_preservation(self): + """Test that binary content is preserved byte-for-byte in round-trip.""" + # Create payload with bytes directly + image_bytes = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9" + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(image_bytes) + temp_path = f.name + + try: + original_payload = BedrockBase.create_payload( + user_message="Test", images=[temp_path], max_tokens=256 + ) + + # Save and load + with tempfile.TemporaryDirectory() as temp_dir: + output_path = save_payloads(original_payload, temp_dir) + loaded_payload = list(load_payloads(output_path))[0] + + # Verify exact equality + assert original_payload == loaded_payload + + finally: + Path(temp_path).unlink() + + def test_save_multiple_payloads(self): + """Test saving and loading multiple payloads.""" + image_bytes = b"\xff\xd8\xff\xe0" + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(image_bytes) + temp_path = f.name + + try: + payloads = [ + BedrockBase.create_payload( + user_message=f"Image {i}", images=[temp_path], max_tokens=256 + ) + for i in range(3) + ] + + # Save and load + with tempfile.TemporaryDirectory() as temp_dir: + output_path = save_payloads(payloads, temp_dir) + loaded_payloads = list(load_payloads(output_path)) + + assert len(loaded_payloads) == 3 + + # Verify each payload + for i, loaded in enumerate(loaded_payloads): + content = loaded["messages"][0]["content"] + assert content[0]["text"] == f"Image {i}" + assert isinstance(content[1]["image"]["source"]["bytes"], bytes) + + finally: + Path(temp_path).unlink() + + def test_load_prompts_with_multimodal_create_payload(self): + """Test load_prompts integration with multi-modal create_payload.""" + # Create a prompts file + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f: + f.write("What is this?\n") + f.write("Describe the image\n") + prompts_path = f.name + + # Create an image file + image_bytes = b"\xff\xd8\xff\xe0" + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_file: + img_file.write(image_bytes) + img_path = img_file.name + + try: + # Define a create_payload function that includes an image + def create_multimodal_payload(input_text, **kwargs): + return BedrockBase.create_payload( + user_message=input_text, images=[img_path], max_tokens=256, **kwargs + ) + + # Load prompts with multi-modal payload creation + payloads = list(load_prompts(prompts_path, create_multimodal_payload)) + + assert len(payloads) == 2 + + # Verify each payload has the image + for payload in payloads: + content = payload["messages"][0]["content"] + assert len(content) == 2 # text + image + assert "image" in content[1] + assert isinstance(content[1]["image"]["source"]["bytes"], bytes) + + finally: + Path(prompts_path).unlink() + Path(img_path).unlink() + + def test_openai_payload_serialization(self): + """Test serialization of OpenAI multi-modal payloads.""" + image_bytes = b"\xff\xd8\xff\xe0" + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(image_bytes) + temp_path = f.name + + try: + payload = OpenAIEndpoint.create_payload( + user_message="Test", images=[temp_path], max_tokens=256 + ) + + # Save and load + with tempfile.TemporaryDirectory() as temp_dir: + output_path = save_payloads(payload, temp_dir) + loaded_payload = list(load_payloads(output_path))[0] + + # Verify exact equality + assert payload == loaded_payload + + # Verify OpenAI-specific format + content = loaded_payload["messages"][0]["content"] + assert content[1]["image"]["format"] == "image/jpeg" + + finally: + Path(temp_path).unlink() + + def test_sagemaker_payload_serialization(self): + """Test serialization of SageMaker multi-modal payloads.""" + image_bytes = b"\xff\xd8\xff\xe0" + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(image_bytes) + temp_path = f.name + + try: + payload = SageMakerBase.create_payload( + input_text="Test", images=[temp_path], max_tokens=256 + ) + + # Save and load + with tempfile.TemporaryDirectory() as temp_dir: + output_path = save_payloads(payload, temp_dir) + loaded_payload = list(load_payloads(output_path))[0] + + # Verify exact equality + assert payload == loaded_payload + + # Verify SageMaker-specific format + content = loaded_payload["inputs"] + assert content[1]["image"]["format"] == "jpeg" + + finally: + Path(temp_path).unlink() diff --git a/tests/unit/endpoints/test_multimodal_utilities.py b/tests/unit/endpoints/test_multimodal_utilities.py new file mode 100644 index 0000000..1a7f82f --- /dev/null +++ b/tests/unit/endpoints/test_multimodal_utilities.py @@ -0,0 +1,212 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +from llmeter.prompt_utils import ( + read_file, + detect_format_from_extension, + detect_format_from_bytes, + detect_format_from_file, +) + + +class TestReadFile: + """Test the read_file utility function.""" + + def test_read_file_valid_path(self): + """Test reading a valid file.""" + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(b"test content") + temp_path = f.name + + try: + content = read_file(temp_path) + assert content == b"test content" + finally: + Path(temp_path).unlink() + + def test_read_file_nonexistent(self): + """Test reading a non-existent file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File not found"): + read_file("/nonexistent/file.txt") + + def test_read_file_binary_content(self): + """Test reading binary content.""" + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(b"\xff\xd8\xff\xe0") # JPEG magic bytes + temp_path = f.name + + try: + content = read_file(temp_path) + assert content == b"\xff\xd8\xff\xe0" + finally: + Path(temp_path).unlink() + + +class TestDetectFormatFromExtension: + """Test the detect_format_from_extension utility function.""" + + def test_detect_jpeg_extension(self): + """Test detecting JPEG format from .jpg extension.""" + mime_type = detect_format_from_extension("image.jpg") + assert mime_type == "image/jpeg" + + def test_detect_jpeg_extension_uppercase(self): + """Test detecting JPEG format from .JPG extension.""" + mime_type = detect_format_from_extension("image.JPG") + assert mime_type == "image/jpeg" + + def test_detect_png_extension(self): + """Test detecting PNG format from .png extension.""" + mime_type = detect_format_from_extension("image.png") + assert mime_type == "image/png" + + def test_detect_pdf_extension(self): + """Test detecting PDF format from .pdf extension.""" + mime_type = detect_format_from_extension("document.pdf") + assert mime_type == "application/pdf" + + def test_detect_mp4_extension(self): + """Test detecting MP4 format from .mp4 extension.""" + mime_type = detect_format_from_extension("video.mp4") + assert mime_type == "video/mp4" + + def test_detect_mp3_extension(self): + """Test detecting MP3 format from .mp3 extension.""" + mime_type = detect_format_from_extension("audio.mp3") + assert mime_type == "audio/mpeg" + + def test_detect_wav_extension(self): + """Test detecting WAV format from .wav extension.""" + mime_type = detect_format_from_extension("audio.wav") + assert mime_type == "audio/wav" + + def test_detect_unknown_extension(self): + """Test detecting unknown extension returns None.""" + mime_type = detect_format_from_extension("file.unknown") + assert mime_type is None + + def test_detect_no_extension(self): + """Test detecting file without extension returns None.""" + mime_type = detect_format_from_extension("file") + assert mime_type is None + + +class TestDetectFormatFromBytes: + """Test the detect_format_from_bytes utility function.""" + + def test_detect_jpeg_from_bytes_with_puremagic(self): + """Test detecting JPEG format from bytes with puremagic.""" + jpeg_bytes = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9" + try: + mime_type = detect_format_from_bytes(jpeg_bytes) + # If puremagic is installed, it should detect the format + if mime_type is not None: + assert mime_type == "image/jpeg" + else: + # If puremagic is not installed, it returns None + pytest.skip("puremagic not installed") + except ImportError: + pytest.skip("puremagic not installed") + + def test_detect_png_from_bytes_with_puremagic(self): + """Test detecting PNG format from bytes with puremagic.""" + png_bytes = b"\x89PNG\r\n\x1a\n" + try: + mime_type = detect_format_from_bytes(png_bytes) + # If puremagic is installed, it should detect the format + if mime_type is not None: + assert mime_type == "image/png" + else: + # If puremagic is not installed, it returns None + pytest.skip("puremagic not installed") + except ImportError: + pytest.skip("puremagic not installed") + + def test_detect_format_without_puremagic(self): + """Test that detection returns None when puremagic is not available.""" + # Mock puremagic to raise ImportError when accessed + with patch("llmeter.prompt_utils.puremagic") as mock_puremagic: + mock_puremagic.from_string.side_effect = ImportError("puremagic not available") + mime_type = detect_format_from_bytes(b"\xff\xd8\xff\xe0") + assert mime_type is None + + +class TestDetectFormatFromFile: + """Test the detect_format_from_file utility function.""" + + def test_detect_format_from_jpeg_file(self): + """Test detecting format from JPEG file.""" + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write( + b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xd9" + ) + temp_path = f.name + + try: + mime_type = detect_format_from_file(temp_path) + assert mime_type == "image/jpeg" + finally: + Path(temp_path).unlink() + + def test_detect_format_from_png_file(self): + """Test detecting format from PNG file.""" + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n\x1a\n") + temp_path = f.name + + try: + mime_type = detect_format_from_file(temp_path) + assert mime_type == "image/png" + finally: + Path(temp_path).unlink() + + def test_detect_format_from_pdf_file(self): + """Test detecting format from PDF file.""" + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: + f.write(b"%PDF-1.4\n") + temp_path = f.name + + try: + mime_type = detect_format_from_file(temp_path) + assert mime_type == "application/pdf" + finally: + Path(temp_path).unlink() + + def test_detect_format_fallback_to_extension(self): + """Test that detection falls back to extension when puremagic not installed.""" + # Create a file with JPEG magic bytes + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(b"\xff\xd8\xff\xe0") + temp_path = f.name + + try: + # Mock puremagic to raise ImportError when accessed + with patch("llmeter.prompt_utils.puremagic") as mock_puremagic: + mock_puremagic.magic_file.side_effect = ImportError("puremagic not available") + mime_type = detect_format_from_file(temp_path) + # Should fall back to extension-based detection + assert mime_type == "image/jpeg" + finally: + Path(temp_path).unlink() + + def test_detect_format_no_extension_no_puremagic(self): + """Test that detection returns None for file without extension when puremagic unavailable.""" + with tempfile.NamedTemporaryFile(suffix="", delete=False) as f: + f.write(b"some content") + temp_path = f.name + + try: + # Mock puremagic to raise ImportError when accessed + with patch("llmeter.prompt_utils.puremagic") as mock_puremagic: + mock_puremagic.magic_file.side_effect = ImportError("puremagic not available") + mime_type = detect_format_from_file(temp_path) + # Should fall back to extension-based detection, which returns None for no extension + assert mime_type is None + finally: + Path(temp_path).unlink() diff --git a/tests/unit/endpoints/test_openai_multimodal.py b/tests/unit/endpoints/test_openai_multimodal.py new file mode 100644 index 0000000..d587178 --- /dev/null +++ b/tests/unit/endpoints/test_openai_multimodal.py @@ -0,0 +1,144 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pytest + +from llmeter.endpoints.openai import OpenAIEndpoint + + +class TestOpenAIMultiModal: + """Test multi-modal functionality for OpenAI endpoints.""" + + def test_create_payload_single_image_from_file(self): + """Test creating payload with single image from file path.""" + # Create a temporary image file + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(b"\xff\xd8\xff\xe0") # JPEG magic bytes + temp_path = f.name + + try: + payload = OpenAIEndpoint.create_payload( + user_message="What's in this image?", images=[temp_path], max_tokens=256 + ) + + assert "messages" in payload + assert len(payload["messages"]) == 1 + assert payload["messages"][0]["role"] == "user" + + content = payload["messages"][0]["content"] + assert len(content) == 2 # text + image + assert content[0]["text"] == "What's in this image?" + assert "image" in content[1] + # OpenAI uses full MIME types + assert content[1]["image"]["format"] == "image/jpeg" + assert "bytes" in content[1]["image"]["source"] + + finally: + Path(temp_path).unlink() + + def test_create_payload_single_image_from_bytes(self): + """Test creating payload with single image from bytes.""" + # Create a minimal valid JPEG file + jpeg_bytes = ( + b"\xff\xd8" # SOI (Start of Image) + b"\xff\xe0" # APP0 marker + b"\x00\x10" # APP0 length (16 bytes) + b"JFIF\x00" # JFIF identifier + b"\x01\x01" # JFIF version 1.1 + b"\x00" # density units (0 = no units) + b"\x00\x01" # X density = 1 + b"\x00\x01" # Y density = 1 + b"\x00\x00" # thumbnail width and height = 0 + b"\xff\xd9" # EOI (End of Image) + ) + + try: + payload = OpenAIEndpoint.create_payload( + user_message="What's in this image?", + images=[jpeg_bytes], + max_tokens=256, + ) + + assert "messages" in payload + content = payload["messages"][0]["content"] + assert len(content) == 2 # text + image + assert "image" in content[1] + # OpenAI uses full MIME types + assert content[1]["image"]["format"] == "image/jpeg" + assert content[1]["image"]["source"]["bytes"] == jpeg_bytes + except ValueError as e: + # If puremagic can't detect the format, skip this test + if "Cannot detect format from bytes" in str(e): + pytest.skip("puremagic cannot detect format from minimal JPEG bytes") + raise + + def test_create_payload_mixed_content(self): + """Test creating payload with mixed content types.""" + # Create temporary files + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_file: + img_file.write(b"\xff\xd8\xff\xe0") + img_path = img_file.name + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as doc_file: + doc_file.write(b"%PDF-1.4") + doc_path = doc_file.name + + try: + payload = OpenAIEndpoint.create_payload( + user_message="Analyze this", + images=[img_path], + documents=[doc_path], + max_tokens=1024, + ) + + content = payload["messages"][0]["content"] + assert len(content) == 3 # text + image + document + assert content[0]["text"] == "Analyze this" + assert "image" in content[1] + # OpenAI uses full MIME types + assert content[1]["image"]["format"] == "image/jpeg" + assert "document" in content[2] + assert content[2]["document"]["format"] == "application/pdf" + + finally: + Path(img_path).unlink() + Path(doc_path).unlink() + + def test_create_payload_text_only_backward_compatible(self): + """Test that text-only payloads still work (backward compatibility).""" + payload = OpenAIEndpoint.create_payload( + user_message="Hello, world!", max_tokens=256 + ) + + assert "messages" in payload + content = payload["messages"][0]["content"] + # Text-only should be a string, not a list + assert content == "Hello, world!" + + def test_create_payload_invalid_image_type(self): + """Test that invalid image types raise TypeError.""" + with pytest.raises( + TypeError, match="Items in images list must be bytes or str" + ): + OpenAIEndpoint.create_payload( + user_message="Test", + images=[123], # Invalid type + max_tokens=256, + ) + + def test_create_payload_invalid_images_not_list(self): + """Test that non-list images parameter raises TypeError.""" + with pytest.raises(TypeError, match="images must be a list"): + OpenAIEndpoint.create_payload( + user_message="Test", images="not_a_list", max_tokens=256 + ) + + def test_create_payload_missing_file(self): + """Test that missing file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File not found"): + OpenAIEndpoint.create_payload( + user_message="Test", images=["/nonexistent/file.jpg"], max_tokens=256 + ) diff --git a/tests/unit/endpoints/test_sagemaker_multimodal.py b/tests/unit/endpoints/test_sagemaker_multimodal.py new file mode 100644 index 0000000..cb82236 --- /dev/null +++ b/tests/unit/endpoints/test_sagemaker_multimodal.py @@ -0,0 +1,140 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pytest + +from llmeter.endpoints.sagemaker import SageMakerBase + + +class TestSageMakerMultiModal: + """Test multi-modal functionality for SageMaker endpoints.""" + + def test_create_payload_single_image_from_file(self): + """Test creating payload with single image from file path.""" + # Create a temporary image file + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + f.write(b"\xff\xd8\xff\xe0") # JPEG magic bytes + temp_path = f.name + + try: + payload = SageMakerBase.create_payload( + input_text="What's in this image?", images=[temp_path], max_tokens=256 + ) + + assert "inputs" in payload + content = payload["inputs"] + assert len(content) == 2 # text + image + assert content[0]["text"] == "What's in this image?" + assert "image" in content[1] + # SageMaker uses Bedrock-style short format strings + assert content[1]["image"]["format"] == "jpeg" + assert "bytes" in content[1]["image"]["source"] + + finally: + Path(temp_path).unlink() + + def test_create_payload_single_image_from_bytes(self): + """Test creating payload with single image from bytes.""" + # Create a minimal valid JPEG file + jpeg_bytes = ( + b"\xff\xd8" # SOI (Start of Image) + b"\xff\xe0" # APP0 marker + b"\x00\x10" # APP0 length (16 bytes) + b"JFIF\x00" # JFIF identifier + b"\x01\x01" # JFIF version 1.1 + b"\x00" # density units (0 = no units) + b"\x00\x01" # X density = 1 + b"\x00\x01" # Y density = 1 + b"\x00\x00" # thumbnail width and height = 0 + b"\xff\xd9" # EOI (End of Image) + ) + + try: + payload = SageMakerBase.create_payload( + input_text="What's in this image?", + images=[jpeg_bytes], + max_tokens=256, + ) + + assert "inputs" in payload + content = payload["inputs"] + assert len(content) == 2 # text + image + assert "image" in content[1] + # SageMaker uses Bedrock-style short format strings + assert content[1]["image"]["format"] == "jpeg" + assert content[1]["image"]["source"]["bytes"] == jpeg_bytes + except ValueError as e: + # If puremagic can't detect the format, skip this test + if "Cannot detect format from bytes" in str(e): + pytest.skip("puremagic cannot detect format from minimal JPEG bytes") + raise + + def test_create_payload_mixed_content(self): + """Test creating payload with mixed content types.""" + # Create temporary files + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_file: + img_file.write(b"\xff\xd8\xff\xe0") + img_path = img_file.name + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as doc_file: + doc_file.write(b"%PDF-1.4") + doc_path = doc_file.name + + try: + payload = SageMakerBase.create_payload( + input_text="Analyze this", + images=[img_path], + documents=[doc_path], + max_tokens=1024, + ) + + content = payload["inputs"] + assert len(content) == 3 # text + image + document + assert content[0]["text"] == "Analyze this" + assert "image" in content[1] + # SageMaker uses Bedrock-style short format strings + assert content[1]["image"]["format"] == "jpeg" + assert "document" in content[2] + assert content[2]["document"]["format"] == "pdf" + + finally: + Path(img_path).unlink() + Path(doc_path).unlink() + + def test_create_payload_text_only_backward_compatible(self): + """Test that text-only payloads still work (backward compatibility).""" + payload = SageMakerBase.create_payload( + input_text="Hello, world!", max_tokens=256 + ) + + assert "inputs" in payload + # Text-only should be a string, not a list + assert payload["inputs"] == "Hello, world!" + + def test_create_payload_invalid_image_type(self): + """Test that invalid image types raise TypeError.""" + with pytest.raises( + TypeError, match="Items in images list must be bytes or str" + ): + SageMakerBase.create_payload( + input_text="Test", + images=[123], # Invalid type + max_tokens=256, + ) + + def test_create_payload_invalid_images_not_list(self): + """Test that non-list images parameter raises TypeError.""" + with pytest.raises(TypeError, match="images must be a list"): + SageMakerBase.create_payload( + input_text="Test", images="not_a_list", max_tokens=256 + ) + + def test_create_payload_missing_file(self): + """Test that missing file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="File not found"): + SageMakerBase.create_payload( + input_text="Test", images=["/nonexistent/file.jpg"], max_tokens=256 + ) diff --git a/tests/unit/test_prompt_utils.py b/tests/unit/test_prompt_utils.py index c0d1a0f..ca74ba8 100644 --- a/tests/unit/test_prompt_utils.py +++ b/tests/unit/test_prompt_utils.py @@ -1805,9 +1805,6 @@ def test_deserialization_no_unnecessary_copies(self): # Serialize first serialized = json.dumps(payload, cls=LLMeterBytesEncoder) - # Get the size of the serialized data - serialized_size = sys.getsizeof(serialized) - # Deserialize from llmeter.prompt_utils import llmeter_bytes_decoder diff --git a/tests/unit/test_property_save_load.py b/tests/unit/test_property_save_load.py index 5b0a7c9..7b96dc7 100644 --- a/tests/unit/test_property_save_load.py +++ b/tests/unit/test_property_save_load.py @@ -156,7 +156,7 @@ class TestPayloadSaveLoadProperties: @given( st.lists( st.dictionaries( - st.text(min_size=1, max_size=50), + st.text(min_size=1, max_size=50).filter(lambda x: x != "__llmeter_bytes__"), st.one_of( st.text(max_size=200), st.integers(), @@ -172,7 +172,11 @@ class TestPayloadSaveLoadProperties: ) @settings(deadline=None) def test_save_load_payloads_preserves_data(self, payloads): - """Save/load roundtrip should preserve all payload data.""" + """Save/load roundtrip should preserve all payload data. + + Note: Excludes the reserved key '__llmeter_bytes__' which is used internally + for binary content serialization. + """ with tempfile.TemporaryDirectory() as tmpdir: output_path = Path(tmpdir) diff --git a/tests/unit/test_tokenizers.py b/tests/unit/test_tokenizers.py index f441392..37dde1b 100644 --- a/tests/unit/test_tokenizers.py +++ b/tests/unit/test_tokenizers.py @@ -13,6 +13,19 @@ save_tokenizer, ) +# Check for optional dependencies +try: + import transformers # noqa: F401 + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + +try: + import tiktoken # noqa: F401 + TIKTOKEN_AVAILABLE = True +except ImportError: + TIKTOKEN_AVAILABLE = False + # Mock classes for testing class MockTransformersTokenizer: @@ -118,15 +131,29 @@ def test_save_tokenizer(tmp_path): # Test _load_tokenizer_from_info function -@pytest.mark.skip(reason="transformers is not installed") +@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="transformers is not installed") def test_load_tokenizer_from_info_transformers(monkeypatch): + # Mock AutoTokenizer.from_pretrained to return our mock + def mock_from_pretrained(name): + return MockTransformersTokenizer() + + from transformers import AutoTokenizer + monkeypatch.setattr(AutoTokenizer, "from_pretrained", mock_from_pretrained) + tokenizer_info = {"tokenizer_module": "transformers", "name": "mock-transformer"} tokenizer = _load_tokenizer_from_info(tokenizer_info) assert isinstance(tokenizer, MockTransformersTokenizer) -@pytest.mark.skip(reason="tiktoken is not installed") +@pytest.mark.skipif(not TIKTOKEN_AVAILABLE, reason="tiktoken is not installed") def test_load_tokenizer_from_info_tiktoken(monkeypatch): + # Mock get_encoding to return our mock + def mock_get_encoding(name): + return MockTiktokenTokenizer() + + import tiktoken + monkeypatch.setattr(tiktoken, "get_encoding", mock_get_encoding) + tokenizer_info = {"tokenizer_module": "tiktoken", "name": "mock-tiktoken"} tokenizer = _load_tokenizer_from_info(tokenizer_info) assert isinstance(tokenizer, MockTiktokenTokenizer) From 7a64e1049661550710d8f44d14e2aab9b4cde1c2 Mon Sep 17 00:00:00 2001 From: Alessandro Cere Date: Wed, 11 Mar 2026 12:20:46 +0800 Subject: [PATCH 3/9] refactor: Standardize file operations to use UPath for cross-platform compatibility - Replace built-in open() calls with UPath.open() across all file operations - Add UPath import to cost model for consistent path handling - Update cost model save_to_file() to create parent directories automatically - Standardize file reading in prompt_utils, results, runner, and tokenizers modules - Improves cross-platform file system support and enables cloud storage integration --- llmeter/callbacks/cost/model.py | 9 +++++++-- llmeter/prompt_utils.py | 3 ++- llmeter/runner.py | 2 +- llmeter/tokenizers.py | 5 +++-- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/llmeter/callbacks/cost/model.py b/llmeter/callbacks/cost/model.py index d491aad..3641695 100644 --- a/llmeter/callbacks/cost/model.py +++ b/llmeter/callbacks/cost/model.py @@ -4,6 +4,8 @@ from dataclasses import dataclass, field import importlib +from upath import UPath as Path + # Local Dependencies: from ...endpoints.base import InvocationResponse from ...results import Result @@ -201,7 +203,9 @@ async def after_run(self, result: Result) -> None: def save_to_file(self, path: str) -> None: """Save the cost model (including all dimensions) to a JSON file""" - with open(path, "w") as f: + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as f: f.write(self.to_json()) @classmethod @@ -222,5 +226,6 @@ def from_dict(cls, raw: dict, alt_classes: dict = {}, **kwargs) -> "CostModel": @classmethod def _load_from_file(cls, path: str): """Load the cost model (including all dimensions) from a JSON file""" - with open(path, "r") as f: + path = Path(path) + with path.open("r") as f: return cls.from_json(f.read()) diff --git a/llmeter/prompt_utils.py b/llmeter/prompt_utils.py index 61db642..882d318 100644 --- a/llmeter/prompt_utils.py +++ b/llmeter/prompt_utils.py @@ -45,7 +45,8 @@ def read_file(file_path: str) -> bytes: IOError: If file cannot be read """ try: - with open(file_path, "rb") as f: + _path = Path(file_path) + with _path.open("rb") as f: return f.read() except FileNotFoundError: raise FileNotFoundError(f"File not found: {file_path}") diff --git a/llmeter/runner.py b/llmeter/runner.py index e8bf265..db2f94a 100644 --- a/llmeter/runner.py +++ b/llmeter/runner.py @@ -138,7 +138,7 @@ def load(cls, load_path: Path | str, file_name: str = "run_config.json"): file_name: File name within `output_path` for the run configuration JSON. """ load_path = Path(load_path) - with open(load_path / file_name) as f: + with (load_path / file_name).open() as f: config = json.load(f) config["endpoint"] = Endpoint.load(config["endpoint"]) config["tokenizer"] = Tokenizer.load(config["tokenizer"]) diff --git a/llmeter/tokenizers.py b/llmeter/tokenizers.py index 6f4e893..91a606b 100644 --- a/llmeter/tokenizers.py +++ b/llmeter/tokenizers.py @@ -54,7 +54,8 @@ def load_from_file(tokenizer_path: UPath | None) -> Tokenizer: """ if tokenizer_path is None: return DummyTokenizer() - with open(tokenizer_path, "r") as f: + tokenizer_path = UPath(tokenizer_path) + with tokenizer_path.open("r") as f: tokenizer_info = json.load(f) return _load_tokenizer_from_info(tokenizer_info) @@ -123,7 +124,7 @@ def save_tokenizer(tokenizer: Any, output_path: UPath | str) -> UPath: output_path = UPath(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "w") as f: + with output_path.open("w") as f: json.dump(tokenizer_info, f) return output_path From e0ecacb70dbb02f34d40770d3bcf958b9eea14cd Mon Sep 17 00:00:00 2001 From: Alessandro Cere Date: Mon, 16 Mar 2026 23:23:47 +0800 Subject: [PATCH 4/9] fix: serialize datetime fields in Result.to_dict for JSON-safe stats Result.stats included raw datetime objects (start_time, end_time) from to_dict(), causing TypeError when users called json.dumps() without a custom serializer. Now to_dict() converts datetime fields via utc_datetime_serializer so stats is always directly serializable. --- llmeter/results.py | 13 ++++++++----- tests/unit/test_results.py | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/llmeter/results.py b/llmeter/results.py index b02cac7..04bf01a 100644 --- a/llmeter/results.py +++ b/llmeter/results.py @@ -174,12 +174,15 @@ def to_json(self, **kwargs): return json.dumps(summary, default=utc_datetime_serializer, **kwargs) def to_dict(self, include_responses: bool = False): - """Return the results as a dictionary.""" + """Return the results as a dictionary with JSON-serializable values.""" + data = asdict(self) + # Serialize datetime objects so stats dict is always JSON-safe + for key in ("start_time", "end_time"): + if key in data and isinstance(data[key], datetime): + data[key] = utc_datetime_serializer(data[key]) if include_responses: - return asdict(self) - return { - k: o for k, o in asdict(self).items() if k not in ["responses", "stats"] - } + return data + return {k: v for k, v in data.items() if k not in ["responses", "stats"]} def load_responses(self) -> list[InvocationResponse]: """ diff --git a/tests/unit/test_results.py b/tests/unit/test_results.py index cdc4deb..2273b6b 100644 --- a/tests/unit/test_results.py +++ b/tests/unit/test_results.py @@ -217,6 +217,27 @@ def test_stats_property_empty_result(): assert f"{metric}-{stat}" not in stats +def test_stats_json_serializable_with_datetimes(): + """stats dict should be directly JSON-serializable even with datetime fields.""" + from datetime import datetime, timezone + + result = Result( + responses=sample_responses_successful[:2], + total_requests=2, + clients=1, + n_requests=2, + total_test_time=1.0, + start_time=datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + end_time=datetime(2025, 1, 1, 12, 0, 1, tzinfo=timezone.utc), + ) + stats = result.stats + # Should not raise TypeError + json_str = json.dumps(stats) + parsed = json.loads(json_str) + assert parsed["start_time"] == "2025-01-01T12:00:00Z" + assert parsed["end_time"] == "2025-01-01T12:00:01Z" + + @pytest.fixture def temp_dir(tmp_path: Path): return UPath(tmp_path) From 4f2e79a195072e4c80506959e1e1a190f79b1cc0 Mon Sep 17 00:00:00 2001 From: Alessandro Cere Date: Mon, 30 Mar 2026 14:13:52 -0700 Subject: [PATCH 5/9] refactor: replace os.PathLike with UPath type aliases and add ensure_path helper Replace scattered `os.PathLike | str` type annotations with the proper UPath type aliases from `upath.types`: - `ReadablePathLike` for parameters used to load/read data - `WritablePathLike` for parameters used to save/write data Add `ensure_path()` utility to `llmeter.utils` to centralize the `Path(x)` normalization boilerplate that was duplicated at the top of nearly every function accepting a path argument. The helper handles None passthrough and uses a lazy UPath import. Runtime `isinstance(obj, os.PathLike)` checks in serialization helpers are left unchanged since TypeAliases cannot be used for runtime checks. --- llmeter/callbacks/base.py | 9 +++++---- llmeter/callbacks/cost/model.py | 10 ++++----- llmeter/callbacks/cost/serde.py | 15 ++++++++------ llmeter/endpoints/base.py | 11 ++++++---- llmeter/experiments.py | 26 +++++++++++++----------- llmeter/prompt_utils.py | 29 +++++++++++++------------- llmeter/results.py | 18 ++++++++--------- llmeter/runner.py | 21 ++++++++++--------- llmeter/utils.py | 36 +++++++++++++++++++++++++++++++++ 9 files changed, 111 insertions(+), 64 deletions(-) diff --git a/llmeter/callbacks/base.py b/llmeter/callbacks/base.py index 5c6a5d2..f4c4dba 100644 --- a/llmeter/callbacks/base.py +++ b/llmeter/callbacks/base.py @@ -4,10 +4,11 @@ from __future__ import annotations -import os from abc import ABC from typing import final +from upath.types import ReadablePathLike, WritablePathLike + from ..endpoints.base import InvocationResponse from ..results import Result from ..runner import _RunConfig @@ -70,7 +71,7 @@ async def after_run(self, result: Result) -> None: """ pass - def save_to_file(self, path: os.PathLike | str) -> None: + def save_to_file(self, path: WritablePathLike) -> None: """Save this Callback to file Individual Callbacks implement this method to save their configuration to a file that will @@ -83,7 +84,7 @@ def save_to_file(self, path: os.PathLike | str) -> None: @staticmethod @final - def load_from_file(path: os.PathLike | str) -> Callback: + def load_from_file(path: ReadablePathLike) -> Callback: """Load (any type of) Callback from file `Callback.load_from_file()` attempts to detect the type of Callback saved in a given file, @@ -99,7 +100,7 @@ def load_from_file(path: os.PathLike | str) -> Callback: ) @classmethod - def _load_from_file(cls, path: os.PathLike | str) -> Callback: + def _load_from_file(cls, path: ReadablePathLike) -> Callback: """Load this Callback from file Individual Callbacks implement this method to define how they can be loaded from files diff --git a/llmeter/callbacks/cost/model.py b/llmeter/callbacks/cost/model.py index 3641695..8e17664 100644 --- a/llmeter/callbacks/cost/model.py +++ b/llmeter/callbacks/cost/model.py @@ -1,10 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 # Python Built-Ins: -from dataclasses import dataclass, field import importlib +from dataclasses import dataclass, field -from upath import UPath as Path +from llmeter.utils import ensure_path # Local Dependencies: from ...endpoints.base import InvocationResponse @@ -13,7 +13,7 @@ from ..base import Callback from .dimensions import IRequestCostDimension, IRunCostDimension from .results import CalculatedCostWithDimensions -from .serde import from_dict_with_class_map, JSONableBase +from .serde import JSONableBase, from_dict_with_class_map @dataclass @@ -203,7 +203,7 @@ async def after_run(self, result: Result) -> None: def save_to_file(self, path: str) -> None: """Save the cost model (including all dimensions) to a JSON file""" - path = Path(path) + path = ensure_path(path) path.parent.mkdir(parents=True, exist_ok=True) with path.open("w") as f: f.write(self.to_json()) @@ -226,6 +226,6 @@ def from_dict(cls, raw: dict, alt_classes: dict = {}, **kwargs) -> "CostModel": @classmethod def _load_from_file(cls, path: str): """Load the cost model (including all dimensions) from a JSON file""" - path = Path(path) + path = ensure_path(path) with path.open("r") as f: return cls.from_json(f.read()) diff --git a/llmeter/callbacks/cost/serde.py b/llmeter/callbacks/cost/serde.py index 4a1bb97..5df9a7f 100644 --- a/llmeter/callbacks/cost/serde.py +++ b/llmeter/callbacks/cost/serde.py @@ -3,15 +3,18 @@ """(De/re)serialization interfaces for saving Cost Model objects to file and loading them back""" # Python Built-Ins: -from dataclasses import is_dataclass -from datetime import date, datetime, time import json import logging import os +from dataclasses import is_dataclass +from datetime import date, datetime, time from typing import Any, Callable, Dict, Protocol, Type, TypeVar # External Dependencies: from upath import UPath as Path +from upath.types import ReadablePathLike, WritablePathLike + +from llmeter.utils import ensure_path logger = logging.getLogger(__name__) @@ -146,13 +149,13 @@ def from_dict( return from_dict_with_class(raw=raw, cls=cls, **kwargs) @classmethod - def from_file(cls: Type[TJSONable], input_path: os.PathLike, **kwargs) -> TJSONable: + def from_file(cls: Type[TJSONable], input_path: ReadablePathLike, **kwargs) -> TJSONable: """Initialize an instance of this class from a (local or Cloud) JSON file Args: input_path: The path to the JSON data file. **kwargs: Optional extra keyword arguments to pass to `from_dict()` """ - input_path = Path(input_path) + input_path = ensure_path(input_path) with input_path.open("r") as f: return cls.from_json(f.read(), **kwargs) @@ -177,7 +180,7 @@ def to_dict(self, **kwargs) -> dict: def to_file( self, - output_path: os.PathLike, + output_path: WritablePathLike, indent: int | str | None = 4, default: Callable[[Any], Any] | None = str, **kwargs: Any, @@ -200,7 +203,7 @@ def default(obj): return Path(obj).as_posix() return str(obj) - output_path = Path(output_path) + output_path = ensure_path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w") as f: f.write(self.to_json(indent=indent, default=default, **kwargs)) diff --git a/llmeter/endpoints/base.py b/llmeter/endpoints/base.py index 62659d2..b1d6136 100644 --- a/llmeter/endpoints/base.py +++ b/llmeter/endpoints/base.py @@ -14,6 +14,9 @@ from uuid import uuid4 from upath import UPath as Path +from upath.types import ReadablePathLike, WritablePathLike + +from llmeter.utils import ensure_path # @dataclass(slots=True) @@ -262,7 +265,7 @@ def __subclasshook__(cls, C: type) -> bool: return True return NotImplemented - def save(self, output_path: os.PathLike) -> os.PathLike: + def save(self, output_path: WritablePathLike) -> Path: """ Save the endpoint configuration to a JSON file. @@ -275,7 +278,7 @@ def save(self, output_path: os.PathLike) -> os.PathLike: Returns: None """ - output_path = Path(output_path) + output_path = ensure_path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w") as f: endpoint_conf = self.to_dict() @@ -306,7 +309,7 @@ def to_dict(self) -> dict: return endpoint_conf @classmethod - def load_from_file(cls, input_path: os.PathLike) -> "Endpoint": + def load_from_file(cls, input_path: ReadablePathLike) -> "Endpoint": """ Load an endpoint configuration from a JSON file. @@ -322,7 +325,7 @@ def load_from_file(cls, input_path: os.PathLike) -> "Endpoint": with the configuration from the file. """ - input_path = Path(input_path) + input_path = ensure_path(input_path) with input_path.open("r") as f: data = json.load(f) endpoint_type = data.pop("endpoint_type") diff --git a/llmeter/experiments.py b/llmeter/experiments.py index 0a70a17..dcf6a0e 100644 --- a/llmeter/experiments.py +++ b/llmeter/experiments.py @@ -15,12 +15,14 @@ from tqdm.auto import tqdm from upath import UPath as Path +from upath.types import ReadablePathLike, WritablePathLike from llmeter.callbacks.base import Callback from llmeter.results import Result +from llmeter.utils import ensure_path from .endpoints.base import Endpoint -from .plotting import plot_heatmap, plot_load_test_results, color_sequences +from .plotting import color_sequences, plot_heatmap, plot_load_test_results from .prompt_utils import CreatePromptCollection from .runner import Runner from .tokenizers import Tokenizer @@ -38,7 +40,7 @@ class LoadTestResult: results: dict[int, Result] test_name: str - output_path: os.PathLike | str | None = None + output_path: WritablePathLike | None = None def plot_results(self, show: bool = True, format: Literal["html", "png"] = "html"): figs = plot_load_test_results(self) @@ -57,7 +59,7 @@ def plot_results(self, show: bool = True, format: Literal["html", "png"] = "html f.update_layout(colorway=c_seqs[i % len(c_seqs)]) if self.output_path is not None: - output_path = Path(self.output_path) + output_path = ensure_path(self.output_path) # save figure to the output path output_path.parent.mkdir(parents=True, exist_ok=True) for k, f in figs.items(): @@ -96,7 +98,7 @@ def load( raise FileNotFoundError("Load path cannot be None or empty") if isinstance(load_path, str): - load_path = Path(load_path) + load_path = ensure_path(load_path) if not load_path.exists(): raise FileNotFoundError(f"Load path {load_path} does not exist") @@ -130,7 +132,7 @@ class LoadTest: sequence_of_clients: list[int] min_requests_per_client: int = 1 min_requests_per_run: int = 10 - output_path: os.PathLike | str | None = None + output_path: WritablePathLike | None = None tokenizer: Tokenizer | None = None test_name: str | None = None callbacks: list[Callback] | None = None @@ -143,9 +145,9 @@ def _get_n_requests(self, clients): return int(ceil(self.min_requests_per_run / clients)) return int(self.min_requests_per_client) - async def run(self, output_path: os.PathLike | None = None): + async def run(self, output_path: WritablePathLike | None = None): try: - output_path = Path(output_path or self.output_path) / self._test_name + output_path = ensure_path(output_path or self.output_path) / self._test_name except Exception: output_path = None _runner = Runner( @@ -209,9 +211,9 @@ class LatencyHeatmap: """ endpoint: Endpoint - source_file: os.PathLike | str + source_file: ReadablePathLike clients: int = 4 - output_path: os.PathLike | str | None = None + output_path: WritablePathLike | None = None input_lengths: list[int] = field(default_factory=lambda: [10, 50, 200, 500]) output_lengths: list[int] = field(default_factory=lambda: [128, 256, 512, 1024]) requests_per_combination: int = 1 @@ -224,7 +226,7 @@ def __post_init__(self) -> None: requests_per_combination=self.requests_per_combination, input_lengths=self.input_lengths, output_lengths=self.output_lengths, - source_file=Path(self.source_file), + source_file=ensure_path(self.source_file), tokenizer=self.tokenizer, # type: ignore ) @@ -239,7 +241,7 @@ def __post_init__(self) -> None: self._runner = Runner( endpoint=self.endpoint, - output_path=Path(self.output_path) + output_path=ensure_path(self.output_path) if self.output_path is not None else None, tokenizer=self.tokenizer, @@ -249,7 +251,7 @@ async def run(self, output_path=None): # Handle None output_path properly final_output_path = output_path or self.output_path if final_output_path is not None: - final_output_path = Path(final_output_path) + final_output_path = ensure_path(final_output_path) heatmap_results = await self._runner.run( payload=self.payload, diff --git a/llmeter/prompt_utils.py b/llmeter/prompt_utils.py index 882d318..f0a6d56 100644 --- a/llmeter/prompt_utils.py +++ b/llmeter/prompt_utils.py @@ -4,16 +4,17 @@ import base64 import json import logging -from dataclasses import dataclass -from itertools import product import os import random +from dataclasses import dataclass +from itertools import product from typing import Any, Callable, Iterator from upath import UPath as Path +from upath.types import ReadablePathLike, WritablePathLike from .tokenizers import DummyTokenizer, Tokenizer -from .utils import DeferredError +from .utils import DeferredError, ensure_path logger = logging.getLogger(__name__) @@ -45,7 +46,7 @@ def read_file(file_path: str) -> bytes: IOError: If file cannot be read """ try: - _path = Path(file_path) + _path = ensure_path(file_path) with _path.open("rb") as f: return f.read() except FileNotFoundError: @@ -69,7 +70,7 @@ def detect_format_from_extension(file_path: str) -> str | None: >>> detect_format_from_extension("document.pdf") "application/pdf" """ - extension = Path(file_path).suffix.lower() + extension = ensure_path(file_path).suffix.lower() # Map common extensions to MIME types extension_to_mime = { @@ -210,7 +211,7 @@ def llmeter_bytes_decoder(dct: dict) -> dict | bytes: class CreatePromptCollection: input_lengths: list[int] output_lengths: list[int] - source_file: os.PathLike + source_file: ReadablePathLike requests_per_combination: int = 1 tokenizer: Tokenizer | None = None source_file_encoding: str = "utf-8-sig" @@ -226,8 +227,8 @@ def create_collection(self) -> list[Any]: ) return random.sample(collection, k=len(collection)) - def _generate_sample(self, source_file: os.PathLike, sample_size: int) -> str: - source_file = Path(source_file) + def _generate_sample(self, source_file: ReadablePathLike, sample_size: int) -> str: + source_file = ensure_path(source_file) sample = [] with source_file.open(encoding=self.source_file_encoding, mode="r") as f: for line in f: @@ -244,7 +245,7 @@ def _generate_samples(self) -> None: def load_prompts( - file_path: os.PathLike, + file_path: ReadablePathLike, create_payload_fn: Callable, create_payload_kwargs: dict = {}, file_pattern: str | None = None, @@ -273,7 +274,7 @@ def load_prompts( """ - file_path = Path(file_path) + file_path = ensure_path(file_path) if file_path.is_file(): with file_path.open(mode="r") as f: for line in f: @@ -298,7 +299,7 @@ def load_prompts( def load_payloads( - file_path: os.PathLike | str, + file_path: ReadablePathLike, deserializer: Callable[[str], dict] | None = None, ) -> Iterator[dict]: """ @@ -383,7 +384,7 @@ def load_payloads( >>> original == loaded True """ - file_path = Path(file_path) + file_path = ensure_path(file_path) if not file_path.exists(): raise FileNotFoundError(f"The specified path does not exist: {file_path}") @@ -428,7 +429,7 @@ def _load_data_file( def save_payloads( payloads: list[dict] | dict, - output_path: os.PathLike | str, + output_path: WritablePathLike, output_file: str = "payload.jsonl", serializer: Callable[[dict], str] | None = None, ) -> Path: @@ -534,7 +535,7 @@ def save_payloads( >>> # }] >>> # } """ - output_path = Path(output_path) + output_path = ensure_path(output_path) output_path.mkdir(parents=True, exist_ok=True) output_file_path = output_path / output_file diff --git a/llmeter/results.py b/llmeter/results.py index 04bf01a..d8dcf42 100644 --- a/llmeter/results.py +++ b/llmeter/results.py @@ -12,10 +12,11 @@ import jmespath from upath import UPath as Path +from upath.types import ReadablePathLike, WritablePathLike from .endpoints import InvocationResponse from .prompt_utils import LLMeterBytesEncoder -from .utils import summary_stats_from_list +from .utils import ensure_path, summary_stats_from_list logger = logging.getLogger(__name__) @@ -87,7 +88,7 @@ class Result: n_requests: int total_test_time: float | None = None model_id: str | None = None - output_path: os.PathLike | None = None + output_path: WritablePathLike | None = None endpoint_name: str | None = None provider: str | None = None run_name: str | None = None @@ -119,7 +120,7 @@ def _update_contributed_stats(self, stats: dict[str, Number]): ) self._contributed_stats.update(stats) - def save(self, output_path: os.PathLike | str | None = None): + def save(self, output_path: WritablePathLike | None = None): """ Save the results to disk or cloud storage. @@ -147,9 +148,8 @@ def save(self, output_path: os.PathLike | str | None = None): which provides a unified interface for working with different file systems. """ - try: - output_path = Path(self.output_path or output_path) - except TypeError: + output_path = ensure_path(self.output_path or output_path) + if output_path is None: raise ValueError("No output path provided") output_path.mkdir(parents=True, exist_ok=True) @@ -204,7 +204,7 @@ def load_responses(self) -> list[InvocationResponse]: raise ValueError( "No output_path set on this Result. Cannot locate responses file." ) - responses_path = Path(self.output_path) / "responses.jsonl" + responses_path = ensure_path(self.output_path) / "responses.jsonl" with responses_path.open("r") as f: self.responses = [ InvocationResponse(**json.loads(line)) for line in f if line @@ -216,7 +216,7 @@ def load_responses(self) -> list[InvocationResponse]: @classmethod def load( - cls, result_path: os.PathLike | str, load_responses: bool = True + cls, result_path: ReadablePathLike, load_responses: bool = True ) -> "Result": """ Load run results from disk or cloud storage. @@ -246,7 +246,7 @@ def load( either file. """ - result_path = Path(result_path) + result_path = ensure_path(result_path) summary_path = result_path / "summary.json" with summary_path.open("r") as g: diff --git a/llmeter/runner.py b/llmeter/runner.py index db2f94a..d9b61c8 100644 --- a/llmeter/runner.py +++ b/llmeter/runner.py @@ -19,8 +19,9 @@ from tqdm.auto import tqdm, trange from upath import UPath as Path +from upath.types import ReadablePathLike, WritablePathLike -from llmeter.utils import now_utc +from llmeter.utils import ensure_path, now_utc if TYPE_CHECKING: # Avoid circular import: We only need typing for Callback @@ -52,11 +53,11 @@ class _RunConfig: """ endpoint: Endpoint | dict | None = None - output_path: str | Path | None = None + output_path: WritablePathLike | None = None tokenizer: Tokenizer | Any | None = None clients: int = 1 n_requests: int | None = None - payload: dict | list[dict] | os.PathLike | str | None = None + payload: dict | list[dict] | ReadablePathLike | None = None run_name: str | None = None run_description: str | None = None timeout: int | float = 60 @@ -84,7 +85,7 @@ def __post_init__(self, disable_client_progress_bar, disable_clients_progress_ba self._endpoint = self.endpoint if self.output_path is not None: - self.output_path = Path(self.output_path) + self.output_path = ensure_path(self.output_path) if self.tokenizer is None: self.tokenizer = DummyTokenizer() @@ -95,7 +96,7 @@ def __post_init__(self, disable_client_progress_bar, disable_clients_progress_ba def save( self, - output_path: os.PathLike | str | None = None, + output_path: WritablePathLike | None = None, file_name: str = "run_config.json", ): """Save the configuration to a disk or cloud storage. @@ -104,7 +105,7 @@ def save( output_path: Optional override for output folder. By default, self.output_path is used. file_name: File name to create under `output_path`. """ - output_path = Path(output_path or self.output_path) + output_path = ensure_path(output_path or self.output_path) output_path.mkdir(parents=True, exist_ok=True) run_config_path = output_path / file_name @@ -137,7 +138,7 @@ def load(cls, load_path: Path | str, file_name: str = "run_config.json"): output_path: Folder under which the configuration is stored file_name: File name within `output_path` for the run configuration JSON. """ - load_path = Path(load_path) + load_path = ensure_path(load_path) with (load_path / file_name).open() as f: config = json.load(f) config["endpoint"] = Endpoint.load(config["endpoint"]) @@ -460,7 +461,7 @@ async def _run(self): run_start_time = now_utc() _, (total_test_time, start_time, end_time) = await asyncio.gather( self._process_results_from_q( - output_path=Path(self.output_path) / "responses.jsonl" + output_path=ensure_path(self.output_path) / "responses.jsonl" if self.output_path else None, ), @@ -580,7 +581,7 @@ def _prepare_run(self, **kwargs) -> _Run: run_params["run_name"] = f"{datetime.now():%Y%m%d-%H%M}" if self.output_path and not kwargs.get("output_path"): # Run output path is nested under run name subfolder unless explicitly set: - run_params["output_path"] = Path(self.output_path) / run_params["run_name"] + run_params["output_path"] = ensure_path(self.output_path) / run_params["run_name"] # Validate that clients parameter is set and is a positive integer clients = run_params.get("clients") if clients is None: @@ -600,7 +601,7 @@ async def run( tokenizer: Tokenizer | Any | None = None, clients: int | None = None, n_requests: int | None = None, - payload: dict | list[dict] | os.PathLike | str | None = None, + payload: dict | list[dict] | ReadablePathLike | None = None, run_name: str | None = None, run_description: str | None = None, timeout: int | float | None = None, diff --git a/llmeter/utils.py b/llmeter/utils.py index d072e58..9e639d5 100644 --- a/llmeter/utils.py +++ b/llmeter/utils.py @@ -1,8 +1,15 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from datetime import datetime, timezone from itertools import filterfalse from math import isnan +from typing import TYPE_CHECKING, overload + +if TYPE_CHECKING: + from upath import UPath + from upath.types import ReadablePathLike, WritablePathLike from statistics import StatisticsError, mean, median, quantiles from typing import Any, Sequence @@ -90,3 +97,32 @@ def now_utc() -> datetime: datetime: Current UTC datetime object """ return datetime.now(timezone.utc) + + +@overload +def ensure_path(path: ReadablePathLike | WritablePathLike) -> UPath: ... + + +@overload +def ensure_path(path: ReadablePathLike | WritablePathLike | None) -> UPath | None: ... + + +def ensure_path( + path: ReadablePathLike | WritablePathLike | None = None, +) -> UPath | None: + """Normalize a path-like argument to a UPath instance. + + Converts strings, os.PathLike objects, and UPath instances into a + consistent UPath representation. Passes through None unchanged. + + Args: + path: A string, path-like object, or None. + + Returns: + A UPath instance, or None if the input was None. + """ + if path is None: + return None + from upath import UPath + + return UPath(path) From 964ac451220f62f00af9bfbd2ab2020e7de0f6a3 Mon Sep 17 00:00:00 2001 From: Alessandro Cere Date: Mon, 30 Mar 2026 15:06:10 -0700 Subject: [PATCH 6/9] fix: extend os.PathLike isinstance checks to also catch cloud UPath instances Cloud-backed UPath instances (e.g. S3Path) do not implement os.PathLike, so isinstance checks against os.PathLike alone would miss them. This caused: - Serialization: cloud UPaths skipped the path branch in JSON serializers - runner.py: cloud UPath payloads not recognized as path references, leading to unnecessary re-saving or failure to load from path Fix by checking isinstance(obj, (os.PathLike, Path)) where Path is UPath, which catches both plain pathlib.Path (via os.PathLike) and cloud UPaths (via UPath). Serialization keeps .as_posix() for cross-platform safety. --- llmeter/callbacks/cost/serde.py | 4 ++-- llmeter/endpoints/base.py | 4 ++-- llmeter/prompt_utils.py | 2 +- llmeter/results.py | 4 ++-- llmeter/runner.py | 6 +++--- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/llmeter/callbacks/cost/serde.py b/llmeter/callbacks/cost/serde.py index 5df9a7f..c03de44 100644 --- a/llmeter/callbacks/cost/serde.py +++ b/llmeter/callbacks/cost/serde.py @@ -63,7 +63,7 @@ def to_dict_recursive_generic(obj: object, **kwargs) -> dict: result.update({k: getattr(obj, k) for k in dir(obj)}) result.update(kwargs) for k, v in result.items(): - if isinstance(v, os.PathLike): + if isinstance(v, (os.PathLike, Path)): result[k] = Path(v).as_posix() elif hasattr(v, "to_dict"): result[k] = v.to_dict() @@ -199,7 +199,7 @@ def to_file( """ if default is None: def default(obj): - if isinstance(obj, os.PathLike): + if isinstance(obj, (os.PathLike, Path)): return Path(obj).as_posix() return str(obj) diff --git a/llmeter/endpoints/base.py b/llmeter/endpoints/base.py index b1d6136..f14bf56 100644 --- a/llmeter/endpoints/base.py +++ b/llmeter/endpoints/base.py @@ -284,7 +284,7 @@ def save(self, output_path: WritablePathLike) -> Path: endpoint_conf = self.to_dict() def _default_serializer(obj): - if isinstance(obj, os.PathLike): + if isinstance(obj, (os.PathLike, Path)): return Path(obj).as_posix() return str(obj) @@ -302,7 +302,7 @@ def to_dict(self) -> dict: for k, v in vars(self).items(): if k.startswith("_"): continue - if isinstance(v, os.PathLike): + if isinstance(v, (os.PathLike, Path)): v = Path(v).as_posix() endpoint_conf[k] = v endpoint_conf["endpoint_type"] = self.__class__.__name__ diff --git a/llmeter/prompt_utils.py b/llmeter/prompt_utils.py index f0a6d56..a7d0a6d 100644 --- a/llmeter/prompt_utils.py +++ b/llmeter/prompt_utils.py @@ -177,7 +177,7 @@ def default(self, obj): """ if isinstance(obj, bytes): return {"__llmeter_bytes__": base64.b64encode(obj).decode("utf-8")} - if isinstance(obj, os.PathLike): + if isinstance(obj, (os.PathLike, Path)): return Path(obj).as_posix() return super().default(obj) diff --git a/llmeter/results.py b/llmeter/results.py index d8dcf42..81f1e58 100644 --- a/llmeter/results.py +++ b/llmeter/results.py @@ -38,7 +38,7 @@ def utc_datetime_serializer(obj: Any) -> str: if obj.tzinfo is not None: obj = obj.astimezone(timezone.utc) return obj.isoformat(timespec="seconds").replace("+00:00", "Z") - if isinstance(obj, os.PathLike): + if isinstance(obj, (os.PathLike, Path)): return Path(obj).as_posix() return str(obj) @@ -69,7 +69,7 @@ def default(self, obj): # First try bytes encoding from parent if isinstance(obj, bytes): return super().default(obj) - if isinstance(obj, os.PathLike): + if isinstance(obj, (os.PathLike, Path)): return Path(obj).as_posix() # Fallback to string representation for other non-serializable types try: diff --git a/llmeter/runner.py b/llmeter/runner.py index d9b61c8..f41c39d 100644 --- a/llmeter/runner.py +++ b/llmeter/runner.py @@ -111,7 +111,7 @@ def save( config_copy = replace(self) - if self.payload and (not isinstance(self.payload, (os.PathLike, str))): + if self.payload and (not isinstance(self.payload, (Path, str))): payload_path = save_payloads(self.payload, output_path) config_copy.payload = payload_path @@ -123,7 +123,7 @@ def save( config_copy.tokenizer = Tokenizer.to_dict(self.tokenizer) def _default_serializer(obj): - if isinstance(obj, os.PathLike): + if isinstance(obj, (os.PathLike, Path)): return Path(obj).as_posix() return str(obj) @@ -174,7 +174,7 @@ def _validate_and_prepare_payload(self): This method ensures that the payload is valid and prepared for the test run. """ assert self.payload, "No payload provided" - if isinstance(self.payload, (os.PathLike, str)): + if isinstance(self.payload, (Path, str)): self.payload = list(load_payloads(self.payload)) if isinstance(self.payload, dict): self.payload = [self.payload] From a263977cbfd6957420b7f5040737e5d8417d479d Mon Sep 17 00:00:00 2001 From: Alessandro Cere Date: Mon, 30 Mar 2026 17:15:00 -0700 Subject: [PATCH 7/9] refactor: consolidate serialization into llmeter/json_utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rationalize scattered serialization utilities into a single unified module: - Create llmeter/json_utils.py with LLMeterEncoder (handles bytes, datetime, date, time, PathLike, to_dict() objects, str() fallback) and llmeter_bytes_decoder (restores __llmeter_bytes__ markers to bytes). - Remove redundant encoders: LLMeterBytesEncoder (prompt_utils), InvocationResponseEncoder (results), utc_datetime_serializer (results), and inline _default_serializer lambdas (runner, endpoints/base). - Slim down callbacks/cost/serde.py to cost-specific helpers (JSONableBase, ISerializable, from_dict_with_class, from_dict_with_class_map, to_dict_recursive_generic). Update to Python 3.10 typing (dict/type builtins instead of Dict/Type). - Standardize all to_json() methods to default to cls=LLMeterEncoder via kwargs.setdefault(), ensuring consistent encoding across InvocationResponse, Result, and JSONableBase. - Remove serializer/deserializer/cls customization params from save_payloads, load_payloads, _load_data_file — hardcode LLMeterEncoder and llmeter_bytes_decoder since custom encoders produce files that can't be loaded back without metadata. - LLMeterEncoder.default() delegates to to_dict() for objects that implement it, enabling json.dump(self, f, cls=LLMeterEncoder) without manual to_dict() calls (used in Endpoint.save). - Convert all changed files to relative imports, run ruff check + format + import sorting. - Clean up llmeter/utils.py: move upath imports to top level (it's a hard dependency), remove unnecessary from __future__ import annotations. - Add docs/reference/json_utils.md and mkdocs.yml nav entry. - Add property-based tests for datetime, date, time, PathLike, and to_dict() encoding (TestDatetimeSerializationProperties, TestPathSerializationProperties, TestToDictSerializationProperties). - Update existing tests to use LLMeterEncoder/llmeter_bytes_decoder from llmeter.json_utils instead of old aliases. All 581 unit tests pass. --- docs/reference/json_utils.md | 1 + llmeter/callbacks/cost/serde.py | 51 +++-- llmeter/endpoints/base.py | 37 ++-- llmeter/endpoints/bedrock_invoke.py | 11 +- llmeter/json_utils.py | 133 ++++++++++++ llmeter/prompt_utils.py | 99 +-------- llmeter/results.py | 74 +------ llmeter/runner.py | 30 ++- llmeter/utils.py | 14 +- mkdocs.yml | 1 + .../endpoints/test_multimodal_properties.py | 8 +- tests/unit/test_prompt_utils.py | 205 +++++++----------- tests/unit/test_results.py | 36 +-- tests/unit/test_serialization_properties.py | 197 ++++++++++++----- 14 files changed, 459 insertions(+), 438 deletions(-) create mode 100644 docs/reference/json_utils.md create mode 100644 llmeter/json_utils.py diff --git a/docs/reference/json_utils.md b/docs/reference/json_utils.md new file mode 100644 index 0000000..a30ad78 --- /dev/null +++ b/docs/reference/json_utils.md @@ -0,0 +1 @@ +:::: llmeter.json_utils diff --git a/llmeter/callbacks/cost/serde.py b/llmeter/callbacks/cost/serde.py index c03de44..28acbfa 100644 --- a/llmeter/callbacks/cost/serde.py +++ b/llmeter/callbacks/cost/serde.py @@ -8,13 +8,14 @@ import os from dataclasses import is_dataclass from datetime import date, datetime, time -from typing import Any, Callable, Dict, Protocol, Type, TypeVar +from typing import Any, Protocol, TypeVar # External Dependencies: from upath import UPath as Path from upath.types import ReadablePathLike, WritablePathLike -from llmeter.utils import ensure_path +from ...json_utils import LLMeterEncoder +from ...utils import ensure_path logger = logging.getLogger(__name__) @@ -34,6 +35,7 @@ def from_dict(cls, **kwargs) -> TSerializable: def is_dataclass_instance(obj): """Check whether `obj` is an instance of any dataclass + See: https://docs.python.org/3/library/dataclasses.html#dataclasses.is_dataclass """ return is_dataclass(obj) and not isinstance(obj, type) @@ -41,17 +43,19 @@ def is_dataclass_instance(obj): def to_dict_recursive_generic(obj: object, **kwargs) -> dict: """Convert a vaguely dataclass-like object (with maybe IJSONable fields) to a JSON-ready dict + The output dict is augmented with `_type` storing the `__class__.__name__` of the provided `obj`. + Args: obj: The object to convert **kwargs: Optional extra parameters to insert in the output dictionary """ obj_classname = obj.__class__.__name__ - result = {} if obj_classname == "dict" else {"_type": obj.__class__.__name__} + result: dict = {} if obj_classname == "dict" else {"_type": obj.__class__.__name__} if hasattr(obj, "__dict__"): - # We *don't* use dataclass asdict() here because we want our custom behaviour instead of its - # recursion: + # We *don't* use dataclass asdict() here because we want our custom behaviour instead of + # its recursion: result.update(obj.__dict__) elif ( isinstance(obj, dict) @@ -79,10 +83,12 @@ def to_dict_recursive_generic(obj: object, **kwargs) -> dict: TFromDict = TypeVar("TFromDict") -def from_dict_with_class(raw: dict, cls: Type[TFromDict], **kwargs) -> TFromDict: +def from_dict_with_class(raw: dict, cls: type[TFromDict], **kwargs) -> TFromDict: """Initialize an instance of a class from a plain dict (with optional extra kwargs) + If the input dictionary contains a `_type` key, and this doesn't match the provided `cls.__name__`, a warning will be logged. + Args: raw: A plain Python dict, for example loaded from a JSON file cls: The class to create an instance of @@ -99,12 +105,13 @@ def from_dict_with_class(raw: dict, cls: Type[TFromDict], **kwargs) -> TFromDict def from_dict_with_class_map( - raw: dict, class_map: Dict[str, Type[TFromDict]], **kwargs + raw: dict, class_map: dict[str, type[TFromDict]], **kwargs ) -> TFromDict: """Initialize an instance of a class from a plain dict (with optional extra kwargs) + Args: raw: A plain Python dict which must contain a `_type` key - classes: A mapping from `_type` string to class to create an instance of + class_map: A mapping from `_type` string to class to create an instance of **kwargs: Optional extra keyword arguments to pass to the constructor """ if "_type" not in raw: @@ -125,9 +132,9 @@ class JSONableBase: @classmethod def from_dict( - cls: Type[TJSONable], + cls: type[TJSONable], raw: dict, - alt_classes: Dict[str, TJSONable] = {}, + alt_classes: dict[str, TJSONable] = {}, **kwargs: Any, ) -> TJSONable: """Initialize an instance of this class from a plain dict (with optional extra kwargs) @@ -149,8 +156,11 @@ def from_dict( return from_dict_with_class(raw=raw, cls=cls, **kwargs) @classmethod - def from_file(cls: Type[TJSONable], input_path: ReadablePathLike, **kwargs) -> TJSONable: + def from_file( + cls: type[TJSONable], input_path: ReadablePathLike, **kwargs + ) -> TJSONable: """Initialize an instance of this class from a (local or Cloud) JSON file + Args: input_path: The path to the JSON data file. **kwargs: Optional extra keyword arguments to pass to `from_dict()` @@ -160,12 +170,12 @@ def from_file(cls: Type[TJSONable], input_path: ReadablePathLike, **kwargs) -> T return cls.from_json(f.read(), **kwargs) @classmethod - def from_json(cls: Type[TJSONable], json_string: str, **kwargs: Any) -> TJSONable: + def from_json(cls: type[TJSONable], json_string: str, **kwargs: Any) -> TJSONable: """Initialize an instance of this class from a JSON string (with optional extra kwargs) Args: json_string: A string containing valid JSON data - **kwargs: Optional extra keyword arguments to pass to `from_dict()`` + **kwargs: Optional extra keyword arguments to pass to `from_dict()` """ return cls.from_dict(json.loads(json_string), **kwargs) @@ -182,33 +192,26 @@ def to_file( self, output_path: WritablePathLike, indent: int | str | None = 4, - default: Callable[[Any], Any] | None = str, **kwargs: Any, ) -> Path: """Save the state of the object to a (local or Cloud) JSON file Args: output_path: The path where the configuration file will be saved. - indent: Optional indentation passed through to `to_json()` and therefore `json.dumps()` - default: Optional function to convert non-JSON-serializable objects to strings, passed - through to `to_json()` and therefore to `json.dumps()` + indent: Optional indentation passed through to `to_json()` and therefore + `json.dumps()` **kwargs: Optional extra keyword arguments to pass to `to_json()` Returns: output_path: Universal Path representation of the target file. """ - if default is None: - def default(obj): - if isinstance(obj, (os.PathLike, Path)): - return Path(obj).as_posix() - return str(obj) - output_path = ensure_path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w") as f: - f.write(self.to_json(indent=indent, default=default, **kwargs)) + f.write(self.to_json(indent=indent, **kwargs)) return output_path def to_json(self, **kwargs) -> str: """Serialize this object to JSON, with optional kwargs passed through to `json.dumps()`""" + kwargs.setdefault("cls", LLMeterEncoder) return json.dumps(self.to_dict(), **kwargs) diff --git a/llmeter/endpoints/base.py b/llmeter/endpoints/base.py index f14bf56..32c55a8 100644 --- a/llmeter/endpoints/base.py +++ b/llmeter/endpoints/base.py @@ -16,7 +16,8 @@ from upath import UPath as Path from upath.types import ReadablePathLike, WritablePathLike -from llmeter.utils import ensure_path +from ..json_utils import LLMeterEncoder +from ..utils import ensure_path # @dataclass(slots=True) @@ -52,21 +53,21 @@ class InvocationResponse: def to_json(self, **kwargs) -> str: """ Convert InvocationResponse to JSON string with binary content support. - + This method serializes the InvocationResponse object to a JSON string, with automatic handling of binary content (bytes objects) in the input_payload field. Binary data is converted to base64-encoded strings wrapped in marker objects, enabling JSON serialization while preserving the ability to restore the original bytes during deserialization. - + Binary Content Handling: When the input_payload contains bytes objects (e.g., images, video), they are automatically converted to base64-encoded strings and wrapped in marker objects with the key "__llmeter_bytes__". This approach enables JSON serialization of multimodal payloads while maintaining round-trip integrity. - + The marker object format is: {"__llmeter_bytes__": ""} - + For non-serializable types other than bytes, the encoder falls back to str() representation to ensure the response can always be serialized. @@ -78,7 +79,7 @@ def to_json(self, **kwargs) -> str: Examples: Serialize a response with binary content in the payload: - + >>> # Create a response with binary image data in the payload >>> with open("image.jpg", "rb") as f: ... image_bytes = f.read() @@ -106,9 +107,9 @@ def to_json(self, **kwargs) -> str: >>> # The JSON string contains marker objects for binary data >>> "__llmeter_bytes__" in json_str True - + Serialize with pretty printing: - + >>> json_str = response.to_json(indent=2) >>> print(json_str) { @@ -136,14 +137,14 @@ def to_json(self, **kwargs) -> str: "num_tokens_output": 15, ... } - + Round-trip serialization with binary preservation: - + >>> # Serialize to JSON >>> json_str = response.to_json() >>> # Parse back to dict >>> import json - >>> from llmeter.prompt_utils import llmeter_bytes_decoder + >>> from llmeter.json_utils import llmeter_bytes_decoder >>> response_dict = json.loads(json_str, object_hook=llmeter_bytes_decoder) >>> # Binary data is preserved >>> original_bytes = response.input_payload["messages"][0]["content"][1]["image"]["source"]["bytes"] @@ -151,9 +152,8 @@ def to_json(self, **kwargs) -> str: >>> original_bytes == restored_bytes True """ - from llmeter.results import InvocationResponseEncoder - - return json.dumps(asdict(self), cls=InvocationResponseEncoder, **kwargs) + kwargs.setdefault("cls", LLMeterEncoder) + return json.dumps(asdict(self), **kwargs) @staticmethod def error_output( @@ -281,14 +281,7 @@ def save(self, output_path: WritablePathLike) -> Path: output_path = ensure_path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w") as f: - endpoint_conf = self.to_dict() - - def _default_serializer(obj): - if isinstance(obj, (os.PathLike, Path)): - return Path(obj).as_posix() - return str(obj) - - json.dump(endpoint_conf, f, indent=4, default=_default_serializer) + json.dump(self, f, indent=4, cls=LLMeterEncoder) return output_path def to_dict(self) -> dict: diff --git a/llmeter/endpoints/bedrock_invoke.py b/llmeter/endpoints/bedrock_invoke.py index f6216c7..0b0c118 100644 --- a/llmeter/endpoints/bedrock_invoke.py +++ b/llmeter/endpoints/bedrock_invoke.py @@ -8,10 +8,11 @@ from uuid import uuid4 import boto3 +import jmespath from botocore.config import Config from botocore.exceptions import ClientError -import jmespath +from ..json_utils import LLMeterEncoder from .base import Endpoint, InvocationResponse logger = logging.getLogger(__name__) @@ -261,9 +262,7 @@ def invoke(self, payload: dict) -> InvocationResponse: raise TypeError("Payload must be a dictionary") try: - from llmeter.prompt_utils import LLMeterBytesEncoder - - req_body = json.dumps(payload, cls=LLMeterBytesEncoder).encode("utf-8") + req_body = json.dumps(payload, cls=LLMeterEncoder).encode("utf-8") try: start_t = time.perf_counter() client_response = self._bedrock_client.invoke_model( # type: ignore @@ -355,9 +354,7 @@ def __init__( ) def invoke(self, payload: dict) -> InvocationResponse: - from llmeter.prompt_utils import LLMeterBytesEncoder - - req_body = json.dumps(payload, cls=LLMeterBytesEncoder).encode("utf-8") + req_body = json.dumps(payload, cls=LLMeterEncoder).encode("utf-8") try: start_t = time.perf_counter() client_response = self._bedrock_client.invoke_model_with_response_stream( # type: ignore diff --git a/llmeter/json_utils.py b/llmeter/json_utils.py new file mode 100644 index 0000000..167e755 --- /dev/null +++ b/llmeter/json_utils.py @@ -0,0 +1,133 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""JSON encoding and decoding helpers used across LLMeter + +Provides a unified :class:`json.JSONEncoder` subclass and a matching decoder hook for +round-tripping binary content (``bytes``) through JSON via base64 marker objects, while also +handling ``datetime``, ``os.PathLike``, and objects that implement ``to_dict()``. + +Example:: + + import json + from llmeter.json_utils import LLMeterEncoder, llmeter_bytes_decoder + + payload = {"image": {"bytes": b"\\xff\\xd8\\xff\\xe0"}} + + # Serialize + encoded = json.dumps(payload, cls=LLMeterEncoder) + + # Deserialize (bytes are restored automatically) + decoded = json.loads(encoded, object_hook=llmeter_bytes_decoder) + assert decoded == payload +""" + +import base64 +import json +import os +from datetime import date, datetime, time, timezone +from typing import Any + +from upath import UPath as Path + + +class LLMeterEncoder(json.JSONEncoder): + """JSON encoder that handles common non-serializable types found in LLMeter. + + Type handling (checked in order): + + * Objects with a ``to_dict()`` method — delegates to that method. + * ``bytes`` — wrapped in a ``{"__llmeter_bytes__": ""}`` marker object + so that :func:`llmeter_bytes_decoder` can restore them on the way back. + * ``datetime`` — converted to a UTC ISO-8601 string with a ``Z`` suffix. + * ``date`` / ``time`` — converted via ``.isoformat()``. + * ``os.PathLike`` — converted to a POSIX path string. + * Anything else — ``str()`` fallback (returns ``None`` if that also fails). + + Customization: + To handle additional types, subclass ``LLMeterEncoder`` and override + :meth:`default`. Call ``super().default(obj)`` as a fallback so that the + built-in type handling is preserved. Because the encoder is used + consistently across LLMeter (payloads, results, run configs), any type + that implements a ``to_dict()`` method will be serialized automatically + without needing a custom encoder. + + Example:: + + Subclassing to handle a custom type: + + >>> import json + >>> import numpy as np + >>> from llmeter.json_utils import LLMeterEncoder + >>> + >>> class MyEncoder(LLMeterEncoder): + ... def default(self, obj): + ... if isinstance(obj, np.ndarray): + ... return obj.tolist() + ... return super().default(obj) + ... + >>> json.dumps({"data": np.array([1, 2, 3])}, cls=MyEncoder) + '{"data": [1, 2, 3]}' + + Using ``to_dict()`` (no subclassing needed): + + >>> class MyPayload: + ... def __init__(self, model_id, temperature): + ... self.model_id = model_id + ... self.temperature = temperature + ... def to_dict(self): + ... return {"model_id": self.model_id, "temperature": self.temperature} + ... + >>> json.dumps({"payload": MyPayload("gpt-4", 0.7)}, cls=LLMeterEncoder) + '{"payload": {"model_id": "gpt-4", "temperature": 0.7}}' + """ + + def default(self, obj: Any) -> Any: + """Encode a single non-serializable object. + + Args: + obj: The object that the default encoder could not handle. + + Returns: + A JSON-serializable representation of *obj*. + """ + if hasattr(obj, "to_dict") and callable(obj.to_dict): + return obj.to_dict() + if isinstance(obj, bytes): + return {"__llmeter_bytes__": base64.b64encode(obj).decode("utf-8")} + if isinstance(obj, datetime): + if obj.tzinfo is not None: + obj = obj.astimezone(timezone.utc) + return obj.isoformat(timespec="seconds").replace("+00:00", "Z") + if isinstance(obj, (date, time)): + return obj.isoformat() + if isinstance(obj, (os.PathLike, Path)): + return Path(obj).as_posix() + try: + return str(obj) + except Exception: + return None + + +def llmeter_bytes_decoder(dct: dict) -> dict | bytes: + """Decode ``__llmeter_bytes__`` marker objects back to ``bytes``. + + Intended for use as the ``object_hook`` argument to :func:`json.loads` or + :func:`json.load`. Marker objects produced by :class:`LLMeterEncoder` are + detected and converted back to ``bytes``; all other dicts pass through unchanged. + + Args: + dct: A dictionary produced by the JSON parser. + + Returns: + The original ``bytes`` if *dct* is a marker object, otherwise *dct* unchanged. + + Example:: + + >>> import json + >>> from llmeter.json_utils import llmeter_bytes_decoder + >>> json.loads('{"__llmeter_bytes__": "/9j/4A=="}', object_hook=llmeter_bytes_decoder) + b'\\xff\\xd8\\xff\\xe0' + """ + if "__llmeter_bytes__" in dct and len(dct) == 1: + return base64.b64decode(dct["__llmeter_bytes__"]) + return dct diff --git a/llmeter/prompt_utils.py b/llmeter/prompt_utils.py index a7d0a6d..b7b190c 100644 --- a/llmeter/prompt_utils.py +++ b/llmeter/prompt_utils.py @@ -1,10 +1,8 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import base64 import json import logging -import os import random from dataclasses import dataclass from itertools import product @@ -13,6 +11,7 @@ from upath import UPath as Path from upath.types import ReadablePathLike, WritablePathLike +from .json_utils import LLMeterEncoder, llmeter_bytes_decoder from .tokenizers import DummyTokenizer, Tokenizer from .utils import DeferredError, ensure_path @@ -150,63 +149,6 @@ def detect_format_from_file(file_path: str) -> str | None: return detect_format_from_extension(file_path) -class LLMeterBytesEncoder(json.JSONEncoder): - """Custom JSON encoder that handles bytes objects by converting them to base64. - - This encoder wraps bytes objects in a marker object with the key "__llmeter_bytes__" - to enable round-trip serialization and deserialization of binary content in payloads. - - Example: - >>> encoder = LLMeterBytesEncoder() - >>> payload = {"image": {"bytes": b"\\xff\\xd8\\xff\\xe0"}} - >>> json.dumps(payload, cls=LLMeterBytesEncoder) - '{"image": {"bytes": {"__llmeter_bytes__": "/9j/4A=="}}}' - """ - - def default(self, obj): - """Encode bytes objects as marker objects with base64 strings. - - Args: - obj: Object to encode - - Returns: - dict: Marker object for bytes, or delegates to parent for other types - - Raises: - TypeError: If object is not JSON serializable - """ - if isinstance(obj, bytes): - return {"__llmeter_bytes__": base64.b64encode(obj).decode("utf-8")} - if isinstance(obj, (os.PathLike, Path)): - return Path(obj).as_posix() - return super().default(obj) - - -def llmeter_bytes_decoder(dct: dict) -> dict | bytes: - """Decode marker objects back to bytes during JSON deserialization. - - This function is used as an object_hook for json.loads to detect and decode - marker objects created by LLMeterBytesEncoder back to bytes during deserialization. - - Args: - dct: Dictionary from JSON parsing - - Returns: - bytes if marker detected, otherwise original dict - - Example: - >>> marker = {"__llmeter_bytes__": "/9j/4A=="} - >>> llmeter_bytes_decoder(marker) - b'\\xff\\xd8\\xff\\xe0' - >>> regular_dict = {"key": "value"} - >>> llmeter_bytes_decoder(regular_dict) - {'key': 'value'} - """ - if "__llmeter_bytes__" in dct and len(dct) == 1: - return base64.b64decode(dct["__llmeter_bytes__"]) - return dct - - @dataclass class CreatePromptCollection: input_lengths: list[int] @@ -300,14 +242,13 @@ def load_prompts( def load_payloads( file_path: ReadablePathLike, - deserializer: Callable[[str], dict] | None = None, ) -> Iterator[dict]: """ Load JSON payload(s) from a file or directory with binary content support. This function reads JSON data from either a single file or multiple files in a directory. It supports both .json and .jsonl file formats. Binary content - (bytes objects) that were serialized using LLMeterBytesEncoder are automatically + (bytes objects) that were serialized using LLMeterEncoder are automatically restored during deserialization. Binary Content Handling: @@ -321,9 +262,6 @@ def load_payloads( Args: file_path (Union[Path, str]): Path to a JSON file or a directory containing JSON files. Can be a string or a Path object. - deserializer (Callable[[str], dict], optional): Custom deserializer function - that takes a JSON string and returns a dict. If None, uses json.loads with - llmeter_bytes_decoder for automatic binary content handling. Defaults to None. Yields: dict: Each JSON object loaded from the file(s). @@ -390,15 +328,13 @@ def load_payloads( raise FileNotFoundError(f"The specified path does not exist: {file_path}") if file_path.is_file(): - yield from _load_data_file(file_path, deserializer) + yield from _load_data_file(file_path) else: for file in file_path.glob("*.json*"): - yield from _load_data_file(file, deserializer) + yield from _load_data_file(file) -def _load_data_file( - file: Path, deserializer: Callable[[str], dict] | None = None -) -> Iterator[dict]: +def _load_data_file(file: Path) -> Iterator[dict]: try: with file.open(mode="r") as f: if file.suffix.lower() in [".jsonl", ".manifest"]: @@ -406,21 +342,13 @@ def _load_data_file( try: if not line.strip(): continue - if deserializer is None: - yield json.loads( - line.strip(), object_hook=llmeter_bytes_decoder - ) - else: - yield deserializer(line.strip()) + yield json.loads( + line.strip(), object_hook=llmeter_bytes_decoder + ) except json.JSONDecodeError as e: print(f"Error decoding JSON in {file}: {e}") else: # Assume it's a regular JSON file - if deserializer is None: - yield json.load(f, object_hook=llmeter_bytes_decoder) - else: - # For custom deserializer, read the entire file as string - f.seek(0) - yield deserializer(f.read()) + yield json.load(f, object_hook=llmeter_bytes_decoder) except IOError as e: print(f"Error reading file {file}: {e}") except json.JSONDecodeError as e: @@ -431,7 +359,6 @@ def save_payloads( payloads: list[dict] | dict, output_path: WritablePathLike, output_file: str = "payload.jsonl", - serializer: Callable[[dict], str] | None = None, ) -> Path: """ Save payloads to a file with support for binary content. @@ -453,9 +380,6 @@ def save_payloads( at any nesting level. output_path (Union[Path, str]): The directory path where the output file should be saved. output_file (str, optional): The name of the output file. Defaults to "payload.jsonl". - serializer (Callable[[dict], str], optional): Custom serializer function that takes - a dict and returns a JSON string. If None, uses LLMeterBytesEncoder for automatic - binary content handling. Defaults to None. Returns: Path: The path to the output file. @@ -543,8 +467,5 @@ def save_payloads( payloads = [payloads] with output_file_path.open(mode="w") as f: for payload in payloads: - if serializer is None: - f.write(json.dumps(payload, cls=LLMeterBytesEncoder) + "\n") - else: - f.write(serializer(payload) + "\n") + f.write(json.dumps(payload, cls=LLMeterEncoder) + "\n") return output_file_path diff --git a/llmeter/results.py b/llmeter/results.py index 81f1e58..f5481a3 100644 --- a/llmeter/results.py +++ b/llmeter/results.py @@ -3,81 +3,22 @@ import json import logging -import os from dataclasses import asdict, dataclass -from datetime import datetime, timezone +from datetime import datetime from functools import cached_property from numbers import Number from typing import Any, Sequence import jmespath -from upath import UPath as Path from upath.types import ReadablePathLike, WritablePathLike from .endpoints import InvocationResponse -from .prompt_utils import LLMeterBytesEncoder +from .json_utils import LLMeterEncoder from .utils import ensure_path, summary_stats_from_list logger = logging.getLogger(__name__) -def utc_datetime_serializer(obj: Any) -> str: - """ - Serialize datetime objects to UTC ISO format strings. - - Args: - obj: Object to serialize. If datetime, converts to ISO format string with 'Z' timezone. - Otherwise returns string representation. - - Returns: - str: ISO format string with 'Z' timezone for datetime objects, or string representation - for other objects. - """ - if isinstance(obj, datetime): - # Convert to UTC if timezone is set - if obj.tzinfo is not None: - obj = obj.astimezone(timezone.utc) - return obj.isoformat(timespec="seconds").replace("+00:00", "Z") - if isinstance(obj, (os.PathLike, Path)): - return Path(obj).as_posix() - return str(obj) - - -class InvocationResponseEncoder(LLMeterBytesEncoder): - """Extended encoder for InvocationResponse with fallback to str() for non-serializable types. - - This encoder extends LLMeterBytesEncoder to handle bytes objects (via parent class) - and adds a fallback mechanism for other non-serializable types by converting them - to strings. This is particularly useful for InvocationResponse objects that may - contain various non-standard types. - - Example: - >>> response = InvocationResponse(input_payload={"image": {"bytes": b"\\xff\\xd8"}}) - >>> json.dumps(asdict(response), cls=InvocationResponseEncoder) - '{"input_payload": {"image": {"bytes": {"__llmeter_bytes__": "/9g="}}}}' - """ - - def default(self, obj): - """Encode objects with bytes support and str() fallback. - - Args: - obj: Object to encode - - Returns: - Encoded representation or None if encoding fails - """ - # First try bytes encoding from parent - if isinstance(obj, bytes): - return super().default(obj) - if isinstance(obj, (os.PathLike, Path)): - return Path(obj).as_posix() - # Fallback to string representation for other non-serializable types - try: - return str(obj) - except Exception: - return None - - @dataclass class Result: """Results of a test run.""" @@ -97,7 +38,7 @@ class Result: end_time: datetime | None = None def __str__(self): - return json.dumps(self.stats, indent=4, default=utc_datetime_serializer) + return json.dumps(self.stats, indent=4, cls=LLMeterEncoder) def __post_init__(self): """Initialize the Result instance.""" @@ -158,20 +99,21 @@ def save(self, output_path: WritablePathLike | None = None): stats_path = output_path / "stats.json" with summary_path.open("w") as f, stats_path.open("w") as s: f.write(self.to_json(indent=4)) - s.write(json.dumps(self.stats, indent=4, default=utc_datetime_serializer)) + s.write(json.dumps(self.stats, indent=4, cls=LLMeterEncoder)) responses_path = output_path / "responses.jsonl" if not responses_path.exists(): with responses_path.open("w") as f: for response in self.responses: - f.write(json.dumps(asdict(response), cls=InvocationResponseEncoder) + "\n") + f.write(json.dumps(asdict(response), cls=LLMeterEncoder) + "\n") def to_json(self, **kwargs): """Return the results as a JSON string.""" + kwargs.setdefault("cls", LLMeterEncoder) summary = { k: o for k, o in asdict(self).items() if k not in ["responses", "stats"] } - return json.dumps(summary, default=utc_datetime_serializer, **kwargs) + return json.dumps(summary, **kwargs) def to_dict(self, include_responses: bool = False): """Return the results as a dictionary with JSON-serializable values.""" @@ -179,7 +121,7 @@ def to_dict(self, include_responses: bool = False): # Serialize datetime objects so stats dict is always JSON-safe for key in ("start_time", "end_time"): if key in data and isinstance(data[key], datetime): - data[key] = utc_datetime_serializer(data[key]) + data[key] = LLMeterEncoder().default(data[key]) if include_responses: return data return {k: v for k, v in data.items() if k not in ["responses", "stats"]} diff --git a/llmeter/runner.py b/llmeter/runner.py index f41c39d..40496ad 100644 --- a/llmeter/runner.py +++ b/llmeter/runner.py @@ -21,13 +21,14 @@ from upath import UPath as Path from upath.types import ReadablePathLike, WritablePathLike -from llmeter.utils import ensure_path, now_utc +from .utils import ensure_path, now_utc if TYPE_CHECKING: # Avoid circular import: We only need typing for Callback from .callbacks.base import Callback from .endpoints.base import Endpoint, InvocationResponse +from .json_utils import LLMeterEncoder from .prompt_utils import load_payloads, save_payloads from .results import Result from .tokenizers import DummyTokenizer, Tokenizer @@ -122,13 +123,8 @@ def save( if not isinstance(self.tokenizer, dict): config_copy.tokenizer = Tokenizer.to_dict(self.tokenizer) - def _default_serializer(obj): - if isinstance(obj, (os.PathLike, Path)): - return Path(obj).as_posix() - return str(obj) - with run_config_path.open("w") as f: - f.write(json.dumps(asdict(config_copy), default=_default_serializer, indent=4)) + f.write(json.dumps(asdict(config_copy), cls=LLMeterEncoder, indent=4)) @classmethod def load(cls, load_path: Path | str, file_name: str = "run_config.json"): @@ -155,15 +151,15 @@ class _Run(_RunConfig): """ def __post_init__(self, disable_client_progress_bar, disable_clients_progress_bar): - assert ( - self.run_name is not None - ), "Test Run must be created with an explicit run_name" + assert self.run_name is not None, ( + "Test Run must be created with an explicit run_name" + ) super().__post_init__(disable_client_progress_bar, disable_clients_progress_bar) - assert ( - self.endpoint is not None - ), "Test Run must be created with an explicit Endpoint" + assert self.endpoint is not None, ( + "Test Run must be created with an explicit Endpoint" + ) self._validate_and_prepare_payload() self._responses = [] @@ -409,7 +405,7 @@ async def _invoke_n_c( end_t = time.perf_counter() total_test_time = end_t - start_t logger.info( - f"Generated {clients} connections with {n_requests} invocations each in {total_test_time*1000:.2f} seconds" + f"Generated {clients} connections with {n_requests} invocations each in {total_test_time * 1000:.2f} seconds" ) # Signal the token counting task to exit @@ -480,7 +476,7 @@ async def _run(self): return result self._progress_bar.close() - logger.info(f"Test completed in {total_test_time*1000:.2f} seconds.") + logger.info(f"Test completed in {total_test_time * 1000:.2f} seconds.") result = replace( result, @@ -581,7 +577,9 @@ def _prepare_run(self, **kwargs) -> _Run: run_params["run_name"] = f"{datetime.now():%Y%m%d-%H%M}" if self.output_path and not kwargs.get("output_path"): # Run output path is nested under run name subfolder unless explicitly set: - run_params["output_path"] = ensure_path(self.output_path) / run_params["run_name"] + run_params["output_path"] = ( + ensure_path(self.output_path) / run_params["run_name"] + ) # Validate that clients parameter is set and is a positive integer clients = run_params.get("clients") if clients is None: diff --git a/llmeter/utils.py b/llmeter/utils.py index 9e639d5..c9a0170 100644 --- a/llmeter/utils.py +++ b/llmeter/utils.py @@ -1,17 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - from datetime import datetime, timezone from itertools import filterfalse from math import isnan -from typing import TYPE_CHECKING, overload - -if TYPE_CHECKING: - from upath import UPath - from upath.types import ReadablePathLike, WritablePathLike from statistics import StatisticsError, mean, median, quantiles -from typing import Any, Sequence +from typing import Any, Sequence, overload + +from upath import UPath +from upath.types import ReadablePathLike, WritablePathLike class DeferredError: @@ -123,6 +119,4 @@ def ensure_path( """ if path is None: return None - from upath import UPath - return UPath(path) diff --git a/mkdocs.yml b/mkdocs.yml index 99d5198..1c7e808 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -61,6 +61,7 @@ nav: - openai: reference/endpoints/openai.md - sagemaker: reference/endpoints/sagemaker.md - experiments: reference/experiments.md + - json_utils: reference/json_utils.md - plotting: - reference/plotting/index.md - results: reference/results.md diff --git a/tests/unit/endpoints/test_multimodal_properties.py b/tests/unit/endpoints/test_multimodal_properties.py index e240546..8d8736a 100644 --- a/tests/unit/endpoints/test_multimodal_properties.py +++ b/tests/unit/endpoints/test_multimodal_properties.py @@ -13,12 +13,12 @@ from pathlib import Path import pytest -from hypothesis import HealthCheck, given, settings, strategies as st +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st from llmeter.endpoints.bedrock import BedrockBase +from llmeter.json_utils import LLMeterEncoder, llmeter_bytes_decoder from llmeter.prompt_utils import ( - LLMeterBytesEncoder, - llmeter_bytes_decoder, load_payloads, save_payloads, ) @@ -598,7 +598,7 @@ def create_multimodal_payload(input_text, **kwargs): assert len(image_blocks) == 1 # Verify the payload can be serialized - json_str = json.dumps(payload, cls=LLMeterBytesEncoder) + json_str = json.dumps(payload, cls=LLMeterEncoder) assert len(json_str) > 0 # Verify it can be deserialized diff --git a/tests/unit/test_prompt_utils.py b/tests/unit/test_prompt_utils.py index ca74ba8..cac7ddb 100644 --- a/tests/unit/test_prompt_utils.py +++ b/tests/unit/test_prompt_utils.py @@ -13,18 +13,18 @@ from llmeter.prompt_utils import ( CreatePromptCollection, - LLMeterBytesEncoder, load_payloads, load_prompts, save_payloads, ) +from llmeter.json_utils import LLMeterEncoder, llmeter_bytes_decoder from llmeter.tokenizers import DummyTokenizer -class TestLLMeterBytesEncoder: - """Unit tests for LLMeterBytesEncoder class. +class TestLLMeterEncoder: + """Unit tests for LLMeterEncoder class. - These tests verify specific examples and edge cases for the LLMeterBytesEncoder + These tests verify specific examples and edge cases for the LLMeterEncoder class, complementing the property-based tests. Requirements: 1.1, 1.2, 1.3, 1.6 @@ -38,7 +38,7 @@ def test_simple_bytes_object_serialization(self): payload = {"data": b"hello world"} # Serialize using the encoder - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Verify it's valid JSON parsed = json.loads(serialized) @@ -78,7 +78,7 @@ def test_nested_bytes_in_dict_structure(self): } # Serialize - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Verify it's valid JSON parsed = json.loads(serialized) @@ -107,7 +107,7 @@ def test_empty_bytes_object(self): payload = {"empty": b""} # Serialize - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Verify it's valid JSON parsed = json.loads(serialized) @@ -134,7 +134,7 @@ def test_large_binary_data_1mb(self): payload = {"large_image": large_data} # Serialize - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Verify it's valid JSON parsed = json.loads(serialized) @@ -161,7 +161,7 @@ def test_multiple_bytes_objects_in_payload(self): } # Serialize - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Verify it's valid JSON parsed = json.loads(serialized) @@ -195,7 +195,7 @@ def test_bytes_in_list(self): payload = {"images": [b"image1", b"image2", b"image3"]} # Serialize - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Verify it's valid JSON parsed = json.loads(serialized) @@ -223,7 +223,7 @@ def test_mixed_types_with_bytes(self): } # Serialize - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Verify it's valid JSON parsed = json.loads(serialized) @@ -255,7 +255,6 @@ def test_marker_object_decoding(self): Validates: Requirements 2.1, 2.2, 2.3 """ - from llmeter.prompt_utils import llmeter_bytes_decoder # Create a marker object with base64-encoded bytes marker = {"__llmeter_bytes__": "aGVsbG8gd29ybGQ="} # "hello world" in base64 @@ -272,7 +271,6 @@ def test_non_marker_dict_passthrough(self): Validates: Requirements 2.4 """ - from llmeter.prompt_utils import llmeter_bytes_decoder # Regular dict without marker key regular_dict = {"key": "value", "number": 42, "nested": {"data": "test"}} @@ -291,7 +289,6 @@ def test_invalid_base64_error_handling(self): """ import binascii - from llmeter.prompt_utils import llmeter_bytes_decoder # Marker with invalid base64 string invalid_marker = {"__llmeter_bytes__": "not-valid-base64!!!"} @@ -307,7 +304,6 @@ def test_multi_key_dict_with_marker_key_not_decoded(self): Validates: Requirements 2.4 """ - from llmeter.prompt_utils import llmeter_bytes_decoder # Dict with marker key but also other keys (should not be decoded) multi_key_dict = { @@ -329,7 +325,6 @@ def test_empty_bytes_decoding(self): Validates: Requirements 2.1, 2.3 """ - from llmeter.prompt_utils import llmeter_bytes_decoder # Marker with empty base64 string (empty bytes) empty_marker = {"__llmeter_bytes__": ""} @@ -350,7 +345,6 @@ def test_large_binary_data_decoding(self): import base64 import os - from llmeter.prompt_utils import llmeter_bytes_decoder # Create 1MB of random binary data large_data = os.urandom(1024 * 1024) @@ -372,7 +366,6 @@ def test_nested_structure_with_marker(self): Validates: Requirements 2.1, 2.5 """ - from llmeter.prompt_utils import llmeter_bytes_decoder # JSON string with nested marker objects json_str = json.dumps( @@ -415,7 +408,6 @@ def test_dict_without_marker_key(self): Validates: Requirements 2.4 """ - from llmeter.prompt_utils import llmeter_bytes_decoder # Dict without the marker key normal_dict = {"data": "value", "count": 123} @@ -433,7 +425,6 @@ def test_marker_with_special_characters(self): """ import base64 - from llmeter.prompt_utils import llmeter_bytes_decoder # Binary data with special characters special_data = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01" @@ -849,33 +840,23 @@ def test_load_payloads_bytes_correctly_restored(self): assert isinstance(loaded[1]["data"], bytes) assert isinstance(loaded[2]["nested"]["deep"]["data"], bytes) - def test_load_payloads_custom_deserializer_parameter(self): - """Test custom deserializer parameter works correctly. + def test_load_payloads_bytes_marker_round_trip(self): + """Test that bytes markers are decoded automatically during load. Validates: Requirements 5.2, 5.4 """ with tempfile.TemporaryDirectory() as tmpdir: jsonl_file = Path(tmpdir) / "payload.jsonl" - # Write a simple payload - payload = {"test": "data", "number": 42} - with jsonl_file.open("w") as f: - f.write(json.dumps(payload) + "\n") - - # Custom deserializer that adds a field - def custom_deserializer(json_str): - data = json.loads(json_str) - data["custom_field"] = "added_by_deserializer" - return data + # Write a payload with a bytes marker (as save_payloads would) + payload = {"image": b"\xff\xd8\xff\xe0"} + save_payloads(payload, Path(tmpdir)) - # Load with custom deserializer - loaded = list(load_payloads(jsonl_file, deserializer=custom_deserializer)) + # Load and verify bytes are restored + loaded = list(load_payloads(jsonl_file)) - # Verify custom deserializer was used assert len(loaded) == 1 - assert loaded[0]["test"] == "data" - assert loaded[0]["number"] == 42 - assert loaded[0]["custom_field"] == "added_by_deserializer" + assert loaded[0]["image"] == b"\xff\xd8\xff\xe0" def test_load_payloads_backward_compatibility_old_format(self): """Test backward compatibility - old format loads successfully. @@ -1198,33 +1179,30 @@ def test_save_payloads_file_contains_valid_json_with_markers(self): else: assert "__llmeter_bytes__" in parsed["nested"]["data"] - def test_save_payloads_custom_serializer_parameter(self): - """Test custom serializer parameter works correctly. + def test_save_payloads_handles_to_dict_objects(self): + """Test that payloads with to_dict() objects are serialized correctly. + + Objects implementing to_dict() are handled by LLMeterEncoder automatically. Validates: Requirements 5.1, 5.3 """ with tempfile.TemporaryDirectory() as tmpdir: output_path = Path(tmpdir) - payload = {"test": "data", "number": 42} - # Custom serializer that adds a prefix - def custom_serializer(payload_dict): - return "CUSTOM:" + json.dumps(payload_dict) + class CustomObj: + def to_dict(self): + return {"custom": "value"} - result_path = save_payloads( - payload, output_path, serializer=custom_serializer - ) + payload = {"test": "data", "obj": CustomObj()} + + result_path = save_payloads(payload, output_path) - # Verify custom serializer was used with result_path.open("r") as f: - line = f.readline() + line = f.readline().strip() - assert line.startswith("CUSTOM:") - # Remove prefix and verify it's valid JSON - json_part = line[7:].strip() - parsed = json.loads(json_part) + parsed = json.loads(line) assert parsed["test"] == "data" - assert parsed["number"] == 42 + assert parsed["obj"] == {"custom": "value"} def test_save_payloads_backward_compatibility_no_bytes(self): """Test backward compatibility when payload has no bytes. @@ -1371,8 +1349,8 @@ class TestErrorHandling: Requirements: 6.1, 6.2, 6.3, 6.4 """ - def test_unserializable_type_error_message(self): - """Test that unserializable types raise TypeError with type information. + def test_unserializable_type_str_fallback(self): + """Test that unserializable types are serialized via str() fallback. Validates: Requirements 6.1 """ @@ -1388,13 +1366,9 @@ def __init__(self, value): "custom_object": CustomUnserializableObject(42), } - # Should raise TypeError when trying to serialize - with pytest.raises(TypeError) as exc_info: - json.dumps(payload, cls=LLMeterBytesEncoder) - - # Verify error message contains type information - error_message = str(exc_info.value) - assert "CustomUnserializableObject" in error_message + # The unified encoder falls back to str() for unknown types + result = json.dumps(payload, cls=LLMeterEncoder) + assert "test-model" in result def test_invalid_json_error_handling(self): """Test that invalid JSON raises JSONDecodeError with descriptive message. @@ -1422,7 +1396,6 @@ def test_invalid_base64_error_handling(self): """ import binascii - from llmeter.prompt_utils import llmeter_bytes_decoder # Create a marker object with invalid base64 (incorrect padding) # Base64 strings must have length that is a multiple of 4 @@ -1466,64 +1439,45 @@ def test_invalid_base64_in_load_payloads(self): with pytest.raises(binascii.Error): list(load_payloads(jsonl_file)) - def test_custom_serializer_exception_propagation(self): - """Test that custom serializer exceptions are propagated correctly. + def test_save_payloads_encoder_error_propagation(self): + """Test that encoder errors from LLMeterEncoder are propagated correctly. Validates: Requirements 6.4 """ - - class CustomSerializerError(Exception): - """Custom exception for testing.""" - - pass - - def failing_serializer(payload): - """A custom serializer that always raises an exception.""" - raise CustomSerializerError("Custom serializer failed!") - with tempfile.TemporaryDirectory() as tmpdir: output_path = Path(tmpdir) - payload = {"test": "data"} - # Should propagate the custom exception - with pytest.raises(CustomSerializerError) as exc_info: - save_payloads(payload, output_path, serializer=failing_serializer) + # An object whose str() raises — LLMeterEncoder returns None for these, + # which is valid JSON, so no error is raised. + class FailingStr: + def __str__(self): + raise RuntimeError("Cannot convert") + + payload = {"test": FailingStr()} + result_path = save_payloads(payload, output_path) - # Verify the exception message is preserved - assert "Custom serializer failed!" in str(exc_info.value) + with result_path.open("r") as f: + parsed = json.loads(f.readline()) + assert parsed["test"] is None - def test_custom_deserializer_exception_propagation(self): - """Test that custom deserializer exceptions are propagated correctly. + def test_invalid_json_deserialization_error(self): + """Test that invalid JSON in files produces errors gracefully. Validates: Requirements 6.4 """ - - class CustomDeserializerError(Exception): - """Custom exception for testing.""" - - pass - - def failing_deserializer(json_str): - """A custom deserializer that always raises an exception.""" - raise CustomDeserializerError("Custom deserializer failed!") - with tempfile.TemporaryDirectory() as tmpdir: jsonl_file = Path(tmpdir) / "payload.jsonl" - # Write a valid payload - payload = {"test": "data"} + # Write invalid JSON with jsonl_file.open("w") as f: - f.write(json.dumps(payload) + "\n") - - # Should propagate the custom exception - with pytest.raises(CustomDeserializerError) as exc_info: - list(load_payloads(jsonl_file, deserializer=failing_deserializer)) + f.write("not valid json\n") - # Verify the exception message is preserved - assert "Custom deserializer failed!" in str(exc_info.value) + # load_payloads prints errors for invalid JSON lines but doesn't raise + loaded = list(load_payloads(jsonl_file)) + assert len(loaded) == 0 - def test_unserializable_nested_object_error(self): - """Test error message for unserializable object in nested structure. + def test_unserializable_nested_object_str_fallback(self): + """Test str() fallback for unserializable object in nested structure. Validates: Requirements 6.1, 6.3 """ @@ -1547,13 +1501,9 @@ def __init__(self): ], } - # Should raise TypeError with type information - with pytest.raises(TypeError) as exc_info: - json.dumps(payload, cls=LLMeterBytesEncoder) - - # Verify error message contains the type name - error_message = str(exc_info.value) - assert "NestedCustomObject" in error_message + # The unified encoder falls back to str() for unknown types + result = json.dumps(payload, cls=LLMeterEncoder) + assert "test-model" in result def test_bytes_serialization_with_encoding_error(self): """Test that bytes serialization handles all byte values correctly. @@ -1568,7 +1518,7 @@ def test_bytes_serialization_with_encoding_error(self): payload = {"data": all_bytes} # Should serialize without errors - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Should be valid JSON parsed = json.loads(serialized) @@ -1577,7 +1527,6 @@ def test_bytes_serialization_with_encoding_error(self): assert "__llmeter_bytes__" in parsed["data"] # Verify round-trip works - from llmeter.prompt_utils import llmeter_bytes_decoder deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) assert deserialized["data"] == all_bytes @@ -1591,7 +1540,7 @@ def test_empty_payload_serialization(self): empty_payload = {} # Should serialize without errors - serialized = json.dumps(empty_payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(empty_payload, cls=LLMeterEncoder) assert serialized == "{}" # Empty list @@ -1619,7 +1568,7 @@ def test_none_value_serialization(self): } # Should serialize without errors - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Should be valid JSON parsed = json.loads(serialized) @@ -1640,7 +1589,7 @@ def test_unicode_in_payload_with_bytes(self): } # Should serialize without errors - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Should be valid JSON parsed = json.loads(serialized) @@ -1692,7 +1641,7 @@ def test_1mb_image_serialization_performance(self): # Measure serialization time start_time = time.perf_counter() - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) end_time = time.perf_counter() # Calculate elapsed time in milliseconds @@ -1726,10 +1675,9 @@ def test_1mb_image_deserialization_performance(self): } # First serialize the payload - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Measure deserialization time - from llmeter.prompt_utils import llmeter_bytes_decoder start_time = time.perf_counter() deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) @@ -1767,7 +1715,7 @@ def test_serialization_no_unnecessary_copies(self): initial_size = sys.getsizeof(binary_data) # Serialize the payload - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Parse to verify structure parsed = json.loads(serialized) @@ -1803,10 +1751,9 @@ def test_deserialization_no_unnecessary_copies(self): payload = {"data": binary_data} # Serialize first - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Deserialize - from llmeter.prompt_utils import llmeter_bytes_decoder deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) @@ -1847,11 +1794,10 @@ def test_round_trip_performance_with_multiple_images(self): # Measure serialization time start_time = time.perf_counter() - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) serialize_time = (time.perf_counter() - start_time) * 1000 # Measure deserialization time - from llmeter.prompt_utils import llmeter_bytes_decoder start_time = time.perf_counter() deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) @@ -1887,7 +1833,7 @@ def test_serialization_performance_scales_linearly(self): small_payload = {"data": small_data} start_time = time.perf_counter() - json.dumps(small_payload, cls=LLMeterBytesEncoder) + json.dumps(small_payload, cls=LLMeterEncoder) small_time = time.perf_counter() - start_time # Test with 512KB (2x size) @@ -1895,7 +1841,7 @@ def test_serialization_performance_scales_linearly(self): large_payload = {"data": large_data} start_time = time.perf_counter() - json.dumps(large_payload, cls=LLMeterBytesEncoder) + json.dumps(large_payload, cls=LLMeterEncoder) large_time = time.perf_counter() - start_time # Large should take roughly 2x the time (allow 3x for variance) @@ -1915,12 +1861,11 @@ def test_deserialization_performance_scales_linearly(self): import os import time - from llmeter.prompt_utils import llmeter_bytes_decoder # Test with 256KB small_data = os.urandom(256 * 1024) small_payload = {"data": small_data} - small_serialized = json.dumps(small_payload, cls=LLMeterBytesEncoder) + small_serialized = json.dumps(small_payload, cls=LLMeterEncoder) start_time = time.perf_counter() json.loads(small_serialized, object_hook=llmeter_bytes_decoder) @@ -1929,7 +1874,7 @@ def test_deserialization_performance_scales_linearly(self): # Test with 512KB (2x size) large_data = os.urandom(512 * 1024) large_payload = {"data": large_data} - large_serialized = json.dumps(large_payload, cls=LLMeterBytesEncoder) + large_serialized = json.dumps(large_payload, cls=LLMeterEncoder) start_time = time.perf_counter() json.loads(large_serialized, object_hook=llmeter_bytes_decoder) diff --git a/tests/unit/test_results.py b/tests/unit/test_results.py index 2273b6b..c2da0a8 100644 --- a/tests/unit/test_results.py +++ b/tests/unit/test_results.py @@ -329,20 +329,20 @@ def test_save_method_existing_responses(sample_result: Result, temp_dir: UPath): assert responses[-1]["id"] == "extra_response" -# Tests for InvocationResponseEncoder +# Tests for LLMeterEncoder # Validates Requirements: 7.1, 7.2 -def test_invocation_response_encoder_handles_bytes(): - """Test that InvocationResponseEncoder handles bytes objects via parent class. +def test_llmeter_encoder_handles_bytes(): + """Test that LLMeterEncoder handles bytes objects via base64 marker. Validates: Requirements 7.1, 7.2 """ - from llmeter.results import InvocationResponseEncoder + from llmeter.json_utils import LLMeterEncoder # Test bytes object encoding payload = {"image": {"bytes": b"\xff\xd8\xff\xe0"}} - encoded = json.dumps(payload, cls=InvocationResponseEncoder) + encoded = json.dumps(payload, cls=LLMeterEncoder) # Verify marker object is present assert "__llmeter_bytes__" in encoded @@ -352,12 +352,12 @@ def test_invocation_response_encoder_handles_bytes(): assert decoded["image"]["bytes"]["__llmeter_bytes__"] == "/9j/4A==" -def test_invocation_response_encoder_str_fallback(): - """Test that InvocationResponseEncoder falls back to str() for custom objects. +def test_llmeter_encoder_str_fallback(): + """Test that LLMeterEncoder falls back to str() for custom objects. Validates: Requirements 7.1, 7.2 """ - from llmeter.results import InvocationResponseEncoder + from llmeter.json_utils import LLMeterEncoder # Create a custom object with __str__ method class CustomObject: @@ -365,19 +365,19 @@ def __str__(self): return "custom_string_representation" payload = {"custom": CustomObject()} - encoded = json.dumps(payload, cls=InvocationResponseEncoder) + encoded = json.dumps(payload, cls=LLMeterEncoder) # Verify str() fallback was used decoded = json.loads(encoded) assert decoded["custom"] == "custom_string_representation" -def test_invocation_response_encoder_none_on_str_failure(): - """Test that InvocationResponseEncoder returns None when str() conversion fails. +def test_llmeter_encoder_none_on_str_failure(): + """Test that LLMeterEncoder returns None when str() conversion fails. Validates: Requirements 7.1, 7.2 """ - from llmeter.results import InvocationResponseEncoder + from llmeter.json_utils import LLMeterEncoder # Create a custom object that raises exception in __str__ class FailingObject: @@ -385,19 +385,19 @@ def __str__(self): raise RuntimeError("Cannot convert to string") payload = {"failing": FailingObject()} - encoded = json.dumps(payload, cls=InvocationResponseEncoder) + encoded = json.dumps(payload, cls=LLMeterEncoder) # Verify None was returned decoded = json.loads(encoded) assert decoded["failing"] is None -def test_invocation_response_encoder_mixed_types(): - """Test that InvocationResponseEncoder handles mixed types correctly. +def test_llmeter_encoder_mixed_types(): + """Test that LLMeterEncoder handles mixed types correctly. Validates: Requirements 7.1, 7.2 """ - from llmeter.results import InvocationResponseEncoder + from llmeter.json_utils import LLMeterEncoder class CustomObject: def __str__(self): @@ -415,7 +415,7 @@ def __str__(self): } } - encoded = json.dumps(payload, cls=InvocationResponseEncoder) + encoded = json.dumps(payload, cls=LLMeterEncoder) decoded = json.loads(encoded) # Verify bytes were encoded with marker @@ -519,7 +519,7 @@ def test_invocation_response_to_json_round_trip(): Validates: Requirements 7.1, 7.2, 7.3 """ - from llmeter.prompt_utils import llmeter_bytes_decoder + from llmeter.json_utils import llmeter_bytes_decoder # Create original response with binary content original_payload = { diff --git a/tests/unit/test_serialization_properties.py b/tests/unit/test_serialization_properties.py index fcb1f6b..637b082 100644 --- a/tests/unit/test_serialization_properties.py +++ b/tests/unit/test_serialization_properties.py @@ -1,24 +1,26 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""Property-based tests for JSON serialization optimization. +"""Property-based tests for LLMeterEncoder and llmeter_bytes_decoder. -This module contains property-based tests for the binary content serialization -feature using Hypothesis. These tests verify that the serialization and -deserialization of payloads containing bytes objects maintains data integrity -and correctness across a wide range of inputs. +This module contains property-based tests using Hypothesis to verify that +LLMeterEncoder correctly handles all supported types (bytes, datetime, date, +time, PathLike, to_dict() objects, and str() fallback) and that +llmeter_bytes_decoder restores binary content from marker objects. Feature: json-serialization-optimization """ import base64 import json +from datetime import date, datetime, time, timezone +from pathlib import PurePosixPath from hypothesis import given, settings from hypothesis import strategies as st from hypothesis.strategies import composite -from llmeter.prompt_utils import LLMeterBytesEncoder +from llmeter.json_utils import LLMeterEncoder, llmeter_bytes_decoder # Test infrastructure is set up and ready for property test implementation # This file will contain property-based tests for: @@ -107,7 +109,7 @@ def test_property_1_serialization_produces_valid_json_with_marker_objects( base64-encoded string value. """ # Serialize the payload - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Verify it's valid JSON by parsing it parsed = json.loads(serialized) @@ -161,7 +163,6 @@ def test_property_3_deserialization_restores_bytes_from_markers( key, deserializing SHALL convert each marker object back to the original bytes object by base64-decoding the string value. """ - from llmeter.prompt_utils import llmeter_bytes_decoder # Create a marker object from the original bytes base64_str = base64.b64encode(original_bytes).decode("utf-8") @@ -186,7 +187,6 @@ def test_property_4_deserialization_preserves_non_marker_dicts(self, payload): For any payload containing dictionaries without the "__llmeter_bytes__" marker key, deserializing SHALL return those dictionaries unchanged. """ - from llmeter.prompt_utils import llmeter_bytes_decoder # Helper function to recursively apply decoder to all dicts def apply_decoder_recursively(obj): @@ -243,7 +243,7 @@ def test_property_2_serialization_preserves_non_binary_structure(self, payload): # Feature: json-serialization-optimization, Property 2: Serialization preserves non-binary structure """ # Serialize the payload - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Parse the JSON (without decoding markers back to bytes) parsed = json.loads(serialized) @@ -328,10 +328,9 @@ def test_property_5_round_trip_serialization_preserves_data_integrity( byte-for-byte equality of all bytes objects and exact equality of all other values. """ - from llmeter.prompt_utils import llmeter_bytes_decoder # Serialize the payload - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Deserialize the payload deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) @@ -399,10 +398,9 @@ def test_property_6_round_trip_preserves_dictionary_key_ordering(self, payload): # Feature: json-serialization-optimization, Property 6: Round-trip preserves dictionary key ordering """ - from llmeter.prompt_utils import llmeter_bytes_decoder # Serialize the payload - serialized = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized = json.dumps(payload, cls=LLMeterEncoder) # Deserialize the payload deserialized = json.loads(serialized, object_hook=llmeter_bytes_decoder) @@ -469,14 +467,14 @@ def test_property_7_non_binary_payloads_are_backward_compatible(self, payload): # Feature: json-serialization-optimization, Property 7: Non-binary payloads are backward compatible """ # Serialize with the new encoder - serialized_with_encoder = json.dumps(payload, cls=LLMeterBytesEncoder) + serialized_with_encoder = json.dumps(payload, cls=LLMeterEncoder) # Serialize with standard json.dumps serialized_standard = json.dumps(payload) # Verify they produce identical output assert serialized_with_encoder == serialized_standard, ( - "Serialization with LLMeterBytesEncoder should produce identical output " + "Serialization with LLMeterEncoder should produce identical output " "to standard json.dumps for payloads without bytes objects" ) @@ -523,16 +521,16 @@ def unserializable_object_strategy(draw): @given(st.data()) @settings(max_examples=100) - def test_property_8_serialization_errors_are_descriptive(self, data): - """Property 8: Serialization errors are descriptive. + def test_property_8_serialization_handles_unknown_types(self, data): + """Property 8: Serialization handles unknown types via str() fallback. **Validates: Requirements 6.1** For any payload containing unserializable types (not bytes, not standard - JSON types), attempting to serialize SHALL raise a TypeError with a message - indicating the problematic type. + JSON types), serialization SHALL succeed by falling back to str() + representation, producing valid JSON output. - # Feature: json-serialization-optimization, Property 8: Serialization errors are descriptive + # Feature: json-serialization-optimization, Property 8: Serialization handles unknown types """ # Generate a payload with an unserializable object unserializable_obj = data.draw( @@ -555,36 +553,11 @@ def test_property_8_serialization_errors_are_descriptive(self, data): payload_creator = data.draw(placement_strategy) payload = payload_creator(unserializable_obj) - # Attempt to serialize and verify it raises TypeError - try: - json.dumps(payload, cls=LLMeterBytesEncoder) - # If we get here, serialization succeeded when it shouldn't have - raise AssertionError( - f"Expected TypeError for unserializable object of type " - f"{type(unserializable_obj).__name__}, but serialization succeeded" - ) - except TypeError as e: - # Verify the error message is descriptive - error_msg = str(e) - - # The error message should mention that the object is not serializable - assert "not" in error_msg.lower() and "serializable" in error_msg.lower(), ( - f"Error message should indicate object is not serializable. " - f"Got: {error_msg}" - ) - - # The error message should include type information - # Either the class name or the word "type" should appear - type_name = type(unserializable_obj).__name__ - has_type_info = ( - type_name in error_msg or - "type" in error_msg.lower() or - "object" in error_msg.lower() - ) - assert has_type_info, ( - f"Error message should include type information. " - f"Expected reference to '{type_name}' or 'type'. Got: {error_msg}" - ) + # The unified encoder falls back to str() for unknown types + result = json.dumps(payload, cls=LLMeterEncoder) + # Result should be valid JSON + parsed = json.loads(result) + assert isinstance(parsed, dict) @given(st.data()) @@ -600,7 +573,6 @@ def test_property_9_deserialization_errors_are_descriptive(self, data): # Feature: json-serialization-optimization, Property 9: Deserialization errors are descriptive """ - from llmeter.prompt_utils import llmeter_bytes_decoder # Test invalid JSON strings that will raise JSONDecodeError # Note: base64.b64decode() is lenient by default and accepts many inputs, @@ -766,3 +738,124 @@ def verify_bytes_serialized(original, serialized_obj, path=""): assert parsed["time_to_last_token"] == 0.5 assert parsed["num_tokens_input"] == 10 assert parsed["num_tokens_output"] == 20 + + +# --------------------------------------------------------------------------- +# Strategies for the extended type tests +# --------------------------------------------------------------------------- + +# datetime strategy: aware and naive datetimes +_datetime_strategy = st.one_of( + # Naive datetimes + st.datetimes(min_value=datetime(2000, 1, 1), max_value=datetime(2030, 12, 31)), + # UTC-aware datetimes + st.datetimes( + min_value=datetime(2000, 1, 1), + max_value=datetime(2030, 12, 31), + timezones=st.just(timezone.utc), + ), +) + +_date_strategy = st.dates(min_value=date(2000, 1, 1), max_value=date(2030, 12, 31)) + +_time_strategy = st.times() + +_path_strategy = st.from_regex(r"[a-z][a-z0-9_/]{0,30}", fullmatch=True).map( + PurePosixPath +) + + +@composite +def to_dict_object_strategy(draw): + """Generate an object with a to_dict() method returning a JSON-safe dict.""" + inner = draw( + st.dictionaries( + keys=st.text(min_size=1, max_size=10), + values=st.one_of(st.integers(), st.text(max_size=20), st.booleans()), + max_size=5, + ) + ) + + class _Obj: + def __init__(self, d): + self._d = d + + def to_dict(self): + return self._d + + return _Obj(inner), inner + + +class TestDatetimeSerializationProperties: + """Property-based tests for datetime/date/time encoding.""" + + @given(_datetime_strategy) + @settings(max_examples=100) + def test_datetime_produces_iso_string_with_z_suffix(self, dt): + """Datetime values are serialized to ISO-8601 strings at seconds precision. + + Aware datetimes are converted to UTC and suffixed with 'Z'. + Naive datetimes are serialized as-is with no timezone indicator. + Microseconds are truncated (the encoder uses timespec="seconds"). + """ + result = json.loads(json.dumps({"v": dt}, cls=LLMeterEncoder)) + assert isinstance(result["v"], str) + if dt.tzinfo is not None: + assert result["v"].endswith("Z") + parsed_back = datetime.fromisoformat(result["v"].replace("Z", "+00:00")) + expected = dt.astimezone(timezone.utc).replace(microsecond=0) + assert parsed_back == expected + else: + parsed_back = datetime.fromisoformat(result["v"]) + assert parsed_back == dt.replace(microsecond=0) + + @given(_date_strategy) + @settings(max_examples=100) + def test_date_produces_iso_string(self, d): + """Date values are serialized via isoformat().""" + result = json.loads(json.dumps({"v": d}, cls=LLMeterEncoder)) + assert result["v"] == d.isoformat() + assert date.fromisoformat(result["v"]) == d + + @given(_time_strategy) + @settings(max_examples=100) + def test_time_produces_iso_string(self, t): + """Time values are serialized via isoformat().""" + result = json.loads(json.dumps({"v": t}, cls=LLMeterEncoder)) + assert result["v"] == t.isoformat() + assert time.fromisoformat(result["v"]) == t + + +class TestPathSerializationProperties: + """Property-based tests for PathLike encoding.""" + + @given(_path_strategy) + @settings(max_examples=100) + def test_pathlike_produces_posix_string(self, p): + """PathLike objects are serialized to POSIX path strings.""" + result = json.loads(json.dumps({"v": p}, cls=LLMeterEncoder)) + assert isinstance(result["v"], str) + assert result["v"] == p.as_posix() + + +class TestToDictSerializationProperties: + """Property-based tests for to_dict() delegation.""" + + @given(to_dict_object_strategy()) + @settings(max_examples=100) + def test_to_dict_delegation_produces_expected_dict(self, obj_and_expected): + """Objects with to_dict() are serialized by calling that method.""" + obj, expected = obj_and_expected + result = json.loads(json.dumps({"v": obj}, cls=LLMeterEncoder)) + assert result["v"] == expected + + @given(to_dict_object_strategy(), st.binary(min_size=1, max_size=100)) + @settings(max_examples=100) + def test_to_dict_and_bytes_coexist(self, obj_and_expected, raw_bytes): + """Payloads mixing to_dict() objects and bytes round-trip correctly.""" + obj, expected = obj_and_expected + payload = {"obj": obj, "data": raw_bytes} + serialized = json.dumps(payload, cls=LLMeterEncoder) + restored = json.loads(serialized, object_hook=llmeter_bytes_decoder) + assert restored["obj"] == expected + assert restored["data"] == raw_bytes From 10558dbcea43c5e0a8bdf20f3afd650e4ab910e8 Mon Sep 17 00:00:00 2001 From: Alessandro Cere Date: Mon, 30 Mar 2026 17:29:20 -0700 Subject: [PATCH 8/9] refactor: remove premature type coercion from to_dict methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move path and datetime string conversion out of to_dict() and to_dict_recursive_generic() — these are Python dict builders, not serializers. Type coercion to strings is now exclusively handled by LLMeterEncoder at JSON serialization time. - Endpoint.to_dict(): remove PathLike → as_posix() coercion, simplify to a dict comprehension. Remove unused os import. - to_dict_recursive_generic(): remove PathLike → as_posix() and datetime/date/time → isoformat() coercions. Keep structural recursion (nested dicts, lists, to_dict() delegation). Remove unused os, datetime imports. --- llmeter/callbacks/cost/serde.py | 8 +------- llmeter/endpoints/base.py | 11 +++-------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/llmeter/callbacks/cost/serde.py b/llmeter/callbacks/cost/serde.py index 28acbfa..f04dc39 100644 --- a/llmeter/callbacks/cost/serde.py +++ b/llmeter/callbacks/cost/serde.py @@ -5,9 +5,7 @@ # Python Built-Ins: import json import logging -import os from dataclasses import is_dataclass -from datetime import date, datetime, time from typing import Any, Protocol, TypeVar # External Dependencies: @@ -67,16 +65,12 @@ def to_dict_recursive_generic(obj: object, **kwargs) -> dict: result.update({k: getattr(obj, k) for k in dir(obj)}) result.update(kwargs) for k, v in result.items(): - if isinstance(v, (os.PathLike, Path)): - result[k] = Path(v).as_posix() - elif hasattr(v, "to_dict"): + if hasattr(v, "to_dict"): result[k] = v.to_dict() elif isinstance(v, dict): result[k] = to_dict_recursive_generic(v) elif isinstance(v, (list, tuple)): result[k] = [to_dict_recursive_generic(item) for item in v] - elif isinstance(v, (date, datetime, time)): - result[k] = v.isoformat() return result diff --git a/llmeter/endpoints/base.py b/llmeter/endpoints/base.py index 32c55a8..932c83d 100644 --- a/llmeter/endpoints/base.py +++ b/llmeter/endpoints/base.py @@ -7,7 +7,6 @@ import importlib import json -import os from abc import ABC, abstractmethod from dataclasses import asdict, dataclass from typing import Any @@ -291,13 +290,9 @@ def to_dict(self) -> dict: Returns: Dict: A dictionary representation of the endpoint configuration. """ - endpoint_conf = {} - for k, v in vars(self).items(): - if k.startswith("_"): - continue - if isinstance(v, (os.PathLike, Path)): - v = Path(v).as_posix() - endpoint_conf[k] = v + endpoint_conf = { + k: v for k, v in vars(self).items() if not k.startswith("_") + } endpoint_conf["endpoint_type"] = self.__class__.__name__ return endpoint_conf From 1a41bfe31ce06999eb7b670e04de49cb4c9ba556 Mon Sep 17 00:00:00 2001 From: Alessandro Cere Date: Tue, 31 Mar 2026 07:56:56 -0700 Subject: [PATCH 9/9] feat: add dynamic callback serialization via _callback_type marker MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement to_dict/from_dict on Callback base class using a "module:ClassName" type marker (_callback_type) for dynamic dispatch. Any Callback subclass — built-in or third-party — can now round-trip through JSON without a hardcoded registry. - Callback.to_dict() emits _callback_type with fully-qualified class path - Callback.from_dict() dynamically imports and instantiates the correct subclass, delegating to overridden from_dict when present - CostModel.to_dict() injects _callback_type alongside existing _type - CostModel.from_dict() strips both markers before construction - MlflowCallback gains to_dict/from_dict (replaces NotImplementedError stubs) - _RunConfig.save()/load() now serialize and restore callbacks - Replace old Callback base tests with serialization round-trip tests --- llmeter/callbacks/base.py | 179 ++++++++++++++++++++---- llmeter/callbacks/cost/model.py | 19 ++- llmeter/callbacks/mlflow.py | 29 ++-- llmeter/runner.py | 7 + tests/unit/callbacks/cost/test_model.py | 1 + tests/unit/callbacks/test_base.py | 92 +++++++----- 6 files changed, 249 insertions(+), 78 deletions(-) diff --git a/llmeter/callbacks/base.py b/llmeter/callbacks/base.py index f4c4dba..449f287 100644 --- a/llmeter/callbacks/base.py +++ b/llmeter/callbacks/base.py @@ -4,14 +4,21 @@ from __future__ import annotations +import importlib +import json +import logging from abc import ABC -from typing import final +from typing import Any, final from upath.types import ReadablePathLike, WritablePathLike from ..endpoints.base import InvocationResponse +from ..json_utils import LLMeterEncoder from ..results import Result from ..runner import _RunConfig +from ..utils import ensure_path + +logger = logging.getLogger(__name__) class Callback(ABC): @@ -22,8 +29,17 @@ class Callback(ABC): associated with test runs or individual model invocations. A Callback object may implement multiple of the defined lifecycle hooks (such as - `before_invoke`, `after_run`, etc). Callbacks must support serializing their configuration to - a file (by implementing `save_to_file`), and loading back (via `load_from_file`). + `before_invoke`, `after_run`, etc). Callbacks support serializing their configuration via + ``to_dict()`` / ``from_dict()`` (and the convenience wrappers ``save_to_file()`` / + ``load_from_file()``). + + Serialization uses a ``_callback_type`` marker (``"module:ClassName"``) so that + ``Callback.from_dict()`` can dynamically import and reconstruct the correct subclass + without a hardcoded registry. This means third-party callbacks round-trip through JSON + automatically, as long as the defining module is importable at load time. + + Subclasses with complex nested state (like ``CostModel``) can override ``to_dict()`` and + ``from_dict()`` while preserving the type marker by calling ``super()``. """ async def before_invoke(self, payload: dict) -> None: @@ -71,46 +87,149 @@ async def after_run(self, result: Result) -> None: """ pass - def save_to_file(self, path: WritablePathLike) -> None: - """Save this Callback to file + # -- Serialization ----------------------------------------------------------------- - Individual Callbacks implement this method to save their configuration to a file that will - be re-loadable with the equivalent `_load_from_file()` method. + def to_dict(self) -> dict: + """Serialize this callback's configuration to a JSON-safe dict. - Args: - path: (Local or Cloud) path where the callback is saved + The returned dict includes a ``_callback_type`` key with the fully-qualified + class path (``"module:ClassName"``), enabling ``Callback.from_dict`` to + reconstruct the correct subclass without a hardcoded registry. + + By default, all public (non-underscore-prefixed) instance attributes are + included. Subclasses with richer state should override this method and call + ``super().to_dict()`` to preserve the type marker. + + Returns: + dict: A JSON-serializable dictionary representation of this callback. + + Example:: + + >>> from llmeter.callbacks import CostModel + >>> from llmeter.callbacks.cost.dimensions import InputTokens + >>> model = CostModel(request_dims=[InputTokens(price_per_million=3.0)]) + >>> d = model.to_dict() + >>> d["_callback_type"] + 'llmeter.callbacks.cost.model:CostModel' """ - raise NotImplementedError("TODO: Callback.save_to_file is not yet implemented!") + cls = self.__class__ + data: dict[str, Any] = { + "_callback_type": f"{cls.__module__}:{cls.__qualname__}", + } + data.update({k: v for k, v in vars(self).items() if not k.startswith("_")}) + return data - @staticmethod - @final - def load_from_file(path: ReadablePathLike) -> Callback: - """Load (any type of) Callback from file + @classmethod + def from_dict(cls, raw: dict, **kwargs: Any) -> Callback: + """Reconstruct a Callback from a dict produced by ``to_dict()``. - `Callback.load_from_file()` attempts to detect the type of Callback saved in a given file, - and use the relevant implementation's `_load_from_file` method to load it. + Uses the ``_callback_type`` field to dynamically import and instantiate + the correct subclass. If called on a concrete subclass (e.g. + ``CostModel.from_dict(...)``), the ``_callback_type`` is still respected + so that the dict always controls which class is created. + + Args: + raw: A dictionary previously produced by ``to_dict()`` (or loaded from + JSON). Must contain a ``_callback_type`` key. + **kwargs: Extra keyword arguments forwarded to the resolved class + constructor (or its own ``from_dict`` if it overrides this method). + + Returns: + Callback: An instance of the appropriate Callback subclass. + + Raises: + ValueError: If ``_callback_type`` is missing from *raw*. + ImportError: If the module referenced by ``_callback_type`` cannot be + imported. + AttributeError: If the class name cannot be found in the referenced + module. + + Example:: + + >>> from llmeter.callbacks.base import Callback + >>> d = { + ... "_callback_type": "llmeter.callbacks.mlflow:MlflowCallback", + ... "step": 1, + ... "nested": False, + ... } + >>> cb = Callback.from_dict(d) # returns an MlflowCallback instance + """ + raw = dict(raw) # shallow copy — don't mutate caller's dict + callback_type = raw.pop("_callback_type", None) + if callback_type is None: + raise ValueError( + "Cannot deserialize Callback: dict is missing '_callback_type' key. " + f"Got keys: {list(raw.keys())}" + ) + + module_path, class_name = callback_type.rsplit(":", 1) + module = importlib.import_module(module_path) + callback_cls = getattr(module, class_name) + + # If the resolved class has its own from_dict (e.g. CostModel), delegate to it + # so that subclass-specific deserialization logic is honoured. + if callback_cls is not cls and "from_dict" in callback_cls.__dict__: + # Re-inject _callback_type so the subclass from_dict can pop it if needed + return callback_cls.from_dict(raw, **kwargs) + + # Remove any keys the constructor doesn't expect (e.g. _type from JSONableBase) + raw.pop("_type", None) + return callback_cls(**raw, **kwargs) + + def to_json(self, **kwargs: Any) -> str: + """Serialize this callback to a JSON string. Args: - path: (Local or Cloud) path where the callback is saved + **kwargs: Extra keyword arguments forwarded to ``json.dumps`` + (e.g. ``indent``). + Returns: - callback: A loaded Callback - for example an `MlflowCallback`. + str: JSON representation of this callback. """ - raise NotImplementedError( - "TODO: Callback.load_from_file is not yet implemented!" - ) + kwargs.setdefault("cls", LLMeterEncoder) + return json.dumps(self.to_dict(), **kwargs) @classmethod - def _load_from_file(cls, path: ReadablePathLike) -> Callback: - """Load this Callback from file + def from_json(cls, json_string: str, **kwargs: Any) -> Callback: + """Reconstruct a Callback from a JSON string produced by ``to_json()``. + + Args: + json_string: A valid JSON string. + **kwargs: Extra keyword arguments forwarded to ``from_dict``. + + Returns: + Callback: An instance of the appropriate Callback subclass. + """ + return cls.from_dict(json.loads(json_string), **kwargs) + + def save_to_file(self, path: WritablePathLike) -> None: + """Save this Callback's configuration to a JSON file. - Individual Callbacks implement this method to define how they can be loaded from files - created by the equivalent `save_to_file()` method. + The file can be loaded back with ``Callback.load_from_file(path)``. Args: - path: (Local or Cloud) path where the callback is saved + path: (Local or Cloud) path where the callback should be saved. + """ + path = ensure_path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as f: + f.write(self.to_json(indent=4)) + + @staticmethod + @final + def load_from_file(path: ReadablePathLike) -> Callback: + """Load (any type of) Callback from a JSON file. + + The ``_callback_type`` field inside the file determines which subclass is + instantiated, so callers don't need to know the concrete type in advance. + + Args: + path: (Local or Cloud) path to a JSON file previously created by + ``save_to_file()``. + Returns: - callback: The loaded Callback object + Callback: The deserialized callback instance. """ - raise NotImplementedError( - "TODO: Callback._load_from_file is not yet implemented!" - ) + path = ensure_path(path) + with path.open("r") as f: + return Callback.from_dict(json.load(f)) diff --git a/llmeter/callbacks/cost/model.py b/llmeter/callbacks/cost/model.py index 8e17664..9f3170b 100644 --- a/llmeter/callbacks/cost/model.py +++ b/llmeter/callbacks/cost/model.py @@ -13,7 +13,7 @@ from ..base import Callback from .dimensions import IRequestCostDimension, IRunCostDimension from .results import CalculatedCostWithDimensions -from .serde import JSONableBase, from_dict_with_class_map +from .serde import JSONableBase, from_dict_with_class, from_dict_with_class_map @dataclass @@ -201,6 +201,18 @@ async def after_run(self, result: Result) -> None: result, recalculate_request_costs=False, save=True ) + def to_dict(self, **kwargs) -> dict: + """Serialize the cost model to a JSON-safe dict. + + Injects the ``_callback_type`` marker from ``Callback.to_dict()`` into the + dict produced by ``JSONableBase.to_dict()``, so that + ``Callback.from_dict()`` can reconstruct this ``CostModel`` dynamically. + """ + data = JSONableBase.to_dict(self, **kwargs) + cls = self.__class__ + data["_callback_type"] = f"{cls.__module__}:{cls.__qualname__}" + return data + def save_to_file(self, path: str) -> None: """Save the cost model (including all dimensions) to a JSON file""" path = ensure_path(path) @@ -215,13 +227,16 @@ def from_dict(cls, raw: dict, alt_classes: dict = {}, **kwargs) -> "CostModel": **alt_classes, } raw_args = {**raw} + # Strip callback/serde type markers — they're not constructor args + raw_args.pop("_callback_type", None) + raw_args.pop("_type", None) for key in ("request_dims", "run_dims"): if key in raw_args: raw_args[key] = { name: from_dict_with_class_map(d, class_map=dim_classes) for name, d in raw_args[key].items() } - return super().from_dict(raw_args, alt_classes=alt_classes, **kwargs) + return from_dict_with_class(raw=raw_args, cls=cls, **kwargs) @classmethod def _load_from_file(cls, path: str): diff --git a/llmeter/callbacks/mlflow.py b/llmeter/callbacks/mlflow.py index b83635b..bfb10de 100644 --- a/llmeter/callbacks/mlflow.py +++ b/llmeter/callbacks/mlflow.py @@ -66,16 +66,29 @@ def __init__(self, step: int | None = None, nested: bool = False) -> None: # Check MLFlow is installed by polling any attribute on the module to trigger DeferredError mlflow.__version__ + def to_dict(self) -> dict: + """Serialize this callback's configuration to a JSON-safe dict. + + Returns: + dict: Contains ``_callback_type``, ``step``, and ``nested``. + """ + data = super().to_dict() + # Exclude the mlflow version check side-effect attribute if present + return data + @classmethod - async def _load_from_file(cls, path: str): - raise NotImplementedError( - "TODO: MlflowCallback does not yet support loading from file" - ) + def from_dict(cls, raw: dict, **kwargs) -> "MlflowCallback": + """Reconstruct an MlflowCallback from a dict produced by ``to_dict()``. - def save_to_file(self) -> str | None: - raise NotImplementedError( - "TODO: MlflowCallback does not yet support saving to file" - ) + Args: + raw: Dictionary with ``step`` and ``nested`` keys. + + Returns: + MlflowCallback: The deserialized callback. + """ + raw = dict(raw) + raw.pop("_callback_type", None) + return cls(**raw, **kwargs) async def _log_llmeter_run(self, result: Result): """Log parameters and metrics from an LLMeter run to MLflow. diff --git a/llmeter/runner.py b/llmeter/runner.py index 40496ad..6f56d1b 100644 --- a/llmeter/runner.py +++ b/llmeter/runner.py @@ -123,6 +123,9 @@ def save( if not isinstance(self.tokenizer, dict): config_copy.tokenizer = Tokenizer.to_dict(self.tokenizer) + if self.callbacks: + config_copy.callbacks = [cb.to_dict() for cb in self.callbacks] + with run_config_path.open("w") as f: f.write(json.dumps(asdict(config_copy), cls=LLMeterEncoder, indent=4)) @@ -139,6 +142,10 @@ def load(cls, load_path: Path | str, file_name: str = "run_config.json"): config = json.load(f) config["endpoint"] = Endpoint.load(config["endpoint"]) config["tokenizer"] = Tokenizer.load(config["tokenizer"]) + if config.get("callbacks"): + from .callbacks.base import Callback # deferred: callbacks.base imports _RunConfig + + config["callbacks"] = [Callback.from_dict(cb) for cb in config["callbacks"]] return cls(**config) diff --git a/tests/unit/callbacks/cost/test_model.py b/tests/unit/callbacks/cost/test_model.py index b1ea1a8..6eed5bd 100644 --- a/tests/unit/callbacks/cost/test_model.py +++ b/tests/unit/callbacks/cost/test_model.py @@ -25,6 +25,7 @@ def test_cost_model_serialization(): assert model.run_dims["ComputeSeconds"].price_per_hour == 50 assert model.to_dict() == { "_type": "CostModel", + "_callback_type": "llmeter.callbacks.cost.model:CostModel", "request_dims": { "TokensIn": { "_type": "InputTokens", diff --git a/tests/unit/callbacks/test_base.py b/tests/unit/callbacks/test_base.py index b3aa824..c236350 100644 --- a/tests/unit/callbacks/test_base.py +++ b/tests/unit/callbacks/test_base.py @@ -1,51 +1,67 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import json import pytest from llmeter.callbacks.base import Callback -class TestBase: - def test__load_from_file_not_implemented(self): - """ - Test that _load_from_file raises NotImplementedError. - """ - with pytest.raises(NotImplementedError): - Callback._load_from_file("valid_path.json") - - def test_load_from_file_not_implemented(self): - """ - Test that load_from_file raises a NotImplementedError as it's not yet implemented. - """ - with pytest.raises(NotImplementedError): - Callback.load_from_file("valid_path.txt") - - def test_load_from_file_raises_not_implemented_error(self): - """ - Test that Callback.load_from_file raises a NotImplementedError. - - This test verifies that calling the static method load_from_file - on the Callback class raises a NotImplementedError, as the method - is not yet implemented. - """ - with pytest.raises(NotImplementedError) as excinfo: - Callback.load_from_file("dummy_path") +class _DummyCallback(Callback): + """A minimal concrete callback for testing serialization.""" - assert ( - str(excinfo.value) - == "TODO: Callback.load_from_file is not yet implemented!" - ) + def __init__(self, alpha: int = 1, beta: str = "hello"): + self.alpha = alpha + self.beta = beta - def test_load_from_file_raises_not_implemented_error_2(self): - """ - Test that _load_from_file raises NotImplementedError when called. - This method is not yet implemented in the base Callback class. - """ - with pytest.raises(NotImplementedError) as excinfo: - Callback._load_from_file("dummy_path") +class TestCallbackSerialization: + def test_to_dict_includes_callback_type(self): + cb = _DummyCallback(alpha=42, beta="world") + d = cb.to_dict() assert ( - str(excinfo.value) - == "TODO: Callback._load_from_file is not yet implemented!" + d["_callback_type"] + == f"{_DummyCallback.__module__}:{_DummyCallback.__qualname__}" ) + assert d["alpha"] == 42 + assert d["beta"] == "world" + + def test_to_dict_excludes_private_attrs(self): + cb = _DummyCallback() + cb._internal = "secret" + d = cb.to_dict() + assert "_internal" not in d + + def test_from_dict_round_trip(self): + cb = _DummyCallback(alpha=7, beta="test") + d = cb.to_dict() + restored = Callback.from_dict(d) + assert isinstance(restored, _DummyCallback) + assert restored.alpha == 7 + assert restored.beta == "test" + + def test_from_dict_missing_callback_type_raises(self): + with pytest.raises(ValueError, match="_callback_type"): + Callback.from_dict({"alpha": 1}) + + def test_to_json_round_trip(self): + cb = _DummyCallback(alpha=99) + json_str = cb.to_json() + restored = Callback.from_json(json_str) + assert isinstance(restored, _DummyCallback) + assert restored.alpha == 99 + + def test_save_and_load_from_file(self, tmp_path): + cb = _DummyCallback(alpha=5, beta="file_test") + file_path = tmp_path / "callback.json" + cb.save_to_file(file_path) + + loaded = Callback.load_from_file(file_path) + assert isinstance(loaded, _DummyCallback) + assert loaded.alpha == 5 + assert loaded.beta == "file_test" + + # Verify the file is valid JSON with the type marker + with open(file_path) as f: + data = json.load(f) + assert "_callback_type" in data