Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions llmeter/endpoints/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from base64 import b64decode
from copy import deepcopy
import logging
import time
from uuid import uuid4
import warnings

import boto3
from botocore.config import Config
Expand Down Expand Up @@ -55,7 +58,7 @@ def __init__(
)
self._inference_config = inference_config

def _parse_payload(self, payload):
def _parse_payload(self, payload) -> str:
"""
Parse the payload to extract text content.

Expand Down Expand Up @@ -102,6 +105,48 @@ def _parse_payload(self, payload):
logger.error(f"Unexpected error parsing payload: {e}")
return ""

@staticmethod
def _patch_base64_bytes(payload: dict) -> dict:
"""Copy a request payload, undoing base64-encoding on image/multimedia bytes where present

Bedrock's boto3 SDK base64-encodes inline multimedia bytes automatically, and doesn't try
to detect whether it's already been applied - so needs raw bytes input. However, LLMeter
prefers to keep this data encoded so that a load test's payloads are 1/ serializable to
JSON and 2/ consistent with what actually gets sent 'on the wire' to the endpoint.

Therefore this method is called during invoke (before the performance counter starts), to
prepare the payload for boto3.
"""
payload = deepcopy(payload)
messages = payload.get("messages", [])
raw_media_types_found = set()
for msg in messages:
content = msg.get("content", [])
for c in content:
for media_type in ("image", "video"):
if media_type in c:
image_source = c[media_type].get("source", {})
if "bytes" in image_source:
if isinstance(image_source["bytes"], str):
try:
# Undo base64-encoding:
image_source["bytes"] = b64decode(
image_source["bytes"]
)
except Exception:
raw_media_types_found.add(media_type)
else:
raw_media_types_found.add(media_type)
if raw_media_types_found:
warnings.warn(
"Bedrock payload had raw source.bytes in %s content. Although this is "
'supported by boto3, you probably want to `b64encode(img_data).decode("utf-8")` '
"them in LLMeter, so that your payloads and load test results can be saved to "
"(JSON) files." % (sorted(list(raw_media_types_found))),
RuntimeWarning,
)
return payload

@staticmethod
def create_payload(user_message: str | list[str], max_tokens: int = 256, **kwargs):
"""
Expand Down Expand Up @@ -233,9 +278,12 @@ def invoke(self, payload: dict, **kwargs) -> InvocationResponse:
payload["inferenceConfig"] = self._inference_config or {}

payload["modelId"] = self.model_id
# Separate copy to store final modifications that *shouldn't* be visible in LLMeter
# results (as it's not JSON-serializable):
payload_patch = self._patch_base64_bytes(payload)
try:
start_t = time.perf_counter()
client_response = self._bedrock_client.converse(**payload) # type: ignore
client_response = self._bedrock_client.converse(**payload_patch) # type: ignore
time_to_last_token = time.perf_counter() - start_t
except ClientError as e:
logger.error(f"Bedrock API error: {e}")
Expand Down Expand Up @@ -268,9 +316,12 @@ def invoke(self, payload: dict, **kwargs) -> InvocationResponse:
payload["inferenceConfig"] = self._inference_config or {}

payload["modelId"] = self.model_id
start_t = time.perf_counter()
# Separate copy to store final modifications that *shouldn't* be visible in LLMeter
# results (as it's not JSON-serializable):
payload_patch = self._patch_base64_bytes(payload)
try:
client_response = self._bedrock_client.converse_stream(**payload) # type: ignore
start_t = time.perf_counter()
client_response = self._bedrock_client.converse_stream(**payload_patch) # type: ignore
except (ClientError, Exception) as e:
logger.error(e)
return InvocationResponse.error_output(
Expand Down