Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions src/aiperf/common/config/input_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,26 +96,6 @@ def validate_goodput(self) -> Self:

return self

@model_validator(mode="after")
def validate_dataset_sampling_strategy(self) -> Self:
"""Validate the dataset sampling strategy configuration."""
if self.dataset_sampling_strategy is None:
match self.custom_dataset_type:
case CustomDatasetType.RANDOM_POOL:
self.dataset_sampling_strategy = DatasetSamplingStrategy.SHUFFLE
case (
CustomDatasetType.MOONCAKE_TRACE
| CustomDatasetType.SINGLE_TURN
| CustomDatasetType.MULTI_TURN
):
self.dataset_sampling_strategy = DatasetSamplingStrategy.SEQUENTIAL
case _:
self.dataset_sampling_strategy = (
InputDefaults.DATASET_SAMPLING_STRATEGY
)

return self

extra: Annotated[
Any,
Field(
Expand Down
40 changes: 40 additions & 0 deletions src/aiperf/common/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aiperf.common.environment import Environment
from aiperf.common.hooks import Hook, HookType
from aiperf.common.models import (
Conversation,
MetricRecordMetadata,
ParsedResponse,
ParsedResponseRecord,
Expand All @@ -31,14 +32,17 @@

if TYPE_CHECKING:
import multiprocessing
from pathlib import Path

from rich.console import Console

from aiperf.common.config import ServiceConfig, UserConfig
from aiperf.common.enums import DatasetSamplingStrategy
from aiperf.common.messages.inference_messages import MetricRecordsData
from aiperf.common.models.metadata import EndpointMetadata, TransportMetadata
from aiperf.common.models.model_endpoint_info import ModelEndpointInfo
from aiperf.common.models.record_models import MetricResult
from aiperf.dataset.loader.models import CustomDatasetT
from aiperf.exporters.exporter_config import ExporterConfig, FileExportInfo
from aiperf.metrics.metric_dicts import MetricRecordDict
from aiperf.timing.config import TimingManagerConfig
Expand Down Expand Up @@ -316,6 +320,42 @@ def __init__(self, exporter_config: "ExporterConfig") -> None: ...
async def export(self, console: "Console") -> None: ...


@runtime_checkable
class CustomDatasetLoaderProtocol(Protocol):
"""Protocol for custom dataset loaders that load dataset from a file and convert it to a list of Conversation objects."""

@classmethod
def can_load(
cls, data: dict[str, Any] | None = None, filename: "str | Path | None" = None
) -> bool:
"""Check if this loader can handle the given data format.

Args:
data: Optional dictionary representing a single line from the JSONL file.
None indicates path-based detection only (e.g., for directories).
filename: Optional path to the input file/directory for path-based detection

Returns:
True if this loader can handle the data format, False otherwise
"""
...

@classmethod
def get_preferred_sampling_strategy(cls) -> "DatasetSamplingStrategy":
"""Get the preferred dataset sampling strategy for this loader.

Returns:
DatasetSamplingStrategy: The preferred sampling strategy
"""
...

def load_dataset(self) -> dict[str, list["CustomDatasetT"]]: ...

def convert_to_conversations(
self, custom_data: dict[str, list["CustomDatasetT"]]
) -> list[Conversation]: ...


@runtime_checkable
class DataExporterProtocol(Protocol):
"""
Expand Down
11 changes: 8 additions & 3 deletions src/aiperf/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

########################################################################
## 🚩 mkinit flags 🚩 ##
########################################################################
__ignore__ = ["logger"]
########################################################################
## ⚠️ This file is auto-generated by mkinit ⚠️ ##
## ⚠️ Do not edit below this line ⚠️ ##
########################################################################
from aiperf.dataset.composer import (
BaseDatasetComposer,
CustomDatasetComposer,
Expand Down Expand Up @@ -29,7 +36,6 @@
from aiperf.dataset.loader import (
AIPERF_DATASET_CACHE_DIR,
BasePublicDatasetLoader,
CustomDatasetLoaderProtocol,
CustomDatasetT,
MediaConversionMixin,
MooncakeTrace,
Expand Down Expand Up @@ -59,7 +65,6 @@
"BaseGenerator",
"BasePublicDatasetLoader",
"CustomDatasetComposer",
"CustomDatasetLoaderProtocol",
"CustomDatasetT",
"DEFAULT_CORPUS_FILE",
"DatasetManager",
Expand Down
2 changes: 1 addition & 1 deletion src/aiperf/dataset/composer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
########################################################################
## 🚩 mkinit flags 🚩 ##
########################################################################
__ignore__ = []
__ignore__ = ["logger"]
########################################################################
## ⚠️ This file is auto-generated by mkinit ⚠️ ##
## ⚠️ Do not edit below this line ⚠️ ##
Expand Down
117 changes: 114 additions & 3 deletions src/aiperf/dataset/composer/custom.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from typing import Any

from aiperf.common.aiperf_logger import AIPerfLogger
from aiperf.common.config import UserConfig
from aiperf.common.decorators import implements_protocol
from aiperf.common.enums import ComposerType, CustomDatasetType
from aiperf.common.factories import ComposerFactory, CustomDatasetFactory
from aiperf.common.models import Conversation
from aiperf.common.protocols import ServiceProtocol
from aiperf.common.tokenizer import Tokenizer
from aiperf.dataset import utils
from aiperf.common.utils import load_json_str
from aiperf.dataset.composer.base import BaseDatasetComposer
from aiperf.dataset.utils import check_file_exists

logger = AIPerfLogger(__name__)


@implements_protocol(ServiceProtocol)
Expand All @@ -25,14 +32,118 @@ def create_dataset(self) -> list[Conversation]:
list[Conversation]: A list of conversation objects.
"""
# TODO: (future) for K8s, we need to transfer file data from SC (across node)
utils.check_file_exists(self.config.input.file)
check_file_exists(self.config.input.file)

# Auto-infer dataset type if not provided
dataset_type = self.config.input.custom_dataset_type
if dataset_type is None:
dataset_type = self._infer_dataset_type(self.config.input.file)
if dataset_type is None:
raise ValueError(
f"Could not infer dataset type from file: {self.config.input.file}. "
"Please specify --custom-dataset-type explicitly."
)
self.info(f"Auto-detected dataset type: {dataset_type}")

Comment on lines +35 to +47
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Persist inferred type to config and guard missing file

Two fixes:

  • If file is None, raise a clear error before calling check_file_exists.
  • After auto-detect, write dataset_type back to config so downstream modules relying on config.input.custom_dataset_type behave correctly (e.g., exact trace replay).
@@
-        # TODO: (future) for K8s, we need to transfer file data from SC (across node)
-        check_file_exists(self.config.input.file)
+        # TODO: (future) for K8s, we need to transfer file data from SC (across node)
+        if not self.config.input.file:
+            raise ValueError(
+                "Custom dataset requires --input-file. Provide --input-file or disable custom dataset."
+            )
+        check_file_exists(self.config.input.file)
@@
-        if dataset_type is None:
+        if dataset_type is None:
             dataset_type = self._infer_dataset_type(self.config.input.file)
             if dataset_type is None:
                 raise ValueError(
                     f"Could not infer dataset type from file: {self.config.input.file}. "
                     "Please specify --custom-dataset-type explicitly."
                 )
             self.info(f"Auto-detected dataset type: {dataset_type}")
+            # Persist inferred type for downstream logic (request count, fixed schedule, etc.)
+        self.config.input.custom_dataset_type = dataset_type

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.1)

42-45: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In src/aiperf/dataset/composer/custom.py around lines 35 to 47, the code calls
check_file_exists without first ensuring self.config.input.file is present and
also doesn't persist the auto-detected dataset_type back into the config; first,
if self.config.input.file is None raise a clear ValueError before calling
check_file_exists, and second, after auto-detecting dataset_type assign it back
to self.config.input.custom_dataset_type (so downstream code sees the inferred
value) and log the detection as before.

# Set dataset sampling strategy based on inferred type if not explicitly set
self._set_sampling_strategy(dataset_type)

self._create_loader_instance(self.config.input.custom_dataset_type)
self._create_loader_instance(dataset_type)
dataset = self.loader.load_dataset()
conversations = self.loader.convert_to_conversations(dataset)
self._finalize_conversations(conversations)
return conversations

@staticmethod
def _infer_dataset_type(file_path: str) -> CustomDatasetType | None:
"""Infer the custom dataset type from the input file.

Queries all registered loaders to check if they can handle the data format.

Args:
file_path: Path to the JSONL file or directory

Returns:
CustomDatasetType if successfully inferred, None otherwise
"""
try:
path = Path(file_path)

# If it's a directory, use path-based detection only
if path.is_dir():
return CustomDatasetComposer._infer_type(data=None, filename=file_path)

# For files, read first non-empty line and use both content and path detection
with open(file_path) as f:
for line in f:
if not (line := line.strip()):
continue
data = load_json_str(line)
return CustomDatasetComposer._infer_type(
data=data, filename=file_path
)

except Exception:
logger.exception(f"Error inferring dataset type from file: {file_path}")
return None

@staticmethod
def _infer_type(
data: dict[str, Any] | None = None, filename: str | Path | None = None
) -> CustomDatasetType | None:
"""Infer the dataset type from data and/or filename.

First checks for explicit 'type' field in the data, then falls back to
structural detection by querying registered loaders via the factory.

Args:
data: Optional dictionary representing a single line from the JSONL file.
None indicates path-based detection only (e.g., for directories).
filename: Optional path to the input file/directory for path-based detection

Returns:
CustomDatasetType if successfully inferred, None otherwise
"""
# Check for explicit type field first (most efficient)
if data is not None and "type" in data:
try:
# Try to convert the type string to enum
explicit_type = CustomDatasetType(data["type"])
logger.info(f"Using explicit type field: {explicit_type}")
return explicit_type
except (ValueError, KeyError):
logger.info(
f"Invalid type field value: {data['type']}, falling back to structural detection"
)

for (
loader_class,
dataset_type,
) in CustomDatasetFactory.get_all_classes_and_types():
if loader_class.can_load(data, filename):
logger.info(
f"Loader {loader_class.__name__} can handle the data format (structural detection)"
)
return dataset_type
return None

def _set_sampling_strategy(self, dataset_type: CustomDatasetType) -> None:
"""Set the dataset sampling strategy based on the dataset type.

If the user has not explicitly set a sampling strategy, use the loader's
preferred strategy.

Args:
dataset_type: The type of custom dataset
"""
if self.config.input.dataset_sampling_strategy is None:
loader_class = CustomDatasetFactory.get_class_from_type(dataset_type)
preferred_strategy = loader_class.get_preferred_sampling_strategy()
self.config.input.dataset_sampling_strategy = preferred_strategy
self.info(
f"Using preferred sampling strategy for {dataset_type}: {preferred_strategy}"
)

def _create_loader_instance(self, dataset_type: CustomDatasetType) -> None:
"""Initializes the dataset loader based on the custom dataset type.

Expand Down
10 changes: 10 additions & 0 deletions src/aiperf/dataset/composer/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid

from aiperf.common.config import UserConfig
from aiperf.common.config.config_defaults import InputDefaults
from aiperf.common.enums import ComposerType
from aiperf.common.factories import ComposerFactory
from aiperf.common.models import Audio, Conversation, Image, Text, Turn, Video
Expand All @@ -17,6 +18,15 @@ class SyntheticDatasetComposer(BaseDatasetComposer):
def __init__(self, config: UserConfig, tokenizer: Tokenizer):
super().__init__(config, tokenizer)

# Set default sampling strategy for synthetic datasets if not explicitly set
if self.config.input.dataset_sampling_strategy is None:
self.config.input.dataset_sampling_strategy = (
InputDefaults.DATASET_SAMPLING_STRATEGY
)
self.info(
f"Using default sampling strategy for synthetic dataset: {InputDefaults.DATASET_SAMPLING_STRATEGY}"
)

if (
not self.include_prompt
and not self.include_image
Expand Down
8 changes: 7 additions & 1 deletion src/aiperf/dataset/dataset_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,13 @@ async def _configure_dataset(self) -> None:
loader = ShareGPTLoader(self.user_config, self.tokenizer)
dataset = await loader.load_dataset()
conversations = await loader.convert_to_conversations(dataset)
elif self.user_config.input.custom_dataset_type is not None:
elif (
self.user_config.input.custom_dataset_type is not None
or self.user_config.input.file is not None
):
# Use CUSTOM composer if either:
# 1. custom_dataset_type is explicitly set, OR
# 2. input file is provided (composer will auto-infer type)
composer = ComposerFactory.create_instance(
ComposerType.CUSTOM,
config=self.user_config,
Expand Down
4 changes: 0 additions & 4 deletions src/aiperf/dataset/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
from aiperf.dataset.loader.multi_turn import (
MultiTurnDatasetLoader,
)
from aiperf.dataset.loader.protocol import (
CustomDatasetLoaderProtocol,
)
from aiperf.dataset.loader.random_pool import (
RandomPoolDatasetLoader,
)
Expand All @@ -44,7 +41,6 @@
__all__ = [
"AIPERF_DATASET_CACHE_DIR",
"BasePublicDatasetLoader",
"CustomDatasetLoaderProtocol",
"CustomDatasetT",
"MediaConversionMixin",
"MooncakeTrace",
Expand Down
26 changes: 22 additions & 4 deletions src/aiperf/dataset/loader/mooncake_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@

import uuid
from collections import defaultdict
from pathlib import Path
from typing import Any

from aiperf.common.config.user_config import UserConfig
from aiperf.common.decorators import implements_protocol
from aiperf.common.enums import CustomDatasetType
from aiperf.common.enums import CustomDatasetType, DatasetSamplingStrategy
from aiperf.common.factories import CustomDatasetFactory
from aiperf.common.mixins import AIPerfLoggerMixin
from aiperf.common.models import Conversation, Text, Turn
from aiperf.dataset.generator import PromptGenerator
from aiperf.dataset.loader.models import MooncakeTrace
from aiperf.dataset.loader.protocol import CustomDatasetLoaderProtocol


@implements_protocol(CustomDatasetLoaderProtocol)
@CustomDatasetFactory.register(CustomDatasetType.MOONCAKE_TRACE)
class MooncakeTraceDatasetLoader(AIPerfLoggerMixin):
"""A dataset loader that loads Mooncake trace data from a file.
Expand Down Expand Up @@ -54,6 +53,25 @@ def __init__(
self._end_offset = user_config.input.fixed_schedule_end_offset
super().__init__(user_config=user_config, **kwargs)

@classmethod
def can_load(
cls, data: dict[str, Any] | None = None, filename: str | Path | None = None
) -> bool:
"""Check if this loader can handle the given data format.

MooncakeTrace format has "input_length" or "text_input" fields,
and optionally "hash_ids".
"""
if data is None:
return False

return "input_length" in data or ("text_input" in data and "hash_ids" in data)

Comment on lines +56 to +69
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

can_load too strict; accept text_input without requiring hash_ids, and silence ARG003

Docstring says hash_ids is optional. Current predicate requires both and will fail valid traces, breaking auto-detect. Also rename filename to _filename to avoid Ruff ARG003.

-    def can_load(
-        cls, data: dict[str, Any] | None = None, filename: str | Path | None = None
-    ) -> bool:
+    def can_load(
+        cls, data: dict[str, Any] | None = None, _filename: str | Path | None = None
+    ) -> bool:
@@
-        return "input_length" in data or ("text_input" in data and "hash_ids" in data)
+        return ("input_length" in data) or ("text_input" in data)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@classmethod
def can_load(
cls, data: dict[str, Any] | None = None, filename: str | Path | None = None
) -> bool:
"""Check if this loader can handle the given data format.
MooncakeTrace format has "input_length" or "text_input" fields,
and optionally "hash_ids".
"""
if data is None:
return False
return "input_length" in data or ("text_input" in data and "hash_ids" in data)
@classmethod
def can_load(
cls, data: dict[str, Any] | None = None, _filename: str | Path | None = None
) -> bool:
"""Check if this loader can handle the given data format.
MooncakeTrace format has "input_length" or "text_input" fields,
and optionally "hash_ids".
"""
if data is None:
return False
return ("input_length" in data) or ("text_input" in data)
🧰 Tools
🪛 Ruff (0.14.1)

58-58: Unused class method argument: filename

(ARG003)

🤖 Prompt for AI Agents
In src/aiperf/dataset/loader/mooncake_trace.py around lines 56 to 69, the
can_load method is too strict and also triggers Ruff ARG003; update the
predicate to return True if data contains "input_length" or "text_input"
(without requiring "hash_ids") to match the docstring and accept valid traces,
and rename the filename parameter to _filename (and any internal usages) to
silence ARG003; keep the early None check and the docstring unchanged.

@classmethod
def get_preferred_sampling_strategy(cls) -> DatasetSamplingStrategy:
"""Get the preferred dataset sampling strategy for MooncakeTrace."""
return DatasetSamplingStrategy.SEQUENTIAL

def load_dataset(self) -> dict[str, list[MooncakeTrace]]:
"""Load Mooncake trace data from a file.

Expand Down
Loading