Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion src/aiperf/common/enums/metric_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ class GenericMetricUnit(BaseMetricUnit):
RATIO = _unit("ratio")
USER = _unit("user")
PERCENT = _unit("%")
IMAGE = _unit("image")
IMAGES = _unit("images")
VIDEO = _unit("video")
VIDEOS = _unit("videos")


class PowerMetricUnitInfo(BaseMetricUnitInfo):
Expand Down Expand Up @@ -289,7 +293,11 @@ class MetricOverTimeUnitInfo(BaseMetricUnitInfo):
@model_validator(mode="after")
def _set_tag(self: Self) -> Self:
"""Set the tag based on the existing units. ie. requests/sec, tokens/sec, etc."""
self.tag = f"{self.primary_unit}/{self.time_unit}"
self.tag = (
f"{self.primary_unit}/{self.time_unit}"
if not self.inverted
else f"{self.time_unit}/{self.primary_unit}"
)
if self.third_unit:
# If there is a third unit, add it to the tag. ie. tokens/sec/user
self.tag += f"/{self.third_unit}"
Expand All @@ -302,6 +310,7 @@ def _set_tag(self: Self) -> Self:
primary_unit: "MetricUnitT"
time_unit: MetricTimeUnit | MetricTimeUnitInfo
third_unit: "MetricUnitT | None" = None
inverted: bool = False

def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
"""Convert a value from this unit to another unit."""
Expand Down Expand Up @@ -342,6 +351,24 @@ class MetricOverTimeUnit(BaseMetricUnit):
time_unit=MetricTimeUnit.SECONDS,
third_unit=GenericMetricUnit.USER,
)
IMAGES_PER_SECOND = MetricOverTimeUnitInfo(
primary_unit=GenericMetricUnit.IMAGES,
time_unit=MetricTimeUnit.SECONDS,
)
MS_PER_IMAGE = MetricOverTimeUnitInfo(
time_unit=MetricTimeUnit.MILLISECONDS,
primary_unit=GenericMetricUnit.IMAGE,
inverted=True,
)
VIDEOS_PER_SECOND = MetricOverTimeUnitInfo(
primary_unit=GenericMetricUnit.VIDEOS,
time_unit=MetricTimeUnit.SECONDS,
)
MS_PER_VIDEO = MetricOverTimeUnitInfo(
time_unit=MetricTimeUnit.MILLISECONDS,
primary_unit=GenericMetricUnit.VIDEO,
inverted=True,
)

@cached_property
def info(self) -> MetricOverTimeUnitInfo:
Expand All @@ -363,6 +390,11 @@ def third_unit(self) -> "MetricUnitT | None":
"""Get the third unit (if applicable)."""
return self.info.third_unit

@cached_property
def inverted(self) -> bool:
"""Whether the metric is inverted (e.g. time / metric)."""
return self.info.inverted


class MetricType(CaseInsensitiveStrEnum):
"""Defines the possible types of metrics."""
Expand Down Expand Up @@ -643,6 +675,9 @@ class MetricFlags(Flag):
TOKENIZES_INPUT_ONLY = 1 << 12
"""Metrics that are only applicable when the endpoint tokenizes input text."""

SUPPORTS_VIDEO_ONLY = 1 << 13
"""Metrics that are only applicable to video-based endpoints."""

def has_flags(self, flags: "MetricFlags") -> bool:
"""Return True if the metric has ALL of the given flag(s) (regardless of other flags)."""
# Bitwise AND will return the input flags only if all of the given flags are present.
Expand Down
1 change: 1 addition & 0 deletions src/aiperf/common/enums/plugin_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class EndpointType(CaseInsensitiveStrEnum):
COMPLETIONS = "completions"
EMBEDDINGS = "embeddings"
RANKINGS = "rankings"
IMAGE_RETRIEVAL = "image_retrieval"


class TransportType(CaseInsensitiveStrEnum):
Expand Down
2 changes: 2 additions & 0 deletions src/aiperf/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
BaseInferenceServerResponse,
BaseResponseData,
EmbeddingResponseData,
ImageRetrievalResponseData,
MetricRecordInfo,
MetricRecordMetadata,
MetricResult,
Expand Down Expand Up @@ -140,6 +141,7 @@
"GpuTelemetrySnapshot",
"IOCounters",
"Image",
"ImageRetrievalResponseData",
"InputsFile",
"JsonExportData",
"JsonMetricResult",
Expand Down
13 changes: 13 additions & 0 deletions src/aiperf/common/models/record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,18 @@ def get_text(self) -> str:
return ""


class ImageRetrievalResponseData(BaseResponseData):
"""Parsed image retrieval response data."""

data: list[dict[str, Any]] = Field(
..., description="The image retrieval data from the response."
)

def get_text(self) -> str:
"""Get the text of the response (empty for image retrieval)."""
return ""


class ParsedResponse(AIPerfBaseModel):
"""Parsed response from a inference client."""

Expand All @@ -608,6 +620,7 @@ class ParsedResponse(AIPerfBaseModel):
| TextResponseData
| EmbeddingResponseData
| RankingsResponseData
| ImageRetrievalResponseData
| BaseResponseData
] = Field(..., description="The parsed response data.")

Expand Down
8 changes: 8 additions & 0 deletions src/aiperf/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@
)
from aiperf.dataset.utils import (
check_file_exists,
encode_audio,
encode_image,
encode_video,
open_audio,
open_image,
open_video,
sample_normal,
sample_positive_normal,
sample_positive_normal_integer,
Expand Down Expand Up @@ -83,9 +87,13 @@
"SyntheticDatasetComposer",
"VideoGenerator",
"check_file_exists",
"encode_audio",
"encode_image",
"encode_video",
"main",
"open_audio",
"open_image",
"open_video",
"sample_normal",
"sample_positive_normal",
"sample_positive_normal_integer",
Expand Down
133 changes: 130 additions & 3 deletions src/aiperf/dataset/loader/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Iterable
from urllib.parse import urlparse

from aiperf.common.enums.dataset_enums import AudioFormat
from aiperf.common.enums.media_enums import MediaType
from aiperf.common.models import Media
from aiperf.common.types import MediaT
from aiperf.common.types import MediaT, MediaTypeT
from aiperf.dataset import utils
from aiperf.dataset.loader.models import CustomDatasetT


Expand Down Expand Up @@ -51,8 +55,8 @@ def _convert_to_media_objects(

Args:
data: The custom dataset to construct media objects from.
media_class: The target media class (Text, Image, or Audio).
field: The name of the field (e.g., 'text', 'image', 'audio').
media_class: The target media class (Text, Image, Audio, or Video).
field: The name of the field (e.g., 'text', 'image', 'audio', 'video').
name: The name of the media field.

Returns:
Expand All @@ -61,6 +65,9 @@ def _convert_to_media_objects(
# Check singular field first
value = getattr(data, field, None)
if value is not None:
# Handle media content (encode local files to base64)
if field in [MediaType.IMAGE, MediaType.VIDEO, MediaType.AUDIO]:
value = self._handle_media_content(value, media_type=MediaType(field))
return [media_class(name=name, contents=[value])]

# Check plural field
Expand All @@ -72,4 +79,124 @@ def _convert_to_media_objects(
if all(isinstance(v, media_class) for v in values):
return values

# Handle media content (encode local files to base64)
if field in [MediaType.IMAGE, MediaType.VIDEO, MediaType.AUDIO]:
values = [
self._handle_media_content(v, media_type=MediaType(field))
for v in values
]

return [media_class(name=name, contents=values)]

def _is_url(self, content: str) -> bool:
"""Check if content is a valid URL with scheme and netloc.

Args:
content: The content to check.

Returns:
True if content is a URL, False otherwise.

Raises:
ValueError: If URL has only scheme or only netloc (invalid).
"""
url = urlparse(content)

# Valid URL with both scheme and netloc
if url.scheme and url.netloc:
return True

# Invalid URL - has one but not both
if url.scheme or url.netloc:
raise ValueError(f"Valid URL must have both a scheme and netloc: {content}")

# Not a URL
return False

def _is_already_encoded(self, content: str, media_type: MediaTypeT) -> bool:
"""Check if content is already encoded in the expected format.

Args:
content: The content to check.
media_type: The media type (MediaType.IMAGE, MediaType.AUDIO, MediaType.VIDEO).

Returns:
True if content is already encoded, False otherwise.
"""
url = urlparse(content)

if media_type in [MediaType.IMAGE, MediaType.VIDEO]:
# Check for data URL format
return url.scheme == "data"

elif media_type == MediaType.AUDIO:
# Check for "format,base64" format
if "," in content and not url.scheme:
parts = content.split(",", 1)
return len(parts) == 2 and parts[0].lower() in [
AudioFormat.WAV,
AudioFormat.MP3,
]
return False

return False

def _encode_media_file(self, content: str, media_type: MediaTypeT) -> str:
"""Encode a local media file to base64.

Args:
content: The file path to encode.
media_type: The media type (MediaType.IMAGE, MediaType.AUDIO, MediaType.VIDEO).

Returns:
The base64-encoded content in the appropriate format.

Raises:
FileNotFoundError: If the file doesn't exist.
RuntimeError: If the format is unsupported.
"""
if media_type == MediaType.IMAGE:
img = utils.open_image(content)
img_base64 = utils.encode_image(img, img.format)
return f"data:image/{img.format.lower()};base64,{img_base64}"

elif media_type == MediaType.AUDIO:
audio_bytes, audio_format = utils.open_audio(content)
return utils.encode_audio(audio_bytes, audio_format)

elif media_type == MediaType.VIDEO:
video_bytes, video_format = utils.open_video(content)
return utils.encode_video(video_bytes, video_format)

raise ValueError(f"Unsupported media type: {media_type}")

def _handle_media_content(self, content: str, media_type: MediaTypeT) -> str:
"""Generic handler for media content encoding.

If the content is a URL, it's returned as-is.
If it's already encoded, it's returned as-is.
If it's a local file path, it's loaded and encoded to base64.

Args:
content: The media content - URL, encoded string, or local file path.
media_type: The media type (MediaType.IMAGE, MediaType.AUDIO, MediaType.VIDEO).

Returns:
The processed media content.

Raises:
FileNotFoundError: If the local file doesn't exist.
RuntimeError: If the media format is unsupported.
ValueError: If URL format is invalid.
"""
# Check if it's already encoded first (before URL check)
# This handles data URLs which have a scheme but no netloc
if self._is_already_encoded(content, media_type):
return content

# Check if it's a URL
if self._is_url(content):
return content

# Otherwise, it's a local file path - encode it
return self._encode_media_file(content, media_type)
Loading