diff --git a/examples/llm_finetune/packed_parquet_example.yaml b/examples/llm_finetune/packed_parquet_example.yaml new file mode 100644 index 000000000..b9bdb917e --- /dev/null +++ b/examples/llm_finetune/packed_parquet_example.yaml @@ -0,0 +1,91 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Example: Fine-tuning with pre-packed Parquet SFT data. +# +# Pre-packed Parquet files use the RFC packed SFT format: +# - input_ids: list (variable-length token IDs, already packed) +# - loss_mask: list (1 = compute loss, 0 = ignore) +# - seq_start_id: list (sequence boundary positions within each pack) +# +# Files should be named *.idx.parquet and can be produced by the Nemotron +# data prep pipeline or any tool that writes this schema. +# +# To run: +# torchrun --nproc-per-node=8 examples/llm_finetune/finetune.py \ +# --config examples/llm_finetune/packed_parquet_example.yaml +# +# Override data_path on the command line: +# --dataset.data_path /data/packed_sft/shard_*.idx.parquet + + +step_scheduler: + global_batch_size: 32 + local_batch_size: 4 + ckpt_every_steps: 500 + num_epochs: 2 + +dist_env: + backend: nccl + timeout_minutes: 10 + +rng: + _target_: nemo_automodel.components.training.rng.StatefulRNG + seed: 1111 + ranked: true + +model: + _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained + pretrained_model_name_or_path: meta-llama/Llama-3.2-1B + +checkpoint: + enabled: false + checkpoint_dir: checkpoints/ + +distributed: + strategy: fsdp2 + tp_size: 1 + cp_size: 1 + pp_size: 1 + +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +# Pre-packed Parquet dataset — data is already packed, no runtime packing needed. +# packed_sequence_size must match the pack size used during data preparation. +dataset: + _target_: nemo_automodel.components.datasets.llm.packed_parquet_dataset.PackedParquetDataset + data_path: /path/to/packed_sft/ # single file, glob, or directory + packed_sequence_size: 4096 + padding_idx: 0 + split: train + +# packed_sequence_size > 0 tells the model to use THD attention. +# The is_pre_packed attribute on the dataset prevents re-packing. +packed_sequence: + packed_sequence_size: 4096 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: nemo_automodel.components.datasets.utils.packed_sequence_thd_collater + shuffle: true + num_workers: 4 + +optimizer: + _target_: torch.optim.Adam + betas: [0.9, 0.999] + eps: 1e-5 + lr: 1.0e-4 + weight_decay: 0 diff --git a/nemo_automodel/components/datasets/llm/__init__.py b/nemo_automodel/components/datasets/llm/__init__.py index 2a5270953..f87add8b6 100644 --- a/nemo_automodel/components/datasets/llm/__init__.py +++ b/nemo_automodel/components/datasets/llm/__init__.py @@ -20,6 +20,7 @@ is_delta_lake_path, ) from .nanogpt_dataset import NanogptDataset # noqa: F401 +from .packed_parquet_dataset import PackedParquetDataset # noqa: F401 from .retrieval_collator import RetrievalBiencoderCollator # noqa: F401 from .retrieval_dataset import make_retrieval_dataset # noqa: F401 from .squad import make_squad_dataset # noqa: F401 @@ -36,4 +37,5 @@ "ChatDataset", "DeltaLakeDataset", "is_delta_lake_path", + "PackedParquetDataset", ] diff --git a/nemo_automodel/components/datasets/llm/packed_parquet_dataset.py b/nemo_automodel/components/datasets/llm/packed_parquet_dataset.py new file mode 100644 index 000000000..4c1181246 --- /dev/null +++ b/nemo_automodel/components/datasets/llm/packed_parquet_dataset.py @@ -0,0 +1,323 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Packed Parquet SFT dataset for reading pre-packed sequences from Parquet files. + +Reads Parquet files with the RFC packed SFT format: + - input_ids: list variable-length token IDs (concatenated sequences) + - loss_mask: list 1 = compute loss, 0 = ignore + - seq_start_id: list sequence boundary positions within each pack + +Outputs the AutoModel packed format expected by ``packed_sequence_thd_collater``: + - input_ids, labels, position_ids, seq_lens, seq_lens_padded (all tensors) + +Usage:: + + # YAML config + dataset: + _target_: nemo_automodel.components.datasets.llm.packed_parquet_dataset.PackedParquetDataset + data_path: /data/packed_sft/shard_*.idx.parquet + packed_sequence_size: 4096 + + packed_sequence: + packed_sequence_size: 4096 + + dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: nemo_automodel.components.datasets.utils.packed_sequence_thd_collater + shuffle: true +""" + +import bisect +import glob as glob_module +import logging +from pathlib import Path +from typing import Optional, Union + +from torch.utils.data import Dataset + +from nemo_automodel.components.datasets.llm.packed_sequence import ( + CROSS_ENTROPY_IGNORE_IDX, + _tensorize_and_pad_pack, +) + +logger = logging.getLogger(__name__) + +_REQUIRED_COLUMNS = {"input_ids", "loss_mask", "seq_start_id"} + + +class _ReaderState: + """Per-worker reader state with row-group caching. + + Created lazily on first ``__getitem__`` so that the parent Dataset is + pickle-safe for DataLoader worker spawning. Each worker creates its + own instance with independent file handles. + + Caches the last-read row group to avoid repeated I/O for consecutive + samples that fall in the same row group. + """ + + def __init__(self, files: list[str]): + self._files = files + self._pf_cache: dict[int, object] = {} # file_idx -> ParquetFile + self._cached_rg_flat_idx: int = -1 + self._cached_table = None + + def _get_parquet_file(self, file_idx: int): + import pyarrow.parquet as pq + + if file_idx not in self._pf_cache: + self._pf_cache[file_idx] = pq.ParquetFile(self._files[file_idx], memory_map=True) + return self._pf_cache[file_idx] + + def read_row( + self, + rg_flat_idx: int, + row_within_rg: int, + rg_index: list[tuple[int, int, int]], + ) -> dict: + """Read a single row, using the row-group cache.""" + if self._cached_rg_flat_idx != rg_flat_idx: + file_idx, rg_idx, _num_rows = rg_index[rg_flat_idx] + pf = self._get_parquet_file(file_idx) + self._cached_table = pf.read_row_group(rg_idx, columns=list(_REQUIRED_COLUMNS)) + self._cached_rg_flat_idx = rg_flat_idx + + return { + col: self._cached_table.column(col)[row_within_rg].as_py() + for col in _REQUIRED_COLUMNS + } + + +class PackedParquetDataset(Dataset): + """Map-style dataset that reads pre-packed Parquet files in RFC format. + + Args: + data_path: Path to a Parquet file, glob pattern, directory, or list of + any of the above. Directories are scanned for ``*.idx.parquet`` / + ``*.idx.pq`` first, then ``*.parquet`` / ``*.pq``. + packed_sequence_size: Target pack length. Rows shorter than this are + padded; rows longer raise ``ValueError``. + padding_idx: Token ID used for padding ``input_ids``. + cp_size: Context-parallel size for CP-aware padding. + split: Accepted for config compatibility but unused (data is pre-split + in files). + tokenizer: Accepted for config compatibility but unused. + """ + + is_pre_packed: bool = True + + def __init__( + self, + data_path: Union[str, list[str]], + packed_sequence_size: int, + padding_idx: int = 0, + cp_size: int = 1, + split: str = "train", + tokenizer=None, + ): + self._data_path = data_path + self._packed_sequence_size = packed_sequence_size + self._padding_idx = padding_idx + self._cp_size = cp_size + + # Resolve file list eagerly (path resolution only, no file handles) + self._files = self._resolve_files(data_path) + if not self._files: + raise FileNotFoundError(f"No Parquet files found at: {data_path}") + + # Build row-group index from metadata (no row data read) + self._rg_index: list[tuple[int, int, int]] = [] # (file_idx, rg_idx, num_rows) + self._rg_cumulative: list[int] = [] # cumulative row count at start of each rg + self._total_rows = 0 + self._build_index() + + # Lazy reader state (created per-worker on first __getitem__) + self._reader_state: Optional[_ReaderState] = None + + # ------------------------------------------------------------------ + # File resolution + # ------------------------------------------------------------------ + + @staticmethod + def _resolve_files(data_path: Union[str, list[str]]) -> list[str]: + """Resolve *data_path* to a sorted, deduplicated list of Parquet paths.""" + if isinstance(data_path, (list, tuple)): + files: list[str] = [] + for p in data_path: + files.extend(PackedParquetDataset._resolve_files(p)) + return sorted(set(files)) + + path_str = str(data_path) + + # Glob pattern + if "*" in path_str or "?" in path_str: + return sorted(glob_module.glob(path_str)) + + p = Path(path_str) + if p.is_file(): + return [str(p)] + + if p.is_dir(): + # Prefer *.idx.parquet / *.idx.pq (RFC naming convention) + files = sorted(glob_module.glob(str(p / "*.idx.parquet"))) + files += sorted(glob_module.glob(str(p / "*.idx.pq"))) + if not files: + files = sorted(glob_module.glob(str(p / "*.parquet"))) + files += sorted(glob_module.glob(str(p / "*.pq"))) + return sorted(set(files)) + + return sorted(glob_module.glob(path_str)) + + # ------------------------------------------------------------------ + # Index building (metadata only) + # ------------------------------------------------------------------ + + def _build_index(self) -> None: + """Read Parquet metadata to build a cumulative row-group index.""" + import pyarrow.parquet as pq + + cumulative = 0 + schema_validated = False + + for file_idx, filepath in enumerate(self._files): + pf = pq.ParquetFile(filepath) + + # Validate schema on first file + if not schema_validated: + col_names = set(pf.schema_arrow.names) + missing = _REQUIRED_COLUMNS - col_names + if missing: + raise ValueError( + f"Parquet file {filepath} is missing required columns: {missing}. " + f"Expected columns: {_REQUIRED_COLUMNS}" + ) + schema_validated = True + + metadata = pf.metadata + for rg_idx in range(metadata.num_row_groups): + num_rows = metadata.row_group(rg_idx).num_rows + self._rg_index.append((file_idx, rg_idx, num_rows)) + self._rg_cumulative.append(cumulative) + cumulative += num_rows + + self._total_rows = cumulative + # Sentinel for bisect + self._rg_cumulative.append(cumulative) + + logger.info( + "PackedParquetDataset: %d file(s), %d row group(s), %d total rows", + len(self._files), + len(self._rg_index), + self._total_rows, + ) + + # ------------------------------------------------------------------ + # Dataset interface + # ------------------------------------------------------------------ + + def __len__(self) -> int: + return self._total_rows + + def __getitem__(self, idx: int) -> dict[str, list]: + if idx < 0: + idx += self._total_rows + if idx < 0 or idx >= self._total_rows: + raise IndexError(f"Index {idx} out of range [0, {self._total_rows})") + + # Lazy init (pickle-safe for DataLoader workers) + if self._reader_state is None: + self._reader_state = _ReaderState(self._files) + + rg_flat_idx, row_within_rg = self._locate_row(idx) + raw = self._reader_state.read_row(rg_flat_idx, row_within_rg, self._rg_index) + return self._convert_row(raw) + + # ------------------------------------------------------------------ + # Row location + # ------------------------------------------------------------------ + + def _locate_row(self, global_idx: int) -> tuple[int, int]: + """Map *global_idx* → (flat_rg_index, row_offset_within_rg).""" + rg_flat_idx = bisect.bisect_right(self._rg_cumulative, global_idx) - 1 + row_within_rg = global_idx - self._rg_cumulative[rg_flat_idx] + return rg_flat_idx, row_within_rg + + # ------------------------------------------------------------------ + # Format conversion + # ------------------------------------------------------------------ + + def _convert_row(self, raw: dict) -> dict[str, list]: + """Convert an RFC Parquet row to AutoModel packed format. + + Returns lists (not tensors) to match the contract of HuggingFace + ``Dataset.__getitem__``, which is what the ``packed_sequence_thd_collater`` + expects. + + Steps: + 1. labels = input_ids where loss_mask == 1, else -100 + 2. seq_lens = diffs between consecutive seq_start_id values + 3. position_ids = range resetting at each boundary + 4. Pad via ``_tensorize_and_pad_pack`` then convert to lists + """ + input_ids: list[int] = raw["input_ids"] + loss_mask: list[int] = raw["loss_mask"] + seq_start_id: list[int] = raw["seq_start_id"] + + n = len(input_ids) + + if n > self._packed_sequence_size: + raise ValueError( + f"Parquet row has {n} tokens but packed_sequence_size is " + f"{self._packed_sequence_size}. Increase packed_sequence_size or " + f"ensure data is pre-packed to fit." + ) + + # 1. labels + labels = [ + tok if mask == 1 else CROSS_ENTROPY_IGNORE_IDX + for tok, mask in zip(input_ids, loss_mask) + ] + + # 2. seq_lens + seq_lens: list[int] = [] + for i in range(len(seq_start_id)): + end = seq_start_id[i + 1] if i + 1 < len(seq_start_id) else n + seq_lens.append(end - seq_start_id[i]) + + # 3. position_ids (reset at each boundary) + position_ids = [0] * n + for i, start in enumerate(seq_start_id): + end = seq_start_id[i + 1] if i + 1 < len(seq_start_id) else n + for j in range(start, end): + position_ids[j] = j - start + + # 4. Pad and tensorize via existing helper + pack = { + "input_ids": input_ids, + "labels": labels, + "position_ids": position_ids, + "seq_lens": seq_lens, + } + + result = _tensorize_and_pad_pack( + pack, + padding_idx=self._padding_idx, + packed_sequence_size=self._packed_sequence_size, + cp_size=self._cp_size, + ) + + # Convert tensors to lists to match HF Dataset.__getitem__ contract. + # The packed_sequence_thd_collater expects lists, not tensors. + return {k: v.tolist() for k, v in result.items()} diff --git a/nemo_automodel/recipes/llm/train_ft.py b/nemo_automodel/recipes/llm/train_ft.py index f3482c406..d63941f42 100644 --- a/nemo_automodel/recipes/llm/train_ft.py +++ b/nemo_automodel/recipes/llm/train_ft.py @@ -484,7 +484,7 @@ def build_dataloader( # Apply packing if configured # Apply packing if configured - if packed_sequence_size > 0: + if packed_sequence_size > 0 and not getattr(ds, "is_pre_packed", False): logger.info(f"Packing dataset with size: {packed_sequence_size}") if hasattr(ds, "shuffle"): ds = ds.shuffle(seed) diff --git a/tests/unit_tests/datasets/llm/test_packed_parquet_dataset.py b/tests/unit_tests/datasets/llm/test_packed_parquet_dataset.py new file mode 100644 index 000000000..af947a199 --- /dev/null +++ b/tests/unit_tests/datasets/llm/test_packed_parquet_dataset.py @@ -0,0 +1,376 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from nemo_automodel.components.datasets.llm.packed_parquet_dataset import PackedParquetDataset + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _write_parquet(path, rows, row_group_size=None): + """Write a list of row dicts to a Parquet file.""" + cols = { + "input_ids": [r["input_ids"] for r in rows], + "loss_mask": [r["loss_mask"] for r in rows], + "seq_start_id": [r["seq_start_id"] for r in rows], + } + table = pa.table( + { + "input_ids": pa.array(cols["input_ids"], type=pa.list_(pa.int32())), + "loss_mask": pa.array(cols["loss_mask"], type=pa.list_(pa.uint8())), + "seq_start_id": pa.array(cols["seq_start_id"], type=pa.list_(pa.int32())), + } + ) + kwargs = {} + if row_group_size is not None: + kwargs["row_group_size"] = row_group_size + pq.write_table(table, str(path), **kwargs) + + +# A simple pack: 8 tokens, 3 sequences at positions [0, 3, 6] +SIMPLE_ROW = { + "input_ids": [10, 20, 30, 40, 50, 60, 70, 80], + "loss_mask": [1, 1, 0, 1, 1, 1, 0, 0], + "seq_start_id": [0, 3, 6], +} + + +@pytest.fixture +def single_file(tmp_path): + """Single Parquet file with one row.""" + path = tmp_path / "data.idx.parquet" + _write_parquet(path, [SIMPLE_ROW]) + return tmp_path, path + + +@pytest.fixture +def multi_row_file(tmp_path): + """Single Parquet file with multiple rows.""" + rows = [ + SIMPLE_ROW, + { + "input_ids": [1, 2, 3, 4, 5], + "loss_mask": [1, 1, 1, 1, 1], + "seq_start_id": [0, 3], + }, + { + "input_ids": [100, 200], + "loss_mask": [0, 1], + "seq_start_id": [0], + }, + ] + path = tmp_path / "data.idx.parquet" + _write_parquet(path, rows) + return tmp_path, path, rows + + +@pytest.fixture +def multi_file_dir(tmp_path): + """Directory with two Parquet shard files.""" + rows_a = [SIMPLE_ROW] + rows_b = [ + { + "input_ids": [1, 2, 3], + "loss_mask": [1, 1, 1], + "seq_start_id": [0], + } + ] + _write_parquet(tmp_path / "shard_000.idx.parquet", rows_a) + _write_parquet(tmp_path / "shard_001.idx.parquet", rows_b) + return tmp_path + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestBasicRead: + def test_len(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + assert len(ds) == 1 + + def test_output_keys(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + sample = ds[0] + assert set(sample.keys()) == {"input_ids", "labels", "position_ids", "seq_lens", "seq_lens_padded"} + + def test_input_ids_padded(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + sample = ds[0] + ids = sample["input_ids"] + assert ids[:8] == [10, 20, 30, 40, 50, 60, 70, 80] + assert ids[8:] == [0, 0] # padded with padding_idx=0 + assert len(ids) == 10 + + def test_values_are_lists(self, single_file): + """Values are lists (matching HF Dataset contract) for collater compatibility.""" + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + sample = ds[0] + for key in ("input_ids", "labels", "position_ids", "seq_lens", "seq_lens_padded"): + assert isinstance(sample[key], list), f"{key} should be a list" + + +class TestLossMaskToLabels: + def test_masked_positions_get_ignore_idx(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + labels = ds[0]["labels"] + # loss_mask = [1, 1, 0, 1, 1, 1, 0, 0] → positions 2, 6, 7 → -100 + assert labels[0] == 10 + assert labels[1] == 20 + assert labels[2] == -100 # loss_mask=0 + assert labels[3] == 40 + assert labels[4] == 50 + assert labels[5] == 60 + assert labels[6] == -100 # loss_mask=0 + assert labels[7] == -100 # loss_mask=0 + # Padding also gets -100 + assert labels[8] == -100 + assert labels[9] == -100 + + +class TestSeqStartIdToSeqLens: + def test_seq_lens(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + seq_lens = ds[0]["seq_lens"] + # seq_start_id = [0, 3, 6], len = 8 → seq_lens = [3, 3, 2] + assert seq_lens == [3, 3, 2] + + def test_single_sequence(self, tmp_path): + """Single sequence spanning entire pack.""" + row = {"input_ids": [1, 2, 3, 4], "loss_mask": [1, 1, 1, 1], "seq_start_id": [0]} + path = tmp_path / "data.parquet" + _write_parquet(path, [row]) + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=6) + assert ds[0]["seq_lens"] == [4] + + +class TestPositionIds: + def test_reset_at_boundaries(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + pos = ds[0]["position_ids"] + # seq_start_id = [0, 3, 6]: three sequences of lengths 3, 3, 2 + # Positions reset at each boundary: + assert pos[:3] == [0, 1, 2] # seq 1 + assert pos[3:6] == [0, 1, 2] # seq 2 + assert pos[6:8] == [0, 1] # seq 3 + + +class TestPadding: + def test_seq_lens_padded_sum_equals_pack_size(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + seq_lens_padded = ds[0]["seq_lens_padded"] + assert sum(seq_lens_padded) == 10 + + def test_seq_lens_padded_last_includes_padding(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + seq_lens = ds[0]["seq_lens"] + seq_lens_padded = ds[0]["seq_lens_padded"] + # Non-last elements are the same + assert seq_lens_padded[:-1] == seq_lens[:-1] + # Last element includes pack padding + assert seq_lens_padded[-1] >= seq_lens[-1] + + def test_exact_fit_no_padding(self, tmp_path): + """When row exactly fills pack_size, no padding needed.""" + row = {"input_ids": [1, 2, 3, 4, 5], "loss_mask": [1, 1, 1, 1, 1], "seq_start_id": [0, 3]} + path = tmp_path / "data.parquet" + _write_parquet(path, [row]) + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=5) + sample = ds[0] + assert sample["input_ids"] == [1, 2, 3, 4, 5] + assert sample["seq_lens"] == [3, 2] + assert sample["seq_lens_padded"] == [3, 2] + + +class TestOversizedRow: + def test_raises_value_error(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=5) + with pytest.raises(ValueError, match="packed_sequence_size"): + ds[0] + + +class TestMultiRow: + def test_len(self, multi_row_file): + _, path, rows = multi_row_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + assert len(ds) == 3 + + def test_independent_getitem(self, multi_row_file): + _, path, rows = multi_row_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + # Each row has different input_ids + s0 = ds[0]["input_ids"][:8] + s1 = ds[1]["input_ids"][:5] + s2 = ds[2]["input_ids"][:2] + assert s0 == [10, 20, 30, 40, 50, 60, 70, 80] + assert s1 == [1, 2, 3, 4, 5] + assert s2 == [100, 200] + + +class TestMultiFile: + def test_directory_resolution(self, multi_file_dir): + ds = PackedParquetDataset(data_path=str(multi_file_dir), packed_sequence_size=10) + assert len(ds) == 2 + + def test_cross_file_reads(self, multi_file_dir): + ds = PackedParquetDataset(data_path=str(multi_file_dir), packed_sequence_size=10) + s0 = ds[0]["input_ids"][:8] + s1 = ds[1]["input_ids"][:3] + assert s0 == [10, 20, 30, 40, 50, 60, 70, 80] + assert s1 == [1, 2, 3] + + def test_glob_pattern(self, multi_file_dir): + pattern = str(multi_file_dir / "shard_*.idx.parquet") + ds = PackedParquetDataset(data_path=pattern, packed_sequence_size=10) + assert len(ds) == 2 + + +class TestErrorHandling: + def test_no_files_raises_file_not_found(self, tmp_path): + with pytest.raises(FileNotFoundError): + PackedParquetDataset(data_path=str(tmp_path / "nonexistent"), packed_sequence_size=10) + + def test_missing_columns_raises_value_error(self, tmp_path): + table = pa.table({"input_ids": pa.array([[1, 2, 3]], type=pa.list_(pa.int32()))}) + path = tmp_path / "bad.parquet" + pq.write_table(table, str(path)) + with pytest.raises(ValueError, match="missing required columns"): + PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + + def test_negative_indexing(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + last = ds[-1] + first = ds[0] + assert last["input_ids"] == first["input_ids"] + def test_index_out_of_range(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + with pytest.raises(IndexError): + ds[1] + + +class TestPickleSafety: + def test_pickle_roundtrip(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + # Pickle and unpickle (simulates DataLoader worker spawn) + ds2 = pickle.loads(pickle.dumps(ds)) + sample = ds2[0] + assert sample["input_ids"][:8] == [10, 20, 30, 40, 50, 60, 70, 80] + + +class TestCPAwarePadding: + def test_cp_size_2(self, tmp_path): + """With cp_size=2, seq_lens_padded values should be divisible by 2*cp_size=4.""" + row = { + "input_ids": [1, 2, 3, 4, 5, 6, 7], + "loss_mask": [1, 1, 1, 1, 1, 1, 1], + "seq_start_id": [0, 3], + } + path = tmp_path / "data.parquet" + _write_parquet(path, [row]) + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=16, cp_size=2) + sample = ds[0] + seq_lens = sample["seq_lens"] + seq_lens_padded = sample["seq_lens_padded"] + # Original: [3, 4] + assert seq_lens == [3, 4] + # _pad_pack applies CP rounding then adds pack-level padding to last seq. + # CP divisibility_factor = 2 * cp_size = 4 + # Seq 1: 3 → 4 (rounded up) + # Seq 2: 4 → 4 (already divisible) + 9 pack padding = 13 + # The last element absorbs pack padding so is NOT necessarily CP-divisible. + assert seq_lens_padded[0] == 4 + # Sum is 4 + 13 = 17 (not pack_size) because CP rounding expands + # individual seq_lens beyond the actual token count, and pack padding + # is added on top of the last CP-rounded length. + assert seq_lens_padded == [4, 13] + + +class TestIsPrePacked: + def test_attribute_exists(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + assert ds.is_pre_packed is True + + def test_getattr_works(self, single_file): + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + assert getattr(ds, "is_pre_packed", False) is True + + +class TestRowGroupCaching: + def test_multiple_reads_same_rg(self, tmp_path): + """Multiple rows in same row group should use cached read.""" + rows = [ + {"input_ids": [1, 2, 3], "loss_mask": [1, 1, 1], "seq_start_id": [0]}, + {"input_ids": [4, 5, 6], "loss_mask": [1, 1, 1], "seq_start_id": [0]}, + {"input_ids": [7, 8, 9], "loss_mask": [1, 1, 1], "seq_start_id": [0]}, + ] + path = tmp_path / "data.parquet" + _write_parquet(path, rows, row_group_size=3) + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=5) + # Read all three rows + assert ds[0]["input_ids"][:3] == [1, 2, 3] + assert ds[1]["input_ids"][:3] == [4, 5, 6] + assert ds[2]["input_ids"][:3] == [7, 8, 9] + # Verify internal cache hit (same rg_flat_idx) + assert ds._reader_state._cached_rg_flat_idx == 0 + + +class TestCollaterCompatibility: + def test_works_with_thd_collater(self, single_file): + """End-to-end: DataLoader batch via packed_sequence_thd_collater.""" + from nemo_automodel.components.datasets.utils import packed_sequence_thd_collater + + _, path = single_file + ds = PackedParquetDataset(data_path=str(path), packed_sequence_size=10) + batch = packed_sequence_thd_collater([ds[0]]) + assert "input_ids" in batch + assert "labels" in batch + assert "position_ids" in batch + assert "seq_lens" in batch + assert "seq_lens_padded" in batch + assert batch["qkv_format"] == "thd" + assert batch["input_ids"].shape == (1, 10) + + +class TestListPath: + def test_list_of_files(self, multi_file_dir): + """Support list of file paths.""" + files = [ + str(multi_file_dir / "shard_000.idx.parquet"), + str(multi_file_dir / "shard_001.idx.parquet"), + ] + ds = PackedParquetDataset(data_path=files, packed_sequence_size=10) + assert len(ds) == 2