-
Couldn't load subscription status.
- Fork 3
feat: auto detect custom dataset type based on file info #399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,19 +3,18 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import uuid | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from collections import defaultdict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from aiperf.common.config.user_config import UserConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from aiperf.common.decorators import implements_protocol | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from aiperf.common.enums import CustomDatasetType | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from aiperf.common.enums import CustomDatasetType, DatasetSamplingStrategy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from aiperf.common.factories import CustomDatasetFactory | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from aiperf.common.mixins import AIPerfLoggerMixin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from aiperf.common.models import Conversation, Text, Turn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from aiperf.dataset.generator import PromptGenerator | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from aiperf.dataset.loader.models import MooncakeTrace | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from aiperf.dataset.loader.protocol import CustomDatasetLoaderProtocol | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @implements_protocol(CustomDatasetLoaderProtocol) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @CustomDatasetFactory.register(CustomDatasetType.MOONCAKE_TRACE) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class MooncakeTraceDatasetLoader(AIPerfLoggerMixin): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """A dataset loader that loads Mooncake trace data from a file. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -54,6 +53,25 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._end_offset = user_config.input.fixed_schedule_end_offset | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(user_config=user_config, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def can_load( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cls, data: dict[str, Any] | None = None, filename: str | Path | None = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Check if this loader can handle the given data format. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| MooncakeTrace format has "input_length" or "text_input" fields, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and optionally "hash_ids". | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if data is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return "input_length" in data or ("text_input" in data and "hash_ids" in data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+56
to
+69
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can_load too strict; accept text_input without requiring hash_ids, and silence ARG003 Docstring says hash_ids is optional. Current predicate requires both and will fail valid traces, breaking auto-detect. Also rename filename to _filename to avoid Ruff ARG003. - def can_load(
- cls, data: dict[str, Any] | None = None, filename: str | Path | None = None
- ) -> bool:
+ def can_load(
+ cls, data: dict[str, Any] | None = None, _filename: str | Path | None = None
+ ) -> bool:
@@
- return "input_length" in data or ("text_input" in data and "hash_ids" in data)
+ return ("input_length" in data) or ("text_input" in data)📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.14.1)58-58: Unused class method argument: (ARG003) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_preferred_sampling_strategy(cls) -> DatasetSamplingStrategy: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Get the preferred dataset sampling strategy for MooncakeTrace.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return DatasetSamplingStrategy.SEQUENTIAL | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def load_dataset(self) -> dict[str, list[MooncakeTrace]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Load Mooncake trace data from a file. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Persist inferred type to config and guard missing file
Two fixes:
🧰 Tools
🪛 Ruff (0.14.1)
42-45: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents