From 84266da1115fe79e268beea4ea243f68018c0005 Mon Sep 17 00:00:00 2001 From: Anthony Casagrande Date: Fri, 24 Oct 2025 15:08:31 -0700 Subject: [PATCH] feat: auto detect custom dataset type based on file info --- src/aiperf/common/config/input_config.py | 20 -- src/aiperf/common/protocols.py | 40 +++ src/aiperf/dataset/__init__.py | 11 +- src/aiperf/dataset/composer/__init__.py | 2 +- src/aiperf/dataset/composer/custom.py | 117 +++++++- src/aiperf/dataset/composer/synthetic.py | 10 + src/aiperf/dataset/dataset_manager.py | 8 +- src/aiperf/dataset/loader/__init__.py | 4 - src/aiperf/dataset/loader/mooncake_trace.py | 26 +- src/aiperf/dataset/loader/multi_turn.py | 23 +- src/aiperf/dataset/loader/protocol.py | 16 -- src/aiperf/dataset/loader/random_pool.py | 30 +- src/aiperf/dataset/loader/single_turn.py | 36 ++- tests/composers/test_custom_composer.py | 55 +++- tests/loaders/test_can_load.py | 287 ++++++++++++++++++++ 15 files changed, 624 insertions(+), 61 deletions(-) delete mode 100644 src/aiperf/dataset/loader/protocol.py create mode 100644 tests/loaders/test_can_load.py diff --git a/src/aiperf/common/config/input_config.py b/src/aiperf/common/config/input_config.py index f9c710d9e..f37962e99 100644 --- a/src/aiperf/common/config/input_config.py +++ b/src/aiperf/common/config/input_config.py @@ -96,26 +96,6 @@ def validate_goodput(self) -> Self: return self - @model_validator(mode="after") - def validate_dataset_sampling_strategy(self) -> Self: - """Validate the dataset sampling strategy configuration.""" - if self.dataset_sampling_strategy is None: - match self.custom_dataset_type: - case CustomDatasetType.RANDOM_POOL: - self.dataset_sampling_strategy = DatasetSamplingStrategy.SHUFFLE - case ( - CustomDatasetType.MOONCAKE_TRACE - | CustomDatasetType.SINGLE_TURN - | CustomDatasetType.MULTI_TURN - ): - self.dataset_sampling_strategy = DatasetSamplingStrategy.SEQUENTIAL - case _: - self.dataset_sampling_strategy = ( - InputDefaults.DATASET_SAMPLING_STRATEGY - ) - - return self - extra: Annotated[ Any, Field( diff --git a/src/aiperf/common/protocols.py b/src/aiperf/common/protocols.py index 04fd14d7d..08b2c8367 100644 --- a/src/aiperf/common/protocols.py +++ b/src/aiperf/common/protocols.py @@ -9,6 +9,7 @@ from aiperf.common.environment import Environment from aiperf.common.hooks import Hook, HookType from aiperf.common.models import ( + Conversation, MetricRecordMetadata, ParsedResponse, ParsedResponseRecord, @@ -31,14 +32,17 @@ if TYPE_CHECKING: import multiprocessing + from pathlib import Path from rich.console import Console from aiperf.common.config import ServiceConfig, UserConfig + from aiperf.common.enums import DatasetSamplingStrategy from aiperf.common.messages.inference_messages import MetricRecordsData from aiperf.common.models.metadata import EndpointMetadata, TransportMetadata from aiperf.common.models.model_endpoint_info import ModelEndpointInfo from aiperf.common.models.record_models import MetricResult + from aiperf.dataset.loader.models import CustomDatasetT from aiperf.exporters.exporter_config import ExporterConfig, FileExportInfo from aiperf.metrics.metric_dicts import MetricRecordDict from aiperf.timing.config import TimingManagerConfig @@ -316,6 +320,42 @@ def __init__(self, exporter_config: "ExporterConfig") -> None: ... async def export(self, console: "Console") -> None: ... +@runtime_checkable +class CustomDatasetLoaderProtocol(Protocol): + """Protocol for custom dataset loaders that load dataset from a file and convert it to a list of Conversation objects.""" + + @classmethod + def can_load( + cls, data: dict[str, Any] | None = None, filename: "str | Path | None" = None + ) -> bool: + """Check if this loader can handle the given data format. + + Args: + data: Optional dictionary representing a single line from the JSONL file. + None indicates path-based detection only (e.g., for directories). + filename: Optional path to the input file/directory for path-based detection + + Returns: + True if this loader can handle the data format, False otherwise + """ + ... + + @classmethod + def get_preferred_sampling_strategy(cls) -> "DatasetSamplingStrategy": + """Get the preferred dataset sampling strategy for this loader. + + Returns: + DatasetSamplingStrategy: The preferred sampling strategy + """ + ... + + def load_dataset(self) -> dict[str, list["CustomDatasetT"]]: ... + + def convert_to_conversations( + self, custom_data: dict[str, list["CustomDatasetT"]] + ) -> list[Conversation]: ... + + @runtime_checkable class DataExporterProtocol(Protocol): """ diff --git a/src/aiperf/dataset/__init__.py b/src/aiperf/dataset/__init__.py index 05995de87..20362830a 100644 --- a/src/aiperf/dataset/__init__.py +++ b/src/aiperf/dataset/__init__.py @@ -1,6 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 - +######################################################################## +## 🚩 mkinit flags 🚩 ## +######################################################################## +__ignore__ = ["logger"] +######################################################################## +## ⚠️ This file is auto-generated by mkinit ⚠️ ## +## ⚠️ Do not edit below this line ⚠️ ## +######################################################################## from aiperf.dataset.composer import ( BaseDatasetComposer, CustomDatasetComposer, @@ -29,7 +36,6 @@ from aiperf.dataset.loader import ( AIPERF_DATASET_CACHE_DIR, BasePublicDatasetLoader, - CustomDatasetLoaderProtocol, CustomDatasetT, MediaConversionMixin, MooncakeTrace, @@ -59,7 +65,6 @@ "BaseGenerator", "BasePublicDatasetLoader", "CustomDatasetComposer", - "CustomDatasetLoaderProtocol", "CustomDatasetT", "DEFAULT_CORPUS_FILE", "DatasetManager", diff --git a/src/aiperf/dataset/composer/__init__.py b/src/aiperf/dataset/composer/__init__.py index 04892cc8e..e2a5e547a 100644 --- a/src/aiperf/dataset/composer/__init__.py +++ b/src/aiperf/dataset/composer/__init__.py @@ -3,7 +3,7 @@ ######################################################################## ## 🚩 mkinit flags 🚩 ## ######################################################################## -__ignore__ = [] +__ignore__ = ["logger"] ######################################################################## ## ⚠️ This file is auto-generated by mkinit ⚠️ ## ## ⚠️ Do not edit below this line ⚠️ ## diff --git a/src/aiperf/dataset/composer/custom.py b/src/aiperf/dataset/composer/custom.py index 35bced43e..7e039cb00 100644 --- a/src/aiperf/dataset/composer/custom.py +++ b/src/aiperf/dataset/composer/custom.py @@ -1,6 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from pathlib import Path +from typing import Any + +from aiperf.common.aiperf_logger import AIPerfLogger from aiperf.common.config import UserConfig from aiperf.common.decorators import implements_protocol from aiperf.common.enums import ComposerType, CustomDatasetType @@ -8,8 +12,11 @@ from aiperf.common.models import Conversation from aiperf.common.protocols import ServiceProtocol from aiperf.common.tokenizer import Tokenizer -from aiperf.dataset import utils +from aiperf.common.utils import load_json_str from aiperf.dataset.composer.base import BaseDatasetComposer +from aiperf.dataset.utils import check_file_exists + +logger = AIPerfLogger(__name__) @implements_protocol(ServiceProtocol) @@ -25,14 +32,118 @@ def create_dataset(self) -> list[Conversation]: list[Conversation]: A list of conversation objects. """ # TODO: (future) for K8s, we need to transfer file data from SC (across node) - utils.check_file_exists(self.config.input.file) + check_file_exists(self.config.input.file) + + # Auto-infer dataset type if not provided + dataset_type = self.config.input.custom_dataset_type + if dataset_type is None: + dataset_type = self._infer_dataset_type(self.config.input.file) + if dataset_type is None: + raise ValueError( + f"Could not infer dataset type from file: {self.config.input.file}. " + "Please specify --custom-dataset-type explicitly." + ) + self.info(f"Auto-detected dataset type: {dataset_type}") + + # Set dataset sampling strategy based on inferred type if not explicitly set + self._set_sampling_strategy(dataset_type) - self._create_loader_instance(self.config.input.custom_dataset_type) + self._create_loader_instance(dataset_type) dataset = self.loader.load_dataset() conversations = self.loader.convert_to_conversations(dataset) self._finalize_conversations(conversations) return conversations + @staticmethod + def _infer_dataset_type(file_path: str) -> CustomDatasetType | None: + """Infer the custom dataset type from the input file. + + Queries all registered loaders to check if they can handle the data format. + + Args: + file_path: Path to the JSONL file or directory + + Returns: + CustomDatasetType if successfully inferred, None otherwise + """ + try: + path = Path(file_path) + + # If it's a directory, use path-based detection only + if path.is_dir(): + return CustomDatasetComposer._infer_type(data=None, filename=file_path) + + # For files, read first non-empty line and use both content and path detection + with open(file_path) as f: + for line in f: + if not (line := line.strip()): + continue + data = load_json_str(line) + return CustomDatasetComposer._infer_type( + data=data, filename=file_path + ) + + except Exception: + logger.exception(f"Error inferring dataset type from file: {file_path}") + return None + + @staticmethod + def _infer_type( + data: dict[str, Any] | None = None, filename: str | Path | None = None + ) -> CustomDatasetType | None: + """Infer the dataset type from data and/or filename. + + First checks for explicit 'type' field in the data, then falls back to + structural detection by querying registered loaders via the factory. + + Args: + data: Optional dictionary representing a single line from the JSONL file. + None indicates path-based detection only (e.g., for directories). + filename: Optional path to the input file/directory for path-based detection + + Returns: + CustomDatasetType if successfully inferred, None otherwise + """ + # Check for explicit type field first (most efficient) + if data is not None and "type" in data: + try: + # Try to convert the type string to enum + explicit_type = CustomDatasetType(data["type"]) + logger.info(f"Using explicit type field: {explicit_type}") + return explicit_type + except (ValueError, KeyError): + logger.info( + f"Invalid type field value: {data['type']}, falling back to structural detection" + ) + + for ( + loader_class, + dataset_type, + ) in CustomDatasetFactory.get_all_classes_and_types(): + if loader_class.can_load(data, filename): + logger.info( + f"Loader {loader_class.__name__} can handle the data format (structural detection)" + ) + return dataset_type + return None + + def _set_sampling_strategy(self, dataset_type: CustomDatasetType) -> None: + """Set the dataset sampling strategy based on the dataset type. + + If the user has not explicitly set a sampling strategy, use the loader's + preferred strategy. + + Args: + dataset_type: The type of custom dataset + """ + if self.config.input.dataset_sampling_strategy is None: + loader_class = CustomDatasetFactory.get_class_from_type(dataset_type) + preferred_strategy = loader_class.get_preferred_sampling_strategy() + self.config.input.dataset_sampling_strategy = preferred_strategy + self.info( + f"Using preferred sampling strategy for {dataset_type}: {preferred_strategy}" + ) + def _create_loader_instance(self, dataset_type: CustomDatasetType) -> None: """Initializes the dataset loader based on the custom dataset type. diff --git a/src/aiperf/dataset/composer/synthetic.py b/src/aiperf/dataset/composer/synthetic.py index 19f54b971..f50912b08 100644 --- a/src/aiperf/dataset/composer/synthetic.py +++ b/src/aiperf/dataset/composer/synthetic.py @@ -4,6 +4,7 @@ import uuid from aiperf.common.config import UserConfig +from aiperf.common.config.config_defaults import InputDefaults from aiperf.common.enums import ComposerType from aiperf.common.factories import ComposerFactory from aiperf.common.models import Audio, Conversation, Image, Text, Turn, Video @@ -17,6 +18,15 @@ class SyntheticDatasetComposer(BaseDatasetComposer): def __init__(self, config: UserConfig, tokenizer: Tokenizer): super().__init__(config, tokenizer) + # Set default sampling strategy for synthetic datasets if not explicitly set + if self.config.input.dataset_sampling_strategy is None: + self.config.input.dataset_sampling_strategy = ( + InputDefaults.DATASET_SAMPLING_STRATEGY + ) + self.info( + f"Using default sampling strategy for synthetic dataset: {InputDefaults.DATASET_SAMPLING_STRATEGY}" + ) + if ( not self.include_prompt and not self.include_image diff --git a/src/aiperf/dataset/dataset_manager.py b/src/aiperf/dataset/dataset_manager.py index e3283c53b..d86683325 100644 --- a/src/aiperf/dataset/dataset_manager.py +++ b/src/aiperf/dataset/dataset_manager.py @@ -188,7 +188,13 @@ async def _configure_dataset(self) -> None: loader = ShareGPTLoader(self.user_config, self.tokenizer) dataset = await loader.load_dataset() conversations = await loader.convert_to_conversations(dataset) - elif self.user_config.input.custom_dataset_type is not None: + elif ( + self.user_config.input.custom_dataset_type is not None + or self.user_config.input.file is not None + ): + # Use CUSTOM composer if either: + # 1. custom_dataset_type is explicitly set, OR + # 2. input file is provided (composer will auto-infer type) composer = ComposerFactory.create_instance( ComposerType.CUSTOM, config=self.user_config, diff --git a/src/aiperf/dataset/loader/__init__.py b/src/aiperf/dataset/loader/__init__.py index 92a72b245..cfe4f9181 100644 --- a/src/aiperf/dataset/loader/__init__.py +++ b/src/aiperf/dataset/loader/__init__.py @@ -28,9 +28,6 @@ from aiperf.dataset.loader.multi_turn import ( MultiTurnDatasetLoader, ) -from aiperf.dataset.loader.protocol import ( - CustomDatasetLoaderProtocol, -) from aiperf.dataset.loader.random_pool import ( RandomPoolDatasetLoader, ) @@ -44,7 +41,6 @@ __all__ = [ "AIPERF_DATASET_CACHE_DIR", "BasePublicDatasetLoader", - "CustomDatasetLoaderProtocol", "CustomDatasetT", "MediaConversionMixin", "MooncakeTrace", diff --git a/src/aiperf/dataset/loader/mooncake_trace.py b/src/aiperf/dataset/loader/mooncake_trace.py index 797b46d9b..bdf63366c 100644 --- a/src/aiperf/dataset/loader/mooncake_trace.py +++ b/src/aiperf/dataset/loader/mooncake_trace.py @@ -3,19 +3,18 @@ import uuid from collections import defaultdict +from pathlib import Path +from typing import Any from aiperf.common.config.user_config import UserConfig -from aiperf.common.decorators import implements_protocol -from aiperf.common.enums import CustomDatasetType +from aiperf.common.enums import CustomDatasetType, DatasetSamplingStrategy from aiperf.common.factories import CustomDatasetFactory from aiperf.common.mixins import AIPerfLoggerMixin from aiperf.common.models import Conversation, Text, Turn from aiperf.dataset.generator import PromptGenerator from aiperf.dataset.loader.models import MooncakeTrace -from aiperf.dataset.loader.protocol import CustomDatasetLoaderProtocol -@implements_protocol(CustomDatasetLoaderProtocol) @CustomDatasetFactory.register(CustomDatasetType.MOONCAKE_TRACE) class MooncakeTraceDatasetLoader(AIPerfLoggerMixin): """A dataset loader that loads Mooncake trace data from a file. @@ -54,6 +53,25 @@ def __init__( self._end_offset = user_config.input.fixed_schedule_end_offset super().__init__(user_config=user_config, **kwargs) + @classmethod + def can_load( + cls, data: dict[str, Any] | None = None, filename: str | Path | None = None + ) -> bool: + """Check if this loader can handle the given data format. + + MooncakeTrace format has "input_length" or "text_input" fields, + and optionally "hash_ids". + """ + if data is None: + return False + + return "input_length" in data or ("text_input" in data and "hash_ids" in data) + + @classmethod + def get_preferred_sampling_strategy(cls) -> DatasetSamplingStrategy: + """Get the preferred dataset sampling strategy for MooncakeTrace.""" + return DatasetSamplingStrategy.SEQUENTIAL + def load_dataset(self) -> dict[str, list[MooncakeTrace]]: """Load Mooncake trace data from a file. diff --git a/src/aiperf/dataset/loader/multi_turn.py b/src/aiperf/dataset/loader/multi_turn.py index 94fc8775e..38441cb29 100644 --- a/src/aiperf/dataset/loader/multi_turn.py +++ b/src/aiperf/dataset/loader/multi_turn.py @@ -3,8 +3,10 @@ import uuid from collections import defaultdict +from pathlib import Path +from typing import Any -from aiperf.common.enums import CustomDatasetType, MediaType +from aiperf.common.enums import CustomDatasetType, DatasetSamplingStrategy, MediaType from aiperf.common.factories import CustomDatasetFactory from aiperf.common.models import Conversation, Turn from aiperf.dataset.loader.mixins import MediaConversionMixin @@ -93,6 +95,24 @@ class MultiTurnDatasetLoader(MediaConversionMixin): def __init__(self, filename: str): self.filename = filename + @classmethod + def can_load( + cls, data: dict[str, Any] | None = None, filename: str | Path | None = None + ) -> bool: + """Check if this loader can handle the given data format. + + MultiTurn format has a "turns" field containing a list of turns. + """ + if data is None: + return False + + return "turns" in data and isinstance(data.get("turns"), list) + + @classmethod + def get_preferred_sampling_strategy(cls) -> DatasetSamplingStrategy: + """Get the preferred dataset sampling strategy for MultiTurn.""" + return DatasetSamplingStrategy.SEQUENTIAL + def load_dataset(self) -> dict[str, list[MultiTurn]]: """Load multi-turn data from a JSONL file. @@ -139,6 +159,7 @@ def convert_to_conversations( texts=media[MediaType.TEXT], images=media[MediaType.IMAGE], audios=media[MediaType.AUDIO], + videos=media[MediaType.VIDEO], timestamp=single_turn.timestamp, delay=single_turn.delay, role=single_turn.role, diff --git a/src/aiperf/dataset/loader/protocol.py b/src/aiperf/dataset/loader/protocol.py deleted file mode 100644 index 342e37640..000000000 --- a/src/aiperf/dataset/loader/protocol.py +++ /dev/null @@ -1,16 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from typing import Protocol, runtime_checkable - -from aiperf.common.models import Conversation -from aiperf.dataset.loader.models import CustomDatasetT - - -@runtime_checkable -class CustomDatasetLoaderProtocol(Protocol): - def load_dataset(self) -> dict[str, list[CustomDatasetT]]: ... - - def convert_to_conversations( - self, custom_data: dict[str, list[CustomDatasetT]] - ) -> list[Conversation]: ... diff --git a/src/aiperf/dataset/loader/random_pool.py b/src/aiperf/dataset/loader/random_pool.py index 97fa40812..1be29deaa 100644 --- a/src/aiperf/dataset/loader/random_pool.py +++ b/src/aiperf/dataset/loader/random_pool.py @@ -5,9 +5,9 @@ import uuid from collections import defaultdict from pathlib import Path -from typing import TypeAlias +from typing import Any, TypeAlias -from aiperf.common.enums import CustomDatasetType, MediaType +from aiperf.common.enums import CustomDatasetType, DatasetSamplingStrategy, MediaType from aiperf.common.factories import CustomDatasetFactory from aiperf.common.models import Conversation, Turn from aiperf.dataset.loader.mixins import MediaConversionMixin @@ -73,6 +73,30 @@ def __init__(self, filename: str, num_conversations: int = 1): self.filename = filename self.num_conversations = num_conversations + @classmethod + def can_load( + cls, data: dict[str, Any] | None = None, filename: str | Path | None = None + ) -> bool: + """Check if this loader can handle the given data format. + + RandomPool is the only loader that supports directory inputs. + For structural detection, RandomPool format is ambiguous with SingleTurn + (both have modality fields), so explicit 'type' field is recommended. + """ + if filename is not None: + path = Path(filename) if isinstance(filename, str) else filename + if path.is_dir(): + return True + + # RandomPool schema is very similar to SingleTurn, so we can't reliably + # distinguish without an explicit type field or directory path + return False + + @classmethod + def get_preferred_sampling_strategy(cls) -> DatasetSamplingStrategy: + """Get the preferred dataset sampling strategy for RandomPool.""" + return DatasetSamplingStrategy.SHUFFLE + def load_dataset(self) -> dict[Filename, list[RandomPool]]: """Load random pool data from a file or directory. @@ -164,6 +188,7 @@ def convert_to_conversations( texts=media[MediaType.TEXT], images=media[MediaType.IMAGE], audios=media[MediaType.AUDIO], + videos=media[MediaType.VIDEO], ) ) sampled_dataset[filename] = turns @@ -188,5 +213,6 @@ def _merge_turns(self, turns: list[Turn]) -> Turn: texts=[text for turn in turns for text in turn.texts], images=[image for turn in turns for image in turn.images], audios=[audio for turn in turns for audio in turn.audios], + videos=[video for turn in turns for video in turn.videos], ) return merged_turn diff --git a/src/aiperf/dataset/loader/single_turn.py b/src/aiperf/dataset/loader/single_turn.py index 6ce612f29..39b48bbfb 100644 --- a/src/aiperf/dataset/loader/single_turn.py +++ b/src/aiperf/dataset/loader/single_turn.py @@ -3,8 +3,10 @@ import uuid from collections import defaultdict +from pathlib import Path +from typing import Any -from aiperf.common.enums import CustomDatasetType, MediaType +from aiperf.common.enums import CustomDatasetType, DatasetSamplingStrategy, MediaType from aiperf.common.factories import CustomDatasetFactory from aiperf.common.models import Conversation, Turn from aiperf.dataset.loader.mixins import MediaConversionMixin @@ -67,6 +69,37 @@ class SingleTurnDatasetLoader(MediaConversionMixin): def __init__(self, filename: str): self.filename = filename + @classmethod + def can_load( + cls, data: dict[str, Any] | None = None, filename: str | Path | None = None + ) -> bool: + """Check if this loader can handle the given data format. + + SingleTurn format has modality fields (text/texts, image/images, etc.) + but does NOT have a "turns" field. + """ + if data is None: + return False + + modality_fields = [ + "text", + "texts", + "image", + "images", + "audio", + "audios", + "video", + "videos", + ] + has_modality = any(field in data for field in modality_fields) + + return has_modality and "turns" not in data + + @classmethod + def get_preferred_sampling_strategy(cls) -> DatasetSamplingStrategy: + """Get the preferred dataset sampling strategy for SingleTurn.""" + return DatasetSamplingStrategy.SEQUENTIAL + def load_dataset(self) -> dict[str, list[SingleTurn]]: """Load single-turn data from a JSONL file. @@ -110,6 +143,7 @@ def convert_to_conversations( texts=media[MediaType.TEXT], images=media[MediaType.IMAGE], audios=media[MediaType.AUDIO], + videos=media[MediaType.VIDEO], timestamp=single_turn.timestamp, delay=single_turn.delay, role=single_turn.role, diff --git a/tests/composers/test_custom_composer.py b/tests/composers/test_custom_composer.py index 088a79415..aa037335b 100644 --- a/tests/composers/test_custom_composer.py +++ b/tests/composers/test_custom_composer.py @@ -5,7 +5,7 @@ import pytest -from aiperf.common.enums import CustomDatasetType +from aiperf.common.enums import CustomDatasetType, DatasetSamplingStrategy from aiperf.common.models import Conversation, Turn from aiperf.dataset import ( MooncakeTraceDatasetLoader, @@ -63,7 +63,7 @@ def test_create_loader_instance_dataset_types( composer._create_loader_instance(dataset_type) assert isinstance(composer.loader, expected_instance) - @patch("aiperf.dataset.composer.custom.utils.check_file_exists") + @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("builtins.open", mock_open(read_data=MOCK_TRACE_CONTENT)) def test_create_dataset_trace(self, mock_check_file, trace_config, mock_tokenizer): """Test that create_dataset returns correct type.""" @@ -75,7 +75,7 @@ def test_create_dataset_trace(self, mock_check_file, trace_config, mock_tokenize assert all(isinstance(turn, Turn) for c in conversations for turn in c.turns) assert all(len(turn.texts) == 1 for c in conversations for turn in c.turns) - @patch("aiperf.dataset.composer.custom.utils.check_file_exists") + @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("builtins.open", mock_open(read_data=MOCK_TRACE_CONTENT)) def test_max_tokens_config(self, mock_check_file, trace_config, mock_tokenizer): trace_config.input.prompt.output_tokens.mean = 120 @@ -93,7 +93,7 @@ def test_max_tokens_config(self, mock_check_file, trace_config, mock_tokenizer): for turn in conversation.turns: assert turn.max_tokens == 20 - @patch("aiperf.dataset.composer.custom.utils.check_file_exists") + @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("builtins.open", mock_open(read_data=MOCK_TRACE_CONTENT)) @patch("pathlib.Path.iterdir", return_value=[]) def test_max_tokens_mooncake( @@ -114,7 +114,7 @@ def test_max_tokens_mooncake( class TestErrorHandling: """Test class for CustomDatasetComposer error handling scenarios.""" - @patch("aiperf.dataset.composer.custom.utils.check_file_exists") + @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("aiperf.dataset.composer.custom.CustomDatasetFactory.create_instance") def test_create_dataset_empty_result( self, mock_factory, mock_check_file, custom_config, mock_tokenizer @@ -131,3 +131,48 @@ def test_create_dataset_empty_result( assert isinstance(result, list) assert len(result) == 0 + + +class TestSamplingStrategy: + """Test class for CustomDatasetComposer sampling strategy configuration.""" + + @pytest.mark.parametrize( + "dataset_type,expected_strategy", + [ + (CustomDatasetType.SINGLE_TURN, DatasetSamplingStrategy.SEQUENTIAL), + (CustomDatasetType.MULTI_TURN, DatasetSamplingStrategy.SEQUENTIAL), + (CustomDatasetType.RANDOM_POOL, DatasetSamplingStrategy.SHUFFLE), + (CustomDatasetType.MOONCAKE_TRACE, DatasetSamplingStrategy.SEQUENTIAL), + ], + ) + def test_set_sampling_strategy_when_none( + self, custom_config, mock_tokenizer, dataset_type, expected_strategy + ): + """Test that _set_sampling_strategy sets the correct strategy when None.""" + custom_config.input.dataset_sampling_strategy = None + composer = CustomDatasetComposer(custom_config, mock_tokenizer) + + composer._set_sampling_strategy(dataset_type) + + assert composer.config.input.dataset_sampling_strategy == expected_strategy + + @pytest.mark.parametrize( + "dataset_type", + [ + CustomDatasetType.SINGLE_TURN, + CustomDatasetType.MULTI_TURN, + CustomDatasetType.RANDOM_POOL, + CustomDatasetType.MOONCAKE_TRACE, + ], + ) + def test_set_sampling_strategy_does_not_override( + self, custom_config, mock_tokenizer, dataset_type + ): + """Test that _set_sampling_strategy does not override explicitly set strategy.""" + explicit_strategy = DatasetSamplingStrategy.SHUFFLE + custom_config.input.dataset_sampling_strategy = explicit_strategy + composer = CustomDatasetComposer(custom_config, mock_tokenizer) + + composer._set_sampling_strategy(dataset_type) + + assert composer.config.input.dataset_sampling_strategy == explicit_strategy diff --git a/tests/loaders/test_can_load.py b/tests/loaders/test_can_load.py new file mode 100644 index 000000000..6d6a497c1 --- /dev/null +++ b/tests/loaders/test_can_load.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import tempfile +from pathlib import Path + +import pytest +from pytest import param + +from aiperf.common.enums import CustomDatasetType +from aiperf.dataset.composer.custom import CustomDatasetComposer +from aiperf.dataset.loader.mooncake_trace import MooncakeTraceDatasetLoader +from aiperf.dataset.loader.multi_turn import MultiTurnDatasetLoader +from aiperf.dataset.loader.random_pool import RandomPoolDatasetLoader +from aiperf.dataset.loader.single_turn import SingleTurnDatasetLoader + + +class TestSingleTurnCanLoad: + """Tests for SingleTurnDatasetLoader.can_load() method. + + Note: Individual loaders perform structural detection only. + Explicit 'type' field checking is done by the CustomDatasetComposer utility. + """ + + @pytest.mark.parametrize( + "data,expected", + [ + param({"text": "Hello world"}, True, id="text_field"), + param({"texts": ["Hello", "World"]}, True, id="texts_field"), + param({"image": "/path/to/image.png"}, True, id="image_field"), + param({"images": ["/path/1.png", "/path/2.png"]}, True, id="images_field"), + param({"audio": "/path/to/audio.wav"}, True, id="audio_field"), + param({"audios": ["/path/1.wav", "/path/2.wav"]}, True, id="audios_field"), + param({"video": "/path/to/video.mp4"}, True, id="video_field"), + param({"videos": ["/path/1.mp4", "/path/2.mp4"]}, True, id="videos_field"), + param({"text": "Describe this", "image": "/path.png", "audio": "/audio.wav"}, True, id="multimodal"), + # Explicit type is ignored by loader (factory handles it) + param({"type": "single_turn", "text": "Hello"}, True, id="with_type_field"), + param({"type": "random_pool", "text": "Hello"}, True, id="wrong_type_but_has_modality"), + param({"turns": [{"text": "Hello"}]}, False, id="has_turns_field"), + param({"session_id": "123", "metadata": "test"}, False, id="no_modality"), + param(None, False, id="none_data"), + ], + ) # fmt: skip + def test_can_load(self, data, expected): + """Test various data formats for SingleTurn structural detection.""" + assert SingleTurnDatasetLoader.can_load(data) is expected + + +class TestMultiTurnCanLoad: + """Tests for MultiTurnDatasetLoader.can_load() method. + + Note: Individual loaders perform structural detection only. + Explicit 'type' field checking is done by the CustomDatasetComposer utility. + """ + + @pytest.mark.parametrize( + "data,expected", + [ + param({"turns": [{"text": "Turn 1"}, {"text": "Turn 2"}]}, True, id="turns_list"), + param({"session_id": "session_123", "turns": [{"text": "Hello"}]}, True, id="with_session_id"), + # Explicit type is ignored by loader (factory handles it) + param({"type": "multi_turn", "turns": [{"text": "Hello"}]}, True, id="with_type_field"), + param({"text": "Hello world"}, False, id="no_turns_field"), + param({"turns": "not a list"}, False, id="turns_not_list_string"), + param({"turns": {"text": "Hello"}}, False, id="turns_not_list_dict"), + param(None, False, id="none_data"), + ], + ) # fmt: skip + def test_can_load(self, data, expected): + """Test various data formats for MultiTurn structural detection.""" + assert MultiTurnDatasetLoader.can_load(data) is expected + + +class TestRandomPoolCanLoad: + """Tests for RandomPoolDatasetLoader.can_load() method. + + Note: Individual loaders perform structural detection only. + Explicit 'type' field checking is done by the CustomDatasetComposer utility. + RandomPool's structural detection is limited to directory paths only. + """ + + @pytest.mark.parametrize( + "data,expected", + [ + # RandomPool cannot distinguish from SingleTurn without directory path + # Factory handles explicit type checking + param({"text": "Hello"}, False, id="no_directory_ambiguous"), + param({"type": "random_pool", "text": "Query"}, False, id="type_ignored_without_directory"), + ], + ) # fmt: skip + def test_can_load_content_based(self, data, expected): + """Test content-based detection for RandomPool (always False without directory).""" + assert RandomPoolDatasetLoader.can_load(data) is expected + + def test_can_load_with_directory_path(self): + """Test detection with directory path (unique to RandomPool).""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + assert ( + RandomPoolDatasetLoader.can_load(data=None, filename=temp_path) is True + ) + + def test_can_load_with_directory_path_as_string(self): + """Test detection with directory path as string.""" + with tempfile.TemporaryDirectory() as temp_dir: + assert ( + RandomPoolDatasetLoader.can_load(data=None, filename=temp_dir) is True + ) + + def test_cannot_load_with_file_path_no_type(self): + """Test rejection with file path but no explicit type (ambiguous with SingleTurn).""" + with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file: + temp_path = Path(temp_file.name) + data = {"text": "Hello"} + # Without explicit type, ambiguous with SingleTurn + assert RandomPoolDatasetLoader.can_load(data, filename=temp_path) is False + + +class TestMooncakeTraceCanLoad: + """Tests for MooncakeTraceDatasetLoader.can_load() method. + + Note: Individual loaders perform structural detection only. + Explicit 'type' field checking is done by the CustomDatasetComposer utility. + """ + + @pytest.mark.parametrize( + "data,expected", + [ + param({"input_length": 100, "output_length": 50}, True, id="input_length_with_output"), + param({"input_length": 100}, True, id="input_length_only"), + param({"text_input": "Hello world", "hash_ids": [123, 456]}, True, id="text_input_with_hash_ids"), + # Explicit type is ignored by loader (factory handles it) + param({"type": "mooncake_trace", "input_length": 100}, True, id="with_type_field"), + param({"text_input": "Hello world"}, False, id="text_input_without_hash_ids"), + param({"timestamp": 1000, "session_id": "abc"}, False, id="no_required_fields"), + param({"output_length": 50}, False, id="only_output_length"), + param(None, False, id="none_data"), + ], + ) # fmt: skip + def test_can_load(self, data, expected): + """Test various data formats for MooncakeTrace structural detection.""" + assert MooncakeTraceDatasetLoader.can_load(data) is expected + + +class TestCustomDatasetComposerInferType: + """Tests for CustomDatasetComposer._infer_type() method.""" + + @pytest.mark.parametrize( + "data,expected_type", + [ + param({"text": "Hello world"}, CustomDatasetType.SINGLE_TURN, id="single_turn_text"), + param({"type": "single_turn", "text": "Hello"}, CustomDatasetType.SINGLE_TURN, id="single_turn_explicit"), + param({"image": "/path.png"}, CustomDatasetType.SINGLE_TURN, id="single_turn_image"), + param({"turns": [{"text": "Turn 1"}]}, CustomDatasetType.MULTI_TURN, id="multi_turn_turns"), + param({"type": "multi_turn", "turns": [{"text": "Turn 1"}]}, CustomDatasetType.MULTI_TURN, id="multi_turn_explicit"), + param({"type": "random_pool", "text": "Query"}, CustomDatasetType.RANDOM_POOL, id="random_pool_explicit"), + param({"input_length": 100, "output_length": 50}, CustomDatasetType.MOONCAKE_TRACE, id="mooncake_input_length"), + param({"type": "mooncake_trace", "input_length": 100}, CustomDatasetType.MOONCAKE_TRACE, id="mooncake_explicit"), + param({"text_input": "Hello", "hash_ids": [1, 2]}, CustomDatasetType.MOONCAKE_TRACE, id="mooncake_text_input"), + param({"unknown_field": "value"}, None, id="unknown_format"), + param({"metadata": "test"}, None, id="unknown_metadata"), + ], + ) # fmt: skip + def test_infer_from_data(self, data, expected_type): + """Test inferring dataset type from various data formats.""" + result = CustomDatasetComposer._infer_type(data) + assert result == expected_type + + def test_infer_random_pool_with_directory(self): + """Test inferring RandomPool with directory path.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + result = CustomDatasetComposer._infer_type(data=None, filename=temp_path) + assert result == CustomDatasetType.RANDOM_POOL + + def test_infer_with_filename_parameter(self): + """Test inference with filename parameter for file path.""" + with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file: + temp_path = Path(temp_file.name) + data = {"text": "Hello"} + result = CustomDatasetComposer._infer_type(data, filename=temp_path) + # Should infer SingleTurn (file, not directory) + assert result == CustomDatasetType.SINGLE_TURN + + +class TestCustomDatasetComposerInferDatasetType: + """Tests for CustomDatasetComposer.infer_dataset_type() method.""" + + @pytest.mark.parametrize( + "content,expected_type", + [ + param(['{"text": "Hello world"}'], CustomDatasetType.SINGLE_TURN, id="single_turn_text"), + param(['{"image": "/path.png"}'], CustomDatasetType.SINGLE_TURN, id="single_turn_image"), + param(['{"turns": [{"text": "Turn 1"}, {"text": "Turn 2"}]}'], CustomDatasetType.MULTI_TURN, id="multi_turn"), + param(['{"type": "random_pool", "text": "Query"}'], CustomDatasetType.RANDOM_POOL, id="random_pool_explicit"), + param(['{"input_length": 100, "output_length": 50}'], CustomDatasetType.MOONCAKE_TRACE, id="mooncake_input_length"), + param(['{"text_input": "Hello", "hash_ids": [1, 2]}'], CustomDatasetType.MOONCAKE_TRACE, id="mooncake_text_input"), + param([], None, id="empty_file"), + param(["", " ", "\n"], None, id="only_empty_lines"), + param(["not valid json"], None, id="invalid_json"), + ], + ) # fmt: skip + def test_infer_from_file(self, create_jsonl_file, content, expected_type): + """Test inferring dataset type from file with various content.""" + filepath = create_jsonl_file(content) + result = CustomDatasetComposer._infer_dataset_type(filepath) + assert result == expected_type + + def test_infer_from_directory(self): + """Test inferring type from directory (should be RandomPool).""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create some files in the directory + temp_path = Path(temp_dir) + file1 = temp_path / "queries.jsonl" + file1.write_text('{"text": "Query 1"}\n') + + result = CustomDatasetComposer._infer_dataset_type(temp_dir) + assert result == CustomDatasetType.RANDOM_POOL + + +class TestDetectionPriorityAndAmbiguity: + """Tests for detection priority and handling of ambiguous cases. + + Note: Individual loaders perform structural detection only. + Explicit 'type' field checking is done by the CustomDatasetComposer utility. + """ + + def test_explicit_type_handled_by_inference(self): + """Test that explicit type field is handled by type inference, not loaders.""" + # Data with explicit random_pool type but modality fields + data = {"type": "random_pool", "text": "Hello"} + + # Loaders only do structural detection: + # - SingleTurn matches because it has modality fields + # - RandomPool does NOT match because it's not a directory + assert SingleTurnDatasetLoader.can_load(data) is True + assert RandomPoolDatasetLoader.can_load(data) is False + + # Type inference handles explicit type and should return RANDOM_POOL + result = CustomDatasetComposer._infer_type(data) + assert result == CustomDatasetType.RANDOM_POOL + + @pytest.mark.parametrize( + "data,single_turn,random_pool", + [ + param({"text": "Hello"}, True, False, id="text_field"), + param({"image": "/path.png"}, True, False, id="image_field"), + # RandomPool cannot match without directory (structural detection only) + param({"type": "random_pool", "text": "Hello"}, True, False, id="with_explicit_type"), + ], + ) # fmt: skip + def test_single_turn_vs_random_pool_ambiguity(self, data, single_turn, random_pool): + """Test SingleTurn vs RandomPool structural detection. + + Without directory path, RandomPool cannot be detected structurally. + """ + assert SingleTurnDatasetLoader.can_load(data) is single_turn + assert RandomPoolDatasetLoader.can_load(data) is random_pool + + def test_multi_turn_takes_priority_over_single_turn(self): + """Test that MultiTurn is correctly detected over SingleTurn.""" + data = {"turns": [{"text": "Hello"}]} + assert MultiTurnDatasetLoader.can_load(data) is True + assert SingleTurnDatasetLoader.can_load(data) is False + + @pytest.mark.parametrize( + "loader,should_match", + [ + param(MooncakeTraceDatasetLoader, True, id="mooncake"), + param(SingleTurnDatasetLoader, False, id="single_turn"), + param(MultiTurnDatasetLoader, False, id="multi_turn"), + param(RandomPoolDatasetLoader, False, id="random_pool"), + ], + ) # fmt: skip + def test_mooncake_trace_distinct_from_others(self, loader, should_match): + """Test that MooncakeTrace is distinct from other types.""" + data = {"input_length": 100} + assert loader.can_load(data) is should_match + + def test_directory_path_uniquely_identifies_random_pool(self): + """Test that directory path uniquely identifies RandomPool.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + assert RandomPoolDatasetLoader.can_load(data=None, filename=temp_path) is True # fmt: skip + assert SingleTurnDatasetLoader.can_load(data=None, filename=temp_path) is False # fmt: skip + assert MultiTurnDatasetLoader.can_load(data=None, filename=temp_path) is False # fmt: skip + assert MooncakeTraceDatasetLoader.can_load(data=None, filename=temp_path) is False # fmt: skip