diff --git a/llmeter/endpoints/bedrock.py b/llmeter/endpoints/bedrock.py index 04012ca..400c995 100644 --- a/llmeter/endpoints/bedrock.py +++ b/llmeter/endpoints/bedrock.py @@ -1,9 +1,12 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from base64 import b64decode +from copy import deepcopy import logging import time from uuid import uuid4 +import warnings import boto3 from botocore.config import Config @@ -55,7 +58,7 @@ def __init__( ) self._inference_config = inference_config - def _parse_payload(self, payload): + def _parse_payload(self, payload) -> str: """ Parse the payload to extract text content. @@ -102,6 +105,48 @@ def _parse_payload(self, payload): logger.error(f"Unexpected error parsing payload: {e}") return "" + @staticmethod + def _patch_base64_bytes(payload: dict) -> dict: + """Copy a request payload, undoing base64-encoding on image/multimedia bytes where present + + Bedrock's boto3 SDK base64-encodes inline multimedia bytes automatically, and doesn't try + to detect whether it's already been applied - so needs raw bytes input. However, LLMeter + prefers to keep this data encoded so that a load test's payloads are 1/ serializable to + JSON and 2/ consistent with what actually gets sent 'on the wire' to the endpoint. + + Therefore this method is called during invoke (before the performance counter starts), to + prepare the payload for boto3. + """ + payload = deepcopy(payload) + messages = payload.get("messages", []) + raw_media_types_found = set() + for msg in messages: + content = msg.get("content", []) + for c in content: + for media_type in ("image", "video"): + if media_type in c: + image_source = c[media_type].get("source", {}) + if "bytes" in image_source: + if isinstance(image_source["bytes"], str): + try: + # Undo base64-encoding: + image_source["bytes"] = b64decode( + image_source["bytes"] + ) + except Exception: + raw_media_types_found.add(media_type) + else: + raw_media_types_found.add(media_type) + if raw_media_types_found: + warnings.warn( + "Bedrock payload had raw source.bytes in %s content. Although this is " + 'supported by boto3, you probably want to `b64encode(img_data).decode("utf-8")` ' + "them in LLMeter, so that your payloads and load test results can be saved to " + "(JSON) files." % (sorted(list(raw_media_types_found))), + RuntimeWarning, + ) + return payload + @staticmethod def create_payload(user_message: str | list[str], max_tokens: int = 256, **kwargs): """ @@ -233,9 +278,12 @@ def invoke(self, payload: dict, **kwargs) -> InvocationResponse: payload["inferenceConfig"] = self._inference_config or {} payload["modelId"] = self.model_id + # Separate copy to store final modifications that *shouldn't* be visible in LLMeter + # results (as it's not JSON-serializable): + payload_patch = self._patch_base64_bytes(payload) try: start_t = time.perf_counter() - client_response = self._bedrock_client.converse(**payload) # type: ignore + client_response = self._bedrock_client.converse(**payload_patch) # type: ignore time_to_last_token = time.perf_counter() - start_t except ClientError as e: logger.error(f"Bedrock API error: {e}") @@ -268,9 +316,12 @@ def invoke(self, payload: dict, **kwargs) -> InvocationResponse: payload["inferenceConfig"] = self._inference_config or {} payload["modelId"] = self.model_id - start_t = time.perf_counter() + # Separate copy to store final modifications that *shouldn't* be visible in LLMeter + # results (as it's not JSON-serializable): + payload_patch = self._patch_base64_bytes(payload) try: - client_response = self._bedrock_client.converse_stream(**payload) # type: ignore + start_t = time.perf_counter() + client_response = self._bedrock_client.converse_stream(**payload_patch) # type: ignore except (ClientError, Exception) as e: logger.error(e) return InvocationResponse.error_output(