diff --git a/docs/zh/get_started/hf-datasets.md b/docs/zh/get_started/hf-datasets.md new file mode 100644 index 000000000..536d0dbd0 --- /dev/null +++ b/docs/zh/get_started/hf-datasets.md @@ -0,0 +1,247 @@ +# HuggingFace Datasets 集成 + +本文档介绍如何使用 HuggingFace Datasets 加载大规模数据集(100GB+)进行训练。 + +## 1. 快速开始 + +### 何时使用 + +| 场景 | 推荐方案 | +|------|---------| +| 数据集 > 10GB | **HF Datasets**(流式加载,内存 < 1GB) | +| 数据集 < 10GB | Legacy Dataset(默认,全量加载) | + +### 基本用法(RL 训练) + +```bash +python train.py \ + --use-hf-datasets \ + --hf-datasets-num-samples 17000 \ + --prompt-data zhuzilin/dapo-math-17k \ + --rollout-batch-size 32 \ + --num-rollout 100 +``` + +### SFT 训练用法 + +```bash +python train.py \ + --use-hf-datasets \ + --hf-datasets-num-samples 335122 \ + --prompt-data nvidia/Nemotron-Agentic-v1 \ + --hf-dataset-split interactive_agent \ + --input-key messages \ + --tool-key tools \ + --rollout-function-path slime.rollout.sft_rollout.generate_rollout \ + --loss-type sft_loss \ + --calculate-per-token-loss \ + --disable-compute-advantages-and-returns \ + --rollout-batch-size 128 \ + --num-rollout 60 +``` + +**必需参数**: +- `--use-hf-datasets`:启用 HF Datasets 流式模式 +- `--hf-datasets-num-samples`:数据集样本数(用于 epoch 追踪) + +### 支持的数据格式 + +| 格式 | 示例 | +|------|------| +| HuggingFace Hub | `zhuzilin/dapo-math-17k` | +| HuggingFace Hub (指定 split) | `nvidia/Nemotron-Agentic-v1` + `--hf-dataset-split interactive_agent` | +| 本地 JSONL | `/path/to/data.jsonl` | +| 本地 Parquet | `/path/to/data.parquet` | + +--- + +## 2. 参数详解 + +### 基础参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--use-hf-datasets` | `False` | 启用 HF Datasets 流式模式 | +| `--hf-datasets-num-samples` | **必需** | 数据集样本数(用于 epoch 边界计算) | +| `--hf-dataset-split` | `train` | 数据集 split 名称(如 `train`、`interactive_agent`) | +| `--hf-dataset-shuffle-buffer` | `10000` | Shuffle buffer 大小 | +| `--hf-dataset-buffer-size` | `100` | 预取 buffer 大小 | +| `--hf-dataset-num-proc` | `8` | DataLoader worker 数量 | + +### SFT 相关参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--rollout-function-path` | - | SFT 使用 `slime.rollout.sft_rollout.generate_rollout` | +| `--loss-type` | `policy_loss` | 损失类型,SFT 使用 `sft_loss` | +| `--calculate-per-token-loss` | `False` | 按 token 计算损失(推荐 SFT 开启) | +| `--disable-compute-advantages-and-returns` | `False` | 禁用 advantage 计算(SFT 必须开启) | +| `--input-key` | `input` | 输入字段名(多轮对话使用 `messages`) | +| `--tool-key` | - | 工具定义字段名(如 `tools`) | + +### 数据格式 + +**JSONL 格式**(每行一个 JSON 对象): + +```json +{"input": "问题文本", "label": "答案文本", "metadata": {"sample_id": 0}} +``` + +**字段映射**: +- `--input-key input`:输入字段名 +- `--label-key label`:标签字段名 +- `--metadata-key metadata`:元数据字段名 + +**Chat Template 格式**(需配合 `--apply-chat-template`): + +```json +{"input": [{"role": "user", "content": "问题"}], "label": "答案"} +``` + +**多轮对话格式**(SFT 模式,配合 `--input-key messages`): + +```json +{ + "messages": [ + {"role": "system", "content": "你是一个助手"}, + {"role": "user", "content": "你好"}, + {"role": "assistant", "content": "你好!有什么可以帮助你的?"}, + {"role": "user", "content": "请解释一下 Python"}, + {"role": "assistant", "content": "Python 是一种高级编程语言..."} + ], + "tools": [{"type": "function", "function": {"name": "search", "description": "搜索信息"}}] +} +``` + +**多轮对话 loss_mask 说明**: +- SFT 模式下,只对 assistant 的回复计算损失 +- `loss_mask` 中 1 表示需要计算损失的 token,0 表示跳过 +- 多轮对话中,中间的 user/system 轮次会被标记为 0 + +--- + +## 3. Checkpoint 支持 + +### 保存 + +训练时自动保存数据集状态: + +```bash +python train.py --use-hf-datasets --save /path/to/ckpt --save-interval 50 +``` + +### 恢复 + +从 checkpoint 继续训练: + +```bash +python train.py --use-hf-datasets --load /path/to/ckpt +``` + +### 状态内容 + +Checkpoint 包含以下数据集状态: + +```python +{ + "epoch_id": 2, # 当前 epoch + "consumed_count": 15234, # 当前 epoch 已消费样本数 + "global_consumed_count": 45234, # 全局已消费样本数 + "hf_state_dict": {...} # HF 原生迭代器状态 +} +``` + +--- + +## 4. 故障排查 + +### 问题 1: ValueError: --hf-datasets-num-samples is required + +**原因**:使用 `--use-hf-datasets` 时必须指定样本数 + +**解决**:添加 `--hf-datasets-num-samples <数量>` + +### 问题 2: 训练卡住/数据加载慢 + +**原因**:DataLoader worker 不足 + +**解决**:增加 `--hf-dataset-num-proc 16` + +### 问题 3: Dataset exhausted while skipping + +**原因**:Checkpoint 损坏或 epoch 边界错误 + +**解决**: +```bash +rm -rf /path/to/ckpt/rollout/global_dataset_state_dict_*.pt +``` + +--- + +## 5. 开发者参考 + +### 架构概述 + +``` +RolloutDataSource + └── HFIterableDatasetAdapter + └── PyTorch DataLoader + └── HuggingFace IterableDataset +``` + +**核心设计**: +- 使用 PyTorch DataLoader 进行多进程预取 +- 使用 HF 原生 `state_dict()` / `load_state_dict()` 支持 checkpoint +- 使用 HF `set_epoch()` 实现可复现的 shuffle + +### SFT Rollout 函数 + +SFT 模式使用专用的 rollout 函数 `slime.rollout.sft_rollout.generate_rollout`: + +```python +# 核心逻辑 +def generate_rollout(args, rollout_id, *, evaluation=False): + # 1. 获取批量数据 + # 2. 应用 chat template,生成 loss_mask + # 3. 返回 Sample 列表(无需生成,直接使用数据集内容) + pass +``` + +**loss_mask 计算**: +- 对于多轮对话,只对 assistant 回复部分设置 mask=1 +- `response_length` 计算为从第一个 mask=1 位置到序列末尾的长度 + +### 添加新 Backend + +实现以下接口的类: + +```python +class MyDatasetAdapter: + """自定义数据集适配器需要实现以下方法""" + + def get_next_batch(self, num_samples: int) -> list[Sample]: + """返回下一批样本""" + pass + + def shuffle(self, new_epoch_id: int): + """基于 epoch_id 的 shuffle""" + pass + + def get_checkpoint_state(self) -> dict: + """获取 checkpoint 状态""" + pass + + def load_checkpoint_state(self, state: dict): + """恢复 checkpoint 状态""" + pass +``` + +--- + +## 参考资料 + +- [HuggingFace Datasets 文档](https://huggingface.co/docs/datasets) +- 源码:`slime/utils/hf_dataset.py` +- SFT Rollout:`slime/rollout/sft_rollout.py` +- 测试:`tests/test_hf_datasets.py` +- SFT 测试:`tests/test_qwen3-0.6B_hf_datasets_sft.py` diff --git a/slime/backends/training_utils/cp_utils.py b/slime/backends/training_utils/cp_utils.py index 7d3f4b3e1..bfc23ee08 100644 --- a/slime/backends/training_utils/cp_utils.py +++ b/slime/backends/training_utils/cp_utils.py @@ -71,7 +71,7 @@ def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: return sum( [ (x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1) - for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False) + for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=True) ] ) @@ -79,7 +79,7 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: return sum( [ (x_i * loss_mask_i).sum() - for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False) + for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=True) ] ) @@ -87,7 +87,7 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: cp_chunk_lengths = [] chunked_loss_masks = [] for i, (total_length, response_length, loss_mask) in enumerate( - zip(total_lengths, response_lengths, loss_masks, strict=False) + zip(total_lengths, response_lengths, loss_masks, strict=True) ): max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None prompt_length = total_length - response_length @@ -104,7 +104,7 @@ def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: [ (x_i * chunked_loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) for x_i, chunked_loss_mask, loss_mask in zip( - x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, loss_masks, strict=False + x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, loss_masks, strict=True ) ] ) @@ -114,7 +114,7 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: [ (x_i * chunked_loss_mask).sum() for x_i, chunked_loss_mask in zip( - x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, strict=False + x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, strict=True ) ] ) diff --git a/slime/rollout/data_source.py b/slime/rollout/data_source.py index fa14c65f0..633b90433 100644 --- a/slime/rollout/data_source.py +++ b/slime/rollout/data_source.py @@ -52,42 +52,120 @@ def __init__(self, args): # TODO remove this self.metadata = {} + self._dataset = None + self._tokenizer = None + self._processor = None + self._use_hf_datasets = getattr(args, "use_hf_datasets", False) + + # Initialize dataset if using global dataset if args.rollout_global_dataset: - tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) - processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + self._tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + self._processor = load_processor(args.hf_checkpoint, trust_remote_code=True) # TODO move (during the refactor) if (d := args.dump_details) is not None: - tokenizer.save_pretrained(Path(d) / "tokenizer") - if processor: - processor.save_pretrained(Path(d) / "processor") - - self.dataset = Dataset( - args.prompt_data, - tokenizer=tokenizer, - processor=processor, - max_length=args.rollout_max_prompt_len, - prompt_key=args.input_key, - multimodal_keys=args.multimodal_keys, - label_key=args.label_key, - metadata_key=args.metadata_key, - tool_key=args.tool_key, - apply_chat_template=args.apply_chat_template, - apply_chat_template_kwargs=args.apply_chat_template_kwargs, - seed=args.rollout_seed, + self._tokenizer.save_pretrained(Path(d) / "tokenizer") + if self._processor: + self._processor.save_pretrained(Path(d) / "processor") + + # Create dataset immediately + self._create_dataset() + + def _create_dataset(self): + """Create dataset based on configuration. + + Selects the appropriate dataset implementation based on args: + - HF Datasets (streaming mode) if --use-hf-datasets is set + - Legacy Dataset otherwise + + Note: RolloutManager is a single instance, so we do NOT shard by dp_rank. + Data sharding happens in RolloutManager._split_train_data_by_dp(). + """ + if self._use_hf_datasets: + # Use HuggingFace Datasets streaming mode + from slime.utils.hf_dataset import HFIterableDatasetAdapter + + # Get dataset size (required for proper epoch tracking) + dataset_size = getattr(self.args, "hf_datasets_num_samples", None) + if dataset_size is None: + raise ValueError( + "--hf-datasets-num-samples is required when using --use-hf-datasets. " + "This specifies the number of samples for proper epoch tracking and __len__() support." + ) + + logger.info(f"Creating HFIterableDatasetAdapter (streaming mode, dataset_size={dataset_size})") + self._dataset = HFIterableDatasetAdapter( + path=self.args.prompt_data, + dataset_size=dataset_size, + tokenizer=self._tokenizer, + processor=self._processor, + max_length=self.args.rollout_max_prompt_len, + prompt_key=self.args.input_key, + label_key=self.args.label_key, + tool_key=self.args.tool_key, + metadata_key=self.args.metadata_key, + multimodal_keys=self.args.multimodal_keys, + seed=self.args.rollout_seed, + apply_chat_template=self.args.apply_chat_template, + apply_chat_template_kwargs=self.args.apply_chat_template_kwargs, + num_workers=getattr(self.args, "hf_dataset_num_proc", 4), + prefetch_factor=2, + shuffle_buffer_size=getattr(self.args, "hf_dataset_shuffle_buffer", 10000), + do_shuffle=self.args.rollout_shuffle, + split=getattr(self.args, "hf_dataset_split", "train"), ) - if self.args.rollout_shuffle: - self.dataset.shuffle(self.epoch_id) + + # Note: shuffle is handled by do_shuffle parameter, no need to call shuffle() separately + else: - self.dataset = None + # Use legacy Dataset implementation + logger.info("Creating legacy Dataset") + self._dataset = Dataset( + self.args.prompt_data, + tokenizer=self._tokenizer, + processor=self._processor, + max_length=self.args.rollout_max_prompt_len, + prompt_key=self.args.input_key, + multimodal_keys=self.args.multimodal_keys, + label_key=self.args.label_key, + metadata_key=self.args.metadata_key, + tool_key=self.args.tool_key, + apply_chat_template=self.args.apply_chat_template, + apply_chat_template_kwargs=self.args.apply_chat_template_kwargs, + seed=self.args.rollout_seed, + ) + + # Apply initial shuffle if requested + if self.args.rollout_shuffle: + self._dataset.shuffle(self.epoch_id) + + @property + def dataset(self): + """Accessor for dataset.""" + return self._dataset def get_samples(self, num_samples): - # TODO further improve code - if self.dataset is not None: + # Mixed mode: auto-detect dataset type using duck typing + if self.dataset is None: + # Case 1: No dataset (--disable-rollout-global-dataset) + prompt_samples = [Sample() for _ in range(num_samples)] + + elif hasattr(self.dataset, "get_next_batch"): + # Case 2: HF adapters - use streaming interface + # Note: HF adapters handle epoch switching internally + prompt_samples = self.dataset.get_next_batch(num_samples) + + # Sync epoch_id from HF adapter (it handles epoch switching internally) + if hasattr(self.dataset, "epoch_id"): + self.epoch_id = self.dataset.epoch_id + + else: + # Case 3: Legacy Dataset - use array access if self.sample_offset + num_samples <= len(self.dataset): prompt_samples = self.dataset.samples[self.sample_offset : self.sample_offset + num_samples] self.sample_offset += num_samples else: + # Handle epoch boundary prompt_samples = self.dataset.samples[self.sample_offset :] num_samples -= len(prompt_samples) self.epoch_id += 1 @@ -95,9 +173,8 @@ def get_samples(self, num_samples): self.dataset.shuffle(self.epoch_id) prompt_samples += self.dataset.samples[:num_samples] self.sample_offset = num_samples - else: - prompt_samples = [Sample() for _ in range(num_samples)] + # Common processing: wrap prompt_samples into groups samples = [] for prompt_sample in prompt_samples: group = [] @@ -125,6 +202,11 @@ def save(self, rollout_id): "sample_index": self.sample_index, "metadata": self.metadata, } + + # Save HF adapter state if using HF Datasets + if self.dataset is not None and hasattr(self.dataset, "get_checkpoint_state"): + state_dict["hf_adapter_state"] = self.dataset.get_checkpoint_state() + path = os.path.join(self.args.save, f"rollout/global_dataset_state_dict_{rollout_id}.pt") os.makedirs(os.path.dirname(path), exist_ok=True) torch.save(state_dict, path) @@ -142,7 +224,6 @@ def load(self, rollout_id=None): return logger.info(f"load metadata from {path}") - logger.info(f"load metadata: {self.metadata}") state_dict = torch.load(path) self.sample_offset = state_dict.get("sample_offset", 0) self.epoch_id = state_dict.get("epoch_id", 0) @@ -150,8 +231,19 @@ def load(self, rollout_id=None): self.sample_index = state_dict.get("sample_index", 0) self.metadata = state_dict.get("metadata", {}) - if self.args.rollout_global_dataset and self.args.rollout_shuffle: - self.dataset.shuffle(self.epoch_id) + # Restore dataset state based on type (mixed mode) + if self.dataset is not None: + if hasattr(self.dataset, "load_checkpoint_state"): + # HF adapters: use dedicated checkpoint API + hf_state = state_dict.get("hf_adapter_state") + if hf_state: + logger.info( + f"Restoring HF adapter state: epoch={hf_state.get('epoch_id')}, consumed={hf_state.get('consumed_count')}" + ) + self.dataset.load_checkpoint_state(hf_state) # type: ignore[attr-defined] + elif self.args.rollout_shuffle: + # Legacy Dataset: manual shuffle + self.dataset.shuffle(self.epoch_id) class RolloutDataSourceWithBuffer(RolloutDataSource): diff --git a/slime/rollout/sft_rollout.py b/slime/rollout/sft_rollout.py index 6b914a964..9328e13a7 100644 --- a/slime/rollout/sft_rollout.py +++ b/slime/rollout/sft_rollout.py @@ -48,12 +48,20 @@ def generate_rollout(args, rollout_id, data_buffer, evaluation=False): token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages, tools=tools) - response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0] + # Calculate response_length = length after first 1 (prompt excluded) + # This ensures loss_mask[-response_length:] gets the correct mask including + # intermediate 0s for user turns in multi-turn conversations + try: + first_one_index = loss_mask.index(1) + response_length = len(loss_mask) - first_one_index + except ValueError: + # No 1 in loss_mask means no tokens need loss computation + response_length = 0 sample.tokens = token_ids sample.response_length = response_length sample.reward = 0 - sample.loss_mask = loss_mask[-response_length:] + sample.loss_mask = loss_mask[-response_length:] if response_length > 0 else [] if i == 0 and not SAMPLE_PRINTED: logger.info( diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 429607d91..29454db49 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -535,6 +535,69 @@ def add_data_arguments(parser): ), ) + # HuggingFace Datasets Integration + parser.add_argument( + "--use-hf-datasets", + action="store_true", + default=False, + help=( + "Enable HuggingFace Datasets integration for efficient loading of large-scale datasets (100GB+). " + "Uses streaming mode with zero memory overhead and prefetch buffer for high throughput." + ), + ) + parser.add_argument( + "--hf-dataset-buffer-size", + type=int, + default=1000, + help=( + "Prefetch buffer size for streaming mode. " + "Larger buffer improves training throughput but increases memory usage. " + "Default: 1000." + ), + ) + parser.add_argument( + "--hf-dataset-shuffle-buffer", + type=int, + default=10000, + help=( + "Shuffle buffer size for HuggingFace Datasets streaming mode. " + "Larger buffer improves shuffle randomness but increases memory usage. " + "Default: 10000 (industry standard for large-scale datasets)." + ), + ) + parser.add_argument( + "--hf-dataset-num-proc", + type=int, + default=8, + help=( + "Number of parallel workers for preprocessing when using HuggingFace Datasets. " + "Applies to both streaming and cached modes. " + "Increase this value if you have more CPU cores available. " + "Default: 8." + ), + ) + parser.add_argument( + "--hf-datasets-num-samples", + type=int, + default=None, + help=( + "Number of samples in the HuggingFace streaming dataset. " + "Required when using --use-hf-datasets for proper epoch tracking. " + "This enables __len__() support and deterministic epoch boundaries." + ), + ) + parser.add_argument( + "--hf-dataset-split", + type=str, + default="train", + help=( + "Split name to use when loading HuggingFace datasets. " + "For HF Hub datasets, specifies which split to load (e.g., 'train', 'test'). " + "Some datasets like 'nvidia/Nemotron-Agentic-v1' have custom splits " + "like 'interactive_agent' or 'tool_calling'. Default: 'train'." + ), + ) + parser.add_argument( "--start-rollout-id", type=int, diff --git a/slime/utils/hf_dataset.py b/slime/utils/hf_dataset.py new file mode 100644 index 000000000..42e6a2db3 --- /dev/null +++ b/slime/utils/hf_dataset.py @@ -0,0 +1,426 @@ +"""HuggingFace Datasets streaming adapter for large-scale datasets (100GB+). + +This module provides a streaming dataset adapter using HuggingFace Datasets library, +enabling efficient loading of large-scale datasets without exhausting memory. + +Key Features: +- Streaming mode: Zero memory overhead, suitable for 100GB+ datasets +- Reproducible shuffling with epoch-based seeds (via HF's set_epoch) +- Checkpoint support using HF's native state_dict/load_state_dict +- PyTorch DataLoader integration for multi-process prefetching + +Architecture Note: +- Used by RolloutDataSource (single instance) +- Generates global data (not sharded by dp_rank) +- Data sharding happens in RolloutManager._split_train_data_by_dp() +""" + +import json +import logging + +import numpy as np +from datasets import IterableDataset as HFIterableDataset +from datasets import load_dataset +from torch.utils.data import DataLoader +from torch.utils.data import IterableDataset as TorchIterableDataset + +from slime.utils.types import Sample + +logger = logging.getLogger(__name__) + + +class HFIterableDatasetAdapter: + """Streaming HF Dataset adapter with checkpoint support. + + Enables loading and processing large datasets (100GB+) without loading + everything into memory. Uses HuggingFace's streaming mode combined with + PyTorch DataLoader for multi-process prefetching. + + Uses HF's native checkpoint support: + - state_dict() / load_state_dict() for efficient save/resume + - set_epoch() for automatic reshuffling (effective_seed = seed + epoch) + - shuffle(seed, buffer_size) for fast approximate shuffling + + Key Design: + - RolloutManager is a single instance, generates global data + - dataset_size is required for __len__() support and epoch tracking + - Sequential consumption only: No random access, only get_next_batch() + + VERIFIED: HF's state_dict enables exact position resume without sample skipping. + See tests/test_hf_datasets.py::TestHFStateTracking for verification tests. + + Args: + path: Dataset path (local JSONL/Parquet or HF hub) + dataset_size: Known dataset size (required for epoch tracking) + tokenizer: HuggingFace tokenizer + processor: Optional multimodal processor + max_length: Max prompt length for filtering + prompt_key: Key for prompt in raw data (default: "text") + label_key: Key for label in raw data + tool_key: Key for tools in raw data + metadata_key: Key for metadata (default: "metadata") + multimodal_keys: Mapping of multimodal types to keys + seed: Random seed for shuffle (default: 42) + apply_chat_template: Whether to apply chat template (default: False) + apply_chat_template_kwargs: Additional kwargs for chat template + num_workers: Number of DataLoader workers (default: 4) + prefetch_factor: Prefetch factor per worker (default: 2) + shuffle_buffer_size: Buffer size for HF shuffle (default: 10000) + do_shuffle: Whether to enable shuffling (default: True) + split: Dataset split name (default: "train") + """ + + def __init__( + self, + path: str, + dataset_size: int, + tokenizer, + processor, + max_length: int | None, + *, + prompt_key: str = "text", + label_key: str | None = None, + tool_key: str | None = None, + metadata_key: str = "metadata", + multimodal_keys: dict | None = None, + seed: int = 42, + apply_chat_template: bool = False, + apply_chat_template_kwargs: dict | None = None, + num_workers: int = 4, + prefetch_factor: int = 2, + shuffle_buffer_size: int = 10000, + do_shuffle: bool = True, + split: str = "train", + ): + self.path = path + self.dataset_size = dataset_size + self.tokenizer = tokenizer + self.processor = processor + self.max_length = max_length + self.prompt_key = prompt_key + self.label_key = label_key + self.tool_key = tool_key + self.metadata_key = metadata_key + self.multimodal_keys = multimodal_keys + self.seed = seed + self.apply_chat_template = apply_chat_template + self.apply_chat_template_kwargs = apply_chat_template_kwargs or {} + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.shuffle_buffer_size = shuffle_buffer_size + self.do_shuffle = do_shuffle + self.split = split + + # State tracking + self.epoch_id = 0 + self.consumed_count = 0 # Samples consumed in current epoch + self.global_consumed_count = 0 # Total samples consumed across all epochs + + # Load and process HF dataset + self.hf_dataset = self._load_and_process_dataset() + + # Apply shuffle at creation time + if do_shuffle: + self.hf_dataset = self.hf_dataset.shuffle(seed=seed, buffer_size=shuffle_buffer_size) + + # Create DataLoader + self.dataloader = DataLoader( + _HFDatasetWrapper(self.hf_dataset, dataset_size), + batch_size=None, # Return individual samples (we batch ourselves) + num_workers=num_workers, + prefetch_factor=prefetch_factor if num_workers > 0 else None, + persistent_workers=num_workers > 0, + ) + self._iter = None + + logger.info( + f"HFIterableDatasetAdapter initialized: " + f"path={path}, dataset_size={dataset_size}, " + f"num_workers={num_workers}, shuffle_buffer={shuffle_buffer_size}" + ) + + def __len__(self) -> int: + return self.dataset_size + + def _load_and_process_dataset(self) -> HFIterableDataset: + """Load base dataset and apply processing pipeline.""" + logger.info(f"Loading dataset from {self.path} (streaming mode, split={self.split})") + + # Determine file type and load + if self.path.endswith(".jsonl"): + dataset = load_dataset("json", data_files=self.path, split=self.split, streaming=True) + elif self.path.endswith(".parquet"): + dataset = load_dataset("parquet", data_files=self.path, split=self.split, streaming=True) + else: + # Try as HF dataset name + try: + dataset = load_dataset(self.path, split=self.split, streaming=True) + except Exception as e: + raise ValueError( + f"Failed to load dataset from {self.path} with split '{self.split}'. " + f"Supported formats: .jsonl, .parquet, or HuggingFace dataset name. " + f"Error: {e}" + ) from e + + # Apply preprocessing (map + filter) + dataset = dataset.map( + self._preprocess_function, + batched=True, + batch_size=128, + ) + + # Filter out invalid samples + dataset = dataset.filter(lambda x: x["is_valid"]) + + return dataset + + def _preprocess_function(self, examples: dict) -> dict: + """Preprocess function for HF .map(). + + Processes a batch of raw samples and converts them to Sample objects. + Samples that are too long are filtered out by marking is_valid=False. + """ + from slime.utils.data import _build_messages, filter_long_prompt + + batch_size = len(examples[list(examples.keys())[0]]) + processed_samples = [None] * batch_size + is_valid_list = [False] * batch_size + samples_to_filter = [] # Collect valid samples for batch filtering + sample_index_map = {} # Map sample id -> original index for filtered samples + + # Step 1: Process all samples and create Sample objects + for idx in range(batch_size): + data = {k: v[idx] for k, v in examples.items()} + + try: + # Build messages + as_conversation = self.apply_chat_template + prompt = _build_messages(data, self.prompt_key, as_conversation, self.multimodal_keys) + + # Handle metadata + metadata = data.get(self.metadata_key) or {} + + # Handle tools + tools = None + if self.tool_key is not None and self.tool_key in data: + tools = data[self.tool_key] + if isinstance(tools, str): + tools = json.loads(tools) + elif isinstance(tools, np.ndarray): + tools = tools.tolist() + assert isinstance(tools, list), f"tools must be a list, got {type(tools)}" + metadata["tools"] = tools + + # Apply chat template + if self.apply_chat_template: + formatted_prompt = self.tokenizer.apply_chat_template( + prompt, + tools=tools, + tokenize=False, + add_generation_prompt=True, + **self.apply_chat_template_kwargs, + ) + else: + formatted_prompt = prompt + + # Handle multimodal (experimental) + multimodal_inputs = None + if self.processor: + logger.warning("Multimodal support is experimental in streaming mode") + try: + from qwen_vl_utils import process_vision_info + + assert isinstance(prompt, list), "prompt must be a list when processor is not None" + images, videos = process_vision_info(prompt) + multimodal_inputs = {"images": images, "videos": videos} + except Exception as e: + logger.warning(f"Failed to process multimodal input: {e}, skipping sample") + # Keep as None and False (already initialized) + continue + + # Create Sample object + sample = Sample( + prompt=formatted_prompt, + label=data[self.label_key] if self.label_key is not None else None, + metadata=metadata, + multimodal_inputs=multimodal_inputs, + ) + + processed_samples[idx] = sample + samples_to_filter.append(sample) + sample_index_map[id(sample)] = idx + + except Exception as e: + logger.warning(f"Failed to preprocess sample: {e}, skipping") + # Keep as None and False (already initialized) + continue + + # Step 2: Batch filter all valid samples by length + if samples_to_filter: + filtered_result = filter_long_prompt(samples_to_filter, self.tokenizer, self.processor, self.max_length) + + # filter_long_prompt returns False when max_length is None or prompt is not a string (no filtering) + if filtered_result is False: + # No filtering applied, all valid samples pass + filtered_samples_ids = {id(sample) for sample in samples_to_filter} + else: + # Samples that passed the length filter + filtered_samples_ids = {id(sample) for sample in filtered_result} + + # Step 3: Update is_valid_list and processed_samples based on filtering results + for sample in samples_to_filter: + idx = sample_index_map[id(sample)] + if id(sample) in filtered_samples_ids: + is_valid_list[idx] = True + else: + # Sample was filtered out (too long) + is_valid_list[idx] = False + processed_samples[idx] = None + + return {"samples": processed_samples, "is_valid": is_valid_list} + + def get_next_batch(self, num_samples: int) -> list[Sample]: + """Get next batch of samples using DataLoader. + + This is the main consumption interface. StopIteration naturally propagates + to the main thread, enabling clean epoch transitions. + """ + if self._iter is None: + self._iter = iter(self.dataloader) + + samples = [] + for _ in range(num_samples): + try: + sample_data = next(self._iter) + sample = sample_data["samples"] + if sample is None: + continue + + samples.append(sample) + self.consumed_count += 1 + self.global_consumed_count += 1 + + except StopIteration: + # Epoch ended - clean transition + logger.info(f"Epoch {self.epoch_id} completed ({self.consumed_count} samples)") + self.epoch_id += 1 + self.consumed_count = 0 + self.hf_dataset.set_epoch(self.epoch_id) # Triggers reshuffle + self._iter = iter(self.dataloader) + + # Get sample from new epoch + try: + sample_data = next(self._iter) + sample = sample_data["samples"] + if sample is not None: + samples.append(sample) + self.consumed_count += 1 + self.global_consumed_count += 1 + except StopIteration: + logger.warning("New epoch iterator immediately exhausted") + break + + return samples + + def shuffle(self, new_epoch_id: int): + """Shuffle for new epoch. + + Called by RolloutDataSource when starting a new epoch. + Uses HF's set_epoch() to change shuffle seed (effective_seed = seed + epoch_id), + then recreates iterator to apply new shuffle order with the updated seed. + """ + if self.epoch_id == new_epoch_id: + return + + logger.info(f"Shuffling for epoch {new_epoch_id} (current epoch: {self.epoch_id})") + self.epoch_id = new_epoch_id + self.consumed_count = 0 + self.hf_dataset.set_epoch(new_epoch_id) + self._iter = iter(self.dataloader) + + def get_checkpoint_state(self) -> dict: + """Get state for checkpoint using HF's native state_dict. + + State tracking: + - epoch_id: Current epoch number (for seed+epoch reproducible shuffle) + - consumed_count: Samples consumed in current epoch (for statistics) + - global_consumed_count: Total samples consumed across all epochs + - hf_state_dict: HF's native iterator state (stores exact position) + + VERIFIED: hf_state_dict enables exact position resume without sample skipping. + """ + return { + "epoch_id": self.epoch_id, + "consumed_count": self.consumed_count, + "global_consumed_count": self.global_consumed_count, + "hf_state_dict": self.hf_dataset.state_dict(), + } + + def load_checkpoint_state(self, state: dict): + """Load state from checkpoint using HF's native load_state_dict. + + VERIFIED BEHAVIOR: + - HF's state_dict() stores exact iterator position (shard + offset) + - load_state_dict() restores this position accurately + - No manual sample skipping is needed after load + """ + self.epoch_id = state.get("epoch_id", 0) + self.consumed_count = state.get("consumed_count", 0) + self.global_consumed_count = state.get("global_consumed_count", 0) + + # Restore HF iterator state + if "hf_state_dict" in state: + self.hf_dataset.load_state_dict(state["hf_state_dict"]) + + self.hf_dataset.set_epoch(self.epoch_id) + self._iter = iter(self.dataloader) + + logger.info( + f"Loaded checkpoint: epoch={self.epoch_id}, " + f"consumed={self.consumed_count}, " + f"global_consumed={self.global_consumed_count}" + ) + + def __iter__(self): + """Iterate over dataset (for compatibility). + + Note: Prefer get_next_batch() for production use. + """ + if self._iter is None: + self._iter = iter(self.dataloader) + + while True: + try: + sample_data = next(self._iter) + sample = sample_data["samples"] + if sample is None: + continue + self.consumed_count += 1 + self.global_consumed_count += 1 + yield sample + except StopIteration: + self.epoch_id += 1 + self.consumed_count = 0 + self.hf_dataset.set_epoch(self.epoch_id) + self._iter = iter(self.dataloader) + + +class _HFDatasetWrapper(TorchIterableDataset): + """Minimal wrapper for PyTorch DataLoader compatibility. + + Only provides __iter__ with dataset_size limit. All other operations + go directly through the HF dataset. + """ + + def __init__(self, hf_dataset: HFIterableDataset, dataset_size: int): + super().__init__() + self.hf_dataset = hf_dataset + self.dataset_size = dataset_size + + def __iter__(self): + count = 0 + for sample in self.hf_dataset: + yield sample + count += 1 + if count >= self.dataset_size: + break diff --git a/tests/test_hf_datasets.py b/tests/test_hf_datasets.py new file mode 100644 index 000000000..0ae75025a --- /dev/null +++ b/tests/test_hf_datasets.py @@ -0,0 +1,846 @@ +"""Unit tests for HuggingFace Datasets integration (streaming mode). + +This test file covers: +1. HFIterableDatasetAdapter basic functionality (initialization, get_next_batch, shuffle) +2. RolloutDataSource mixed mode logic (auto-detection via duck typing) +3. Checkpoint support (save/load/resume across epochs) +4. Edge cases (dataset=None, empty dataset, sample_offset overflow) + +Test Strategy: +- Use small synthetic datasets (100 samples) for fast execution +- Mock Ray actors and heavy dependencies where appropriate +- Focus on correctness of data flow and state management +""" + +import json +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import torch + + +# Test fixtures and utilities +@pytest.fixture +def temp_dir(): + """Create temporary directory for test artifacts.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + +@pytest.fixture +def test_jsonl_data(temp_dir): + """Create small test dataset (100 samples) in JSONL format.""" + data_path = Path(temp_dir) / "test_data.jsonl" + samples = [] + for i in range(100): + samples.append( + { + "input": f"Test prompt {i}", + "label": f"Test label {i}", + "metadata": {"sample_id": i}, + } + ) + + with open(data_path, "w") as f: + for sample in samples: + f.write(json.dumps(sample) + "\n") + + return str(data_path) + + +@pytest.fixture +def mock_args(test_jsonl_data, temp_dir): + """Mock args object for testing.""" + args = MagicMock() + + # Data source config + args.rollout_global_dataset = True + args.prompt_data = test_jsonl_data + args.input_key = "input" + args.label_key = "label" + args.metadata_key = "metadata" + args.tool_key = None + args.multimodal_keys = None + args.apply_chat_template = False + args.apply_chat_template_kwargs = {} + args.rollout_seed = 42 + args.rollout_shuffle = False + args.rollout_max_prompt_len = None + args.n_samples_per_prompt = 1 + + # HF Datasets config + args.use_hf_datasets = False # Default to Legacy + args.hf_dataset_buffer_size = 10 + args.hf_dataset_shuffle_buffer = 100 + args.hf_dataset_num_proc = 4 + args.hf_datasets_num_samples = 100 # Required for HF streaming mode + args.hf_dataset_split = "train" # Default split name + + # Checkpoint config + args.save = str(Path(temp_dir) / "checkpoints") + args.load = None + args.dump_details = None + + # Mock tokenizer and processor + args.hf_checkpoint = None + + return args + + +@pytest.fixture +def mock_tokenizer(): + """Mock HuggingFace tokenizer.""" + tokenizer = MagicMock() + tokenizer.encode = lambda text: list(range(len(text.split()))) # Simple mock + tokenizer.apply_chat_template = lambda msgs, **kwargs: str(msgs) + return tokenizer + + +@pytest.fixture +def mock_processor(): + """Mock HuggingFace processor (for multimodal).""" + return None # Most tests don't use multimodal + + +@pytest.fixture +def test_jsonl_data_with_ids(temp_dir): + """Create test dataset with unique IDs for deduplication testing. + + Creates 100 samples with: + - Unique sample_id in metadata + - Sequential numbering for deterministic testing + """ + data_path = Path(temp_dir) / "test_data_with_ids.jsonl" + for i in range(100): + sample = { + "input": f"Test prompt {i}", + "label": f"Test label {i}", + "metadata": {"sample_id": f"sample_{i:04d}"}, + } + with open(data_path, "a") as f: + f.write(json.dumps(sample) + "\n") + return str(data_path) + + +# ============================================================================ +# Test Class 1: HFDatasetAdapters Basic Functionality +# ============================================================================ + + +class TestHFDatasetAdapters: + """Test HF adapters' core functionality.""" + + def test_streaming_adapter_initialization(self, test_jsonl_data, mock_tokenizer, mock_processor): + """Test HFIterableDatasetAdapter initialization.""" + from slime.utils.hf_dataset import HFIterableDatasetAdapter + + adapter = HFIterableDatasetAdapter( + path=test_jsonl_data, + dataset_size=100, # Required for epoch tracking + tokenizer=mock_tokenizer, + processor=mock_processor, + max_length=None, + prompt_key="input", + label_key="label", + metadata_key="metadata", + seed=42, + num_workers=0, # Single-process for testing + shuffle_buffer_size=100, + ) + + # Check state tracking + assert adapter.epoch_id == 0 + assert adapter.consumed_count == 0 + assert adapter.global_consumed_count == 0 + assert len(adapter) == 100 # Dataset size + + def test_get_next_batch_sequential(self, test_jsonl_data, mock_tokenizer, mock_processor): + """Test sequential consumption via get_next_batch().""" + from slime.utils.hf_dataset import HFIterableDatasetAdapter + + adapter = HFIterableDatasetAdapter( + path=test_jsonl_data, + dataset_size=100, # Required for epoch tracking + tokenizer=mock_tokenizer, + processor=mock_processor, + max_length=None, + prompt_key="input", + label_key="label", + seed=42, + num_workers=0, # Single-process for testing + ) + + # Consume first batch + batch1 = adapter.get_next_batch(num_samples=10) + assert len(batch1) == 10 + assert adapter.consumed_count == 10 + + # Consume second batch + batch2 = adapter.get_next_batch(num_samples=10) + assert len(batch2) == 10 + assert adapter.consumed_count == 20 + + # Check no overlap + assert batch1 != batch2 + + def test_epoch_switch(self, test_jsonl_data, mock_tokenizer, mock_processor): + """Test automatic epoch switching when dataset is exhausted.""" + from slime.utils.hf_dataset import HFIterableDatasetAdapter + + adapter = HFIterableDatasetAdapter( + path=test_jsonl_data, + dataset_size=100, # Required for epoch tracking + tokenizer=mock_tokenizer, + processor=mock_processor, + max_length=None, + prompt_key="input", + label_key="label", + seed=42, + num_workers=0, # Single-process for testing + ) + + # Consume entire dataset + total_samples = 0 + while adapter.epoch_id == 0 and total_samples < 200: + batch = adapter.get_next_batch(num_samples=10) + total_samples += len(batch) + + # Should have switched to epoch 1 + assert adapter.epoch_id >= 1 + + +# ============================================================================ +# Test Class 2: RolloutDataSource Mixed Mode Logic +# ============================================================================ + + +class TestRolloutDataSourceMixedMode: + """Test RolloutDataSource mixed mode logic (duck typing).""" + + @patch("slime.rollout.data_source.load_tokenizer") + @patch("slime.rollout.data_source.load_processor") + def test_get_samples_with_dataset_none(self, mock_load_proc, mock_load_tok, mock_args): + """Test get_samples() when dataset=None (--disable-rollout-global-dataset).""" + from slime.rollout.data_source import RolloutDataSource + from slime.utils.types import Sample + + mock_load_tok.return_value = MagicMock() + mock_load_proc.return_value = None + + # Disable global dataset + mock_args.rollout_global_dataset = False + + data_source = RolloutDataSource(mock_args) + samples = data_source.get_samples(num_samples=5) + + # Should return 5 groups of empty samples + assert len(samples) == 5 + assert all(isinstance(group, list) for group in samples) + assert all(isinstance(sample, Sample) for group in samples for sample in group) + + @patch("slime.rollout.data_source.load_tokenizer") + @patch("slime.rollout.data_source.load_processor") + def test_duck_typing_detection(self, mock_load_proc, mock_load_tok, mock_args): + """Test duck typing correctly detects HF adapters vs Legacy Dataset.""" + from slime.rollout.data_source import RolloutDataSource + + mock_load_tok.return_value = MagicMock() + mock_load_proc.return_value = None + + data_source = RolloutDataSource(mock_args) + + # Check dataset was created (Legacy mode since use_hf_datasets=False) + assert data_source.dataset is not None + + # Verify duck typing: should not have get_next_batch (Legacy Dataset) + assert not hasattr(data_source.dataset, "get_next_batch") + + # Verify has .samples attribute (Legacy Dataset) + assert hasattr(data_source.dataset, "samples") + + @patch("slime.rollout.data_source.load_tokenizer") + @patch("slime.rollout.data_source.load_processor") + def test_get_samples_with_hf_streaming(self, mock_load_proc, mock_load_tok, mock_args): + """Test get_samples() with HF Streaming mode.""" + from slime.rollout.data_source import RolloutDataSource + + mock_load_tok.return_value = MagicMock() + mock_load_proc.return_value = None + + # Enable HF Datasets mode + mock_args.use_hf_datasets = True + + data_source = RolloutDataSource(mock_args) + + # Verify duck typing detected HF adapter + assert hasattr(data_source.dataset, "get_next_batch") + + # Get samples + samples = data_source.get_samples(num_samples=5) + assert len(samples) == 5 + + +# ============================================================================ +# Test Class 3: Checkpoint Support +# ============================================================================ + + +class TestCheckpointSupport: + """Test checkpoint save/load/resume functionality.""" + + @patch("slime.rollout.data_source.load_tokenizer") + @patch("slime.rollout.data_source.load_processor") + def test_save_and_load_legacy(self, mock_load_proc, mock_load_tok, mock_args): + """Test checkpoint save/load for Legacy Dataset.""" + from slime.rollout.data_source import RolloutDataSource + + mock_load_tok.return_value = MagicMock() + mock_load_proc.return_value = None + + # Create data source with Legacy Dataset + data_source = RolloutDataSource(mock_args) + + # Consume some samples + data_source.get_samples(num_samples=10) + data_source.get_samples(num_samples=5) + + # Save checkpoint + rollout_id = 42 + data_source.save(rollout_id) + + # Verify checkpoint file exists + ckpt_path = Path(mock_args.save) / f"rollout/global_dataset_state_dict_{rollout_id}.pt" + assert ckpt_path.exists() + + # Load checkpoint into new data source + data_source2 = RolloutDataSource(mock_args) + + # Set load path + data_source2.args.load = mock_args.save + data_source2.load(rollout_id) + + # Verify state restored + assert data_source2.sample_offset == data_source.sample_offset + assert data_source2.epoch_id == data_source.epoch_id + assert data_source2.sample_group_index == data_source.sample_group_index + assert data_source2.sample_index == data_source.sample_index + + @patch("slime.rollout.data_source.load_tokenizer") + @patch("slime.rollout.data_source.load_processor") + def test_save_and_load_hf_streaming(self, mock_load_proc, mock_load_tok, mock_args): + """Test checkpoint save/load for HF Streaming mode.""" + from slime.rollout.data_source import RolloutDataSource + + mock_load_tok.return_value = MagicMock() + mock_load_proc.return_value = None + + # Enable HF Datasets mode + mock_args.use_hf_datasets = True + + data_source = RolloutDataSource(mock_args) + + # Consume some samples + data_source.get_samples(num_samples=10) + + # Save checkpoint + rollout_id = 42 + data_source.save(rollout_id) + + # Verify checkpoint contains HF adapter state + ckpt_path = Path(mock_args.save) / f"rollout/global_dataset_state_dict_{rollout_id}.pt" + state_dict = torch.load(ckpt_path) + assert "hf_adapter_state" in state_dict + assert "epoch_id" in state_dict["hf_adapter_state"] + assert "consumed_count" in state_dict["hf_adapter_state"] + + +# ============================================================================ +# Test Class 4: Edge Cases +# ============================================================================ + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + @patch("slime.rollout.data_source.load_tokenizer") + @patch("slime.rollout.data_source.load_processor") + def test_dataset_none(self, mock_load_proc, mock_load_tok, mock_args): + """Test behavior when dataset=None.""" + from slime.rollout.data_source import RolloutDataSource + from slime.utils.types import Sample + + mock_load_tok.return_value = MagicMock() + mock_load_proc.return_value = None + + # Disable global dataset + mock_args.rollout_global_dataset = False + + data_source = RolloutDataSource(mock_args) + + # get_samples should return empty Sample objects + samples = data_source.get_samples(num_samples=3) + assert len(samples) == 3 + assert all(isinstance(group, list) for group in samples) + assert all(len(group) == mock_args.n_samples_per_prompt for group in samples) + assert all(isinstance(sample, Sample) for group in samples for sample in group) + + @patch("slime.rollout.data_source.load_tokenizer") + @patch("slime.rollout.data_source.load_processor") + def test_checkpoint_nonexistent_path(self, mock_load_proc, mock_load_tok, mock_args): + """Test loading from nonexistent checkpoint path.""" + from slime.rollout.data_source import RolloutDataSource + + mock_load_tok.return_value = MagicMock() + mock_load_proc.return_value = None + + data_source = RolloutDataSource(mock_args) + + # Set load path to nonexistent location + data_source.args.load = "/nonexistent/path" + + # Should not raise error, just log warning + data_source.load(rollout_id=999) + + # State should remain at initial values + assert data_source.sample_offset == 0 + assert data_source.epoch_id == 0 + + +# ============================================================================ +# Integration Tests (Optional - require full environment) +# ============================================================================ + + +class TestIntegration: + """End-to-end integration tests (require HF datasets library).""" + + def test_full_training_loop_simulation(self, temp_dir, test_jsonl_data_with_ids, mock_tokenizer, mock_processor): + """Simulate full training loop: rollout → train → checkpoint → resume. + + This test verifies: + 1. Sequential consumption across multiple rollouts + 2. Checkpoint save at step N + 3. Checkpoint resume from step N + 4. No sample duplication (via metadata.sample_id) + 5. Automatic epoch switching (100 samples / 32 batch → 3+ epochs) + """ + from slime.rollout.data_source import RolloutDataSource + + # Setup args + args = MagicMock() + args.rollout_global_dataset = True + args.prompt_data = test_jsonl_data_with_ids + args.input_key = "input" + args.label_key = "label" + args.metadata_key = "metadata" + args.tool_key = None + args.multimodal_keys = None + args.apply_chat_template = False + args.apply_chat_template_kwargs = {} + args.rollout_seed = 42 + args.rollout_shuffle = False # Disable shuffle for deterministic testing + args.rollout_max_prompt_len = None + args.n_samples_per_prompt = 1 + args.use_hf_datasets = True # Enable HF Datasets mode + args.hf_dataset_buffer_size = 10 + args.hf_dataset_shuffle_buffer = 100 + args.hf_dataset_num_proc = 4 + args.hf_datasets_num_samples = 100 # Required for epoch tracking + args.hf_dataset_split = "train" # Dataset split name + args.save = str(Path(temp_dir) / "checkpoints") + args.load = None + args.hf_checkpoint = None + + # Mock tokenizer/processor loading + with patch("slime.rollout.data_source.load_tokenizer", return_value=mock_tokenizer), patch( + "slime.rollout.data_source.load_processor", return_value=mock_processor + ): + + # === Phase 1: Run 10 rollouts with checkpoint at step 5 === + data_source = RolloutDataSource(args) + + consumed_sample_ids = set() + num_rollouts = 10 + batch_size = 32 # 32 prompts per rollout + current_epoch = 0 + + for rollout_id in range(num_rollouts): + # Track epoch before getting samples + before_epoch = data_source.epoch_id + + samples = data_source.get_samples(num_samples=batch_size) + + # Track epoch after getting samples + after_epoch = data_source.epoch_id + + # Verify batch size + assert len(samples) == batch_size, f"Expected {batch_size} sample groups, got {len(samples)}" + + # Check for epoch transition + if after_epoch > before_epoch: + # Epoch changed, clear dedup set (samples can repeat across epochs) + consumed_sample_ids.clear() + current_epoch = after_epoch + + # Extract sample IDs and check for duplicates within same epoch + for group in samples: + for sample in group: + sample_id = sample.metadata.get("sample_id") + assert sample_id is not None, f"Sample missing unique ID at rollout {rollout_id}" + assert ( + sample_id not in consumed_sample_ids + ), f"Duplicate sample detected: {sample_id} at rollout {rollout_id}, epoch {current_epoch}" + consumed_sample_ids.add(sample_id) + + # Save checkpoint at step 5 + if rollout_id == 5: + data_source.save(rollout_id) + + # Verify epoch switching occurred (100 samples / 32 batch = 3.125 batches per epoch) + # After 10 rollouts (320 samples requested), should be in epoch 3+ + assert data_source.epoch_id >= 2, f"Expected multiple epochs, but only in epoch {data_source.epoch_id}" + + # === Phase 2: Verify checkpoint file exists and structure === + ckpt_path = Path(args.save) / "rollout/global_dataset_state_dict_5.pt" + assert ckpt_path.exists(), "Checkpoint file not created" + + # Load and verify checkpoint structure + state_dict = torch.load(ckpt_path) + required_keys = ["sample_offset", "epoch_id", "sample_group_index", "sample_index"] + for key in required_keys: + assert key in state_dict, f"Missing key in checkpoint: {key}" + + # Verify HF adapter state is present + assert "hf_adapter_state" in state_dict, "Missing HF adapter state in checkpoint" + hf_state = state_dict["hf_adapter_state"] + assert "epoch_id" in hf_state, "Missing epoch_id in HF adapter state" + assert "consumed_count" in hf_state, "Missing consumed_count in HF adapter state" + + # === Phase 3: Resume from checkpoint and verify state restoration === + data_source2 = RolloutDataSource(args) + data_source2.args.load = args.save + data_source2.load(rollout_id=5) + + # Verify state restoration + assert data_source2.dataset is not None, "Dataset not initialized after load" + if hasattr(data_source2.dataset, "consumed_count"): + # HF adapter should restore consumed_count + assert data_source2.dataset.consumed_count >= 0, "Invalid consumed_count after restore" + + # Continue for 5 more rollouts (simulating steps 6-10) + # This verifies that checkpoint correctly saved position + for _rollout_id in range(6, num_rollouts + 1): + samples = data_source2.get_samples(num_samples=batch_size) + assert len(samples) == batch_size, f"Expected {batch_size} samples after resume" + + def test_epoch_boundary_checkpoint(self, temp_dir, test_jsonl_data_with_ids, mock_tokenizer, mock_processor): + """Test checkpoint save/load at epoch boundary. + + Edge case: Verify correct behavior when checkpoint occurs exactly + at epoch transition (after consuming all 100 samples). + """ + from slime.rollout.data_source import RolloutDataSource + + # Setup args (similar to test_full_training_loop_simulation) + args = MagicMock() + args.rollout_global_dataset = True + args.prompt_data = test_jsonl_data_with_ids + args.input_key = "input" + args.label_key = "label" + args.metadata_key = "metadata" + args.tool_key = None + args.multimodal_keys = None + args.apply_chat_template = False + args.apply_chat_template_kwargs = {} + args.rollout_seed = 42 + args.rollout_shuffle = False + args.rollout_max_prompt_len = None + args.n_samples_per_prompt = 1 + args.use_hf_datasets = True + args.hf_dataset_buffer_size = 10 + args.hf_dataset_shuffle_buffer = 100 + args.hf_dataset_num_proc = 4 + args.hf_datasets_num_samples = 100 # Required for epoch tracking + args.hf_dataset_split = "train" # Dataset split name + args.save = str(Path(temp_dir) / "checkpoints") + args.load = None + args.hf_checkpoint = None + + with patch("slime.rollout.data_source.load_tokenizer", return_value=mock_tokenizer), patch( + "slime.rollout.data_source.load_processor", return_value=mock_processor + ): + + data_source = RolloutDataSource(args) + + # Consume exactly 100 samples (one complete epoch) + # With 100 samples and batch_size=25, need 4 rollouts + for _rollout_id in range(4): + samples = data_source.get_samples(num_samples=25) + assert len(samples) == 25 + + # Epoch switch happens when we request samples that cross the boundary + # Request one more sample to trigger epoch transition + samples = data_source.get_samples(num_samples=1) + assert len(samples) == 1 + + # Should have completed epoch 0, now in epoch 1 + assert data_source.epoch_id >= 1, "Expected epoch transition after 100+ samples" + + # Save checkpoint at epoch boundary + data_source.save(rollout_id=4) + + # Verify checkpoint + ckpt_path = Path(args.save) / "rollout/global_dataset_state_dict_4.pt" + assert ckpt_path.exists(), "Checkpoint not saved at epoch boundary" + + state_dict = torch.load(ckpt_path) + assert "hf_adapter_state" in state_dict + hf_state = state_dict["hf_adapter_state"] + + # At epoch boundary, consumed_count should be 0 (reset for new epoch) + # or 100 (if tracking total), depending on implementation + assert "epoch_id" in hf_state + assert hf_state["epoch_id"] >= 1, "Epoch ID not incremented at boundary" + + # Resume from checkpoint + data_source2 = RolloutDataSource(args) + data_source2.args.load = args.save + data_source2.load(rollout_id=4) + + # Continue consuming - should start from epoch 1 + samples = data_source2.get_samples(num_samples=10) + assert len(samples) == 10, "Failed to consume after epoch boundary checkpoint" + + +# ============================================================================ +# Test Class 5: HF State Tracking Verification +# ============================================================================ + + +class TestHFStateTracking: + """Test HF state tracking and checkpoint resume accuracy. + + These tests verify whether HF's state_dict()/load_state_dict() can + accurately restore iterator position without manual sample skipping. + + VERIFICATION RESULTS (all tests passed): + 1. test_hf_state_dict_exact_resume: HF state_dict() stores exact iterator position + 2. test_consumed_count_consistency: consumed_count tracking is accurate + 3. test_shuffle_reproducibility_same_seed_epoch: Same seed+epoch = same shuffle + 4. test_epoch_boundary_exact_resume: Resume works correctly across epoch boundaries + + CONCLUSION: No additional offset is needed beyond seed + epoch + hf_state_dict. + HF's native checkpoint mechanism handles exact position resume without manual skipping. + """ + + def test_hf_state_dict_exact_resume(self, test_jsonl_data_with_ids, mock_tokenizer, mock_processor): + """Verify HF state_dict() restores exact iterator position. + + Golden truth: Use the SAME adapter's continued consumption as expected result. + This tests whether we need additional offset tracking beyond seed + epoch. + """ + from slime.utils.hf_dataset import HFIterableDatasetAdapter + + # Create adapter with NO shuffle for deterministic testing + adapter1 = HFIterableDatasetAdapter( + path=test_jsonl_data_with_ids, + dataset_size=100, + tokenizer=mock_tokenizer, + processor=mock_processor, + max_length=None, + prompt_key="input", + label_key="label", + metadata_key="metadata", + seed=42, + num_workers=0, + do_shuffle=False, # Disable shuffle for deterministic test + ) + + # Step 1: Consume 35 samples + _ = adapter1.get_next_batch(35) + assert adapter1.consumed_count == 35 + + # Step 2: Save checkpoint state at position 35 + checkpoint_state = adapter1.get_checkpoint_state() + assert checkpoint_state["consumed_count"] == 35 + + # Step 3: Continue consuming 10 more samples → GOLDEN EXPECTED + golden_samples = adapter1.get_next_batch(10) + golden_ids = [s.metadata.get("sample_id") for s in golden_samples] + + # Step 4: Create new adapter and load checkpoint + adapter2 = HFIterableDatasetAdapter( + path=test_jsonl_data_with_ids, + dataset_size=100, + tokenizer=mock_tokenizer, + processor=mock_processor, + max_length=None, + prompt_key="input", + label_key="label", + metadata_key="metadata", + seed=42, + num_workers=0, + do_shuffle=False, + ) + + adapter2.load_checkpoint_state(checkpoint_state) + assert adapter2.consumed_count == 35 # State restored + + # Step 5: Get next 10 samples from resumed adapter + resumed_samples = adapter2.get_next_batch(10) + resumed_ids = [s.metadata.get("sample_id") for s in resumed_samples] + + # CRITICAL VERIFICATION: + # resumed_ids should match golden_ids (from same checkpoint position) + assert resumed_ids == golden_ids, ( + f"HF state_dict did not restore exact position!\n" + f"Golden (from adapter1 after checkpoint): {golden_ids}\n" + f"Resumed (from adapter2 after load): {resumed_ids}\n" + f"If these differ, HF state_dict is NOT working and we need manual skipping." + ) + + # Additional check: consumed_count should update correctly after resume + assert adapter2.consumed_count == 45 # 35 + 10 + + def test_consumed_count_consistency(self, test_jsonl_data_with_ids, mock_tokenizer, mock_processor): + """Verify consumed_count matches actual consumption.""" + from slime.utils.hf_dataset import HFIterableDatasetAdapter + + adapter = HFIterableDatasetAdapter( + path=test_jsonl_data_with_ids, + dataset_size=100, + tokenizer=mock_tokenizer, + processor=mock_processor, + max_length=None, + prompt_key="input", + label_key="label", + metadata_key="metadata", + seed=42, + num_workers=0, + do_shuffle=False, + ) + + # Initial state + assert adapter.consumed_count == 0 + assert adapter.global_consumed_count == 0 + + # Consume batches and verify counting + adapter.get_next_batch(10) + assert adapter.consumed_count == 10 + assert adapter.global_consumed_count == 10 + + adapter.get_next_batch(25) + assert adapter.consumed_count == 35 + assert adapter.global_consumed_count == 35 + + # Checkpoint should preserve counts + state = adapter.get_checkpoint_state() + assert state["consumed_count"] == 35 + assert state["global_consumed_count"] == 35 + + def test_shuffle_reproducibility_same_seed_epoch(self, test_jsonl_data_with_ids, mock_tokenizer, mock_processor): + """Verify same seed+epoch produces same shuffle order.""" + from slime.utils.hf_dataset import HFIterableDatasetAdapter + + # Create two adapters with same seed + adapter1 = HFIterableDatasetAdapter( + path=test_jsonl_data_with_ids, + dataset_size=100, + tokenizer=mock_tokenizer, + processor=mock_processor, + max_length=None, + prompt_key="input", + seed=42, + num_workers=0, + do_shuffle=True, + shuffle_buffer_size=100, # Full buffer for exact shuffle + ) + + adapter2 = HFIterableDatasetAdapter( + path=test_jsonl_data_with_ids, + dataset_size=100, + tokenizer=mock_tokenizer, + processor=mock_processor, + max_length=None, + prompt_key="input", + seed=42, + num_workers=0, + do_shuffle=True, + shuffle_buffer_size=100, + ) + + # Get first 20 samples from each + samples1 = adapter1.get_next_batch(20) + samples2 = adapter2.get_next_batch(20) + + ids1 = [s.metadata.get("sample_id") for s in samples1] + ids2 = [s.metadata.get("sample_id") for s in samples2] + + # Same seed+epoch should produce same shuffle + assert ids1 == ids2, ( + f"Same seed+epoch produced different shuffle!\n" + f"Adapter1: {ids1[:5]}...\n" + f"Adapter2: {ids2[:5]}..." + ) + + def test_epoch_boundary_exact_resume(self, test_jsonl_data_with_ids, mock_tokenizer, mock_processor): + """Verify checkpoint/resume at epoch boundary works correctly.""" + from slime.utils.hf_dataset import HFIterableDatasetAdapter + + adapter1 = HFIterableDatasetAdapter( + path=test_jsonl_data_with_ids, + dataset_size=100, + tokenizer=mock_tokenizer, + processor=mock_processor, + max_length=None, + prompt_key="input", + label_key="label", + metadata_key="metadata", + seed=42, + num_workers=0, + do_shuffle=False, + ) + + # Consume 95 samples (near epoch boundary) + _ = adapter1.get_next_batch(95) + assert adapter1.consumed_count == 95 + assert adapter1.epoch_id == 0 + + # Save checkpoint + checkpoint_state = adapter1.get_checkpoint_state() + + # Continue consuming - this will cross epoch boundary + golden_samples = adapter1.get_next_batch(10) + golden_ids = [s.metadata.get("sample_id") for s in golden_samples] + # After crossing boundary: epoch_id should be 1, consumed_count reset + assert adapter1.epoch_id >= 1, "Should have crossed epoch boundary" + + # Create new adapter and load checkpoint (at position 95) + adapter2 = HFIterableDatasetAdapter( + path=test_jsonl_data_with_ids, + dataset_size=100, + tokenizer=mock_tokenizer, + processor=mock_processor, + max_length=None, + prompt_key="input", + label_key="label", + metadata_key="metadata", + seed=42, + num_workers=0, + do_shuffle=False, + ) + + adapter2.load_checkpoint_state(checkpoint_state) + + # Resume and cross boundary + resumed_samples = adapter2.get_next_batch(10) + resumed_ids = [s.metadata.get("sample_id") for s in resumed_samples] + + # Verify samples match across epoch boundary + assert resumed_ids == golden_ids, ( + f"Epoch boundary resume failed!\n" + f"Golden: {golden_ids}\n" + f"Resumed: {resumed_ids}" + ) \ No newline at end of file diff --git a/tests/test_qwen3-0.6B_hf_datasets_dapo.py b/tests/test_qwen3-0.6B_hf_datasets_dapo.py new file mode 100644 index 000000000..1e87be285 --- /dev/null +++ b/tests/test_qwen3-0.6B_hf_datasets_dapo.py @@ -0,0 +1,189 @@ +"""E2E test for HF Datasets integration with DAPO RL training. + +This test verifies: +1. HF Datasets streaming mode works with real DAPO training +2. Dataset: zhuzilin/dapo-math-17k (loaded directly via load_dataset) +3. Checkpoint save/restore includes HF adapter state +4. Resume continues training without sample duplication +""" + +import os +from pathlib import Path + +import torch + +import slime.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3-0.6B" +MODEL_TYPE = "qwen3-0.6B" + + +def prepare(): + """Prepare model (dataset will be auto-loaded by HF Datasets).""" + U.exec_command("mkdir -p /root/models") + + # Download and convert model checkpoint + # Dataset will be auto-loaded via --prompt-data + U.convert_checkpoint( + model_name=f"Qwen/{MODEL_NAME}", + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=2, + ) + + +def execute(): + """Execute DAPO RL training with HF Datasets.""" + ckpt_args = ( + f"--hf-checkpoint /root/models/Qwen/{MODEL_NAME} " + f"--ref-load /root/models/Qwen/{MODEL_NAME} " # Reference model for KL metrics + "--save /root/Qwen3-0.6B_slime/ " + "--save-interval 30 " # Save checkpoint at step 30 + ) + + rollout_args = ( + # HF Datasets configuration + "--prompt-data zhuzilin/dapo-math-17k " # Direct HF dataset name + "--use-hf-datasets " # Enable HF Datasets streaming mode + "--hf-datasets-num-samples 17000 " # Required for proper epoch tracking + "--hf-dataset-buffer-size 100 " + "--hf-dataset-shuffle-buffer 1000 " + "--hf-dataset-num-proc 4 " + # Data keys (specific to dapo-math-17k) + "--input-key prompt " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + # Reward model and sampling + "--rm-type math " # Math reward model for DAPO + f"--num-rollout {3000 if U.get_env_enable_infinite_run() else 60} " + "--rollout-batch-size 32 " + "--n-samples-per-prompt 8 " + "--rollout-max-response-len 2048 " # Reduced for faster testing + "--rollout-temperature 0.8 " + "--over-sampling-batch-size 64 " + "--global-batch-size 256 " + "--balance-data " # Balance tokens between DP ranks + ) + + eval_args = ( + "--eval-interval 30 " # Align with save-interval + "--eval-prompt-data dapo_test zhuzilin/dapo-math-17k " # Use same dataset for eval + "--n-samples-per-eval-prompt 4 " + "--eval-max-response-len 2048 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = "--rollout-num-gpus-per-engine 2 " "--sglang-decode-log-interval 1000 " "--sglang-enable-metrics " + + fsdp_args = "--update-weight-buffer-size 536870912 " # 512MB + + ci_args = "--ci-test " "--ci-disable-kl-checker " + + misc_args = "--actor-num-nodes 1 " "--actor-num-gpus-per-node 2 " "--colocate " "--train-backend fsdp " + + # Phase 1: Run first 30 rollouts + train_args_phase1 = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{sglang_args} " + f"{U.get_default_wandb_args(__file__, run_name_prefix='hf-dapo-phase1')} " + f"{eval_args} " + f"{fsdp_args} " + f"{ci_args} " + f"{misc_args} " + ) + + print("\n" + "=" * 80) + print("Phase 1: Running first 30 rollouts with HF Datasets") + print("=" * 80 + "\n") + + U.execute_train( + train_args=train_args_phase1, + num_gpus_per_node=2, + megatron_model_type=None, + ) + + # Verify checkpoint was saved + checkpoint_path = Path("/root/Qwen3-0.6B_slime/rollout/global_dataset_state_dict_30.pt") + assert checkpoint_path.exists(), f"Checkpoint not found at {checkpoint_path}" + + # Verify checkpoint contains HF adapter state + state_dict = torch.load(checkpoint_path) + assert "hf_adapter_state" in state_dict, "Missing HF adapter state in checkpoint!" + hf_state = state_dict["hf_adapter_state"] + assert "epoch_id" in hf_state, "Missing epoch_id in HF adapter state" + assert "consumed_count" in hf_state, "Missing consumed_count in HF adapter state" + + print("\n" + "=" * 80) + print("Checkpoint verified successfully!") + print(f" - Epoch ID: {hf_state['epoch_id']}") + print(f" - Consumed count: {hf_state['consumed_count']}") + print("=" * 80 + "\n") + + # Phase 2: Resume from checkpoint and continue to 60 rollouts + ckpt_args_phase2 = ( + f"--hf-checkpoint /root/models/Qwen/{MODEL_NAME} " + f"--ref-load /root/models/Qwen/{MODEL_NAME} " # Reference model for KL metrics + "--load /root/Qwen3-0.6B_slime/ " # Load from previous checkpoint + "--save /root/Qwen3-0.6B_slime/ " + "--save-interval 30 " + ) + + train_args_phase2 = ( + f"{ckpt_args_phase2} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{sglang_args} " + f"{U.get_default_wandb_args(__file__, run_name_prefix='hf-dapo-phase2')} " + f"{eval_args} " + f"{fsdp_args} " + f"{ci_args} " + f"{misc_args} " + ) + + print("\n" + "=" * 80) + print("Phase 2: Resuming from checkpoint (30 → 60 rollouts)") + print("=" * 80 + "\n") + + U.execute_train( + train_args=train_args_phase2, + num_gpus_per_node=2, + megatron_model_type=None, + ) + + print("\n" + "=" * 80) + print("E2E RL test (DAPO) completed successfully!") + print("Verified:") + print(" ✓ HF Datasets streaming mode") + print(" ✓ Checkpoint save with HF adapter state") + print(" ✓ Checkpoint resume and continuation") + print("=" * 80 + "\n") + + +if __name__ == "__main__": + prepare() + # Remove proxy settings (may interfere with HF Datasets download) + for key in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"]: + os.environ.pop(key, None) + execute() diff --git a/tests/test_qwen3-0.6B_hf_datasets_sft.py b/tests/test_qwen3-0.6B_hf_datasets_sft.py new file mode 100644 index 000000000..8fdfa0eed --- /dev/null +++ b/tests/test_qwen3-0.6B_hf_datasets_sft.py @@ -0,0 +1,183 @@ +"""E2E test for HF Datasets integration with SFT (Supervised Fine-Tuning). + +This test verifies: +1. HF Datasets streaming mode works with SFT training +2. Dataset: nvidia/Nemotron-Agentic-v1 (loaded directly via load_dataset) +3. Checkpoint save/restore includes HF adapter state +4. SFT mode correctly handles messages and tools fields +""" + +import os +from pathlib import Path + +import torch + +import slime.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3-0.6B" +MODEL_TYPE = "qwen3-0.6B" + + +def prepare(): + """Prepare model (dataset will be auto-loaded by HF Datasets).""" + U.exec_command("mkdir -p /root/models") + + # Download and convert model checkpoint + # Dataset will be auto-loaded via --prompt-data + U.convert_checkpoint( + model_name=f"Qwen/{MODEL_NAME}", + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=2, + ) + + +def execute(): + """Execute SFT training with HF Datasets.""" + ckpt_args = ( + f"--hf-checkpoint /root/models/Qwen/{MODEL_NAME} " + "--save /root/Qwen3-0.6B_slime_sft/ " + "--save-interval 30 " # Save checkpoint at step 30 + ) + + rollout_args = ( + # HF Datasets configuration + "--prompt-data nvidia/Nemotron-Agentic-v1 " # Direct HF dataset name + "--hf-dataset-split interactive_agent " # Specify the correct split for this dataset + "--use-hf-datasets " # Enable HF Datasets streaming mode + "--hf-datasets-num-samples 335122 " # Required for epoch tracking + "--hf-dataset-buffer-size 100 " + "--hf-dataset-shuffle-buffer 1000 " + "--hf-dataset-num-proc 4 " + # Data keys (Nemotron-Agentic-v1 uses standard messages/tools format) + "--input-key messages " + "--tool-key tools " + "--rollout-shuffle " + # SFT-specific settings + f"--num-rollout {3000 if U.get_env_enable_infinite_run() else 60} " + "--rollout-batch-size 128 " + "--n-samples-per-prompt 1 " # SFT typically uses 1 sample per prompt + "--rollout-max-response-len 2048 " + "--rollout-temperature 1.0 " # Greedy decoding for SFT + "--global-batch-size 128 " + ) + + eval_args = ( + "--eval-interval 30 " # Align with save-interval + "--eval-prompt-data nemotron_test nvidia/Nemotron-Agentic-v1 " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 2048 " + ) + + # SFT mode: No reward model, no GRPO + # Just standard language modeling loss with loss masks + sft_args = ( + "--rollout-function-path slime.rollout.sft_rollout.generate_rollout " + "--loss-type sft_loss " + "--calculate-per-token-loss " + "--disable-compute-advantages-and-returns " + "--debug-train-only " + ) + + # SFT mode: No reward model, no GRPO + # Just standard language modeling loss + optimizer_args = ( + "--optimizer adam " + "--lr 1e-5 " # Higher LR for SFT + "--lr-decay-style cosine " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.999 " + ) + + fsdp_args = "--update-weight-buffer-size 536870912 " # 512MB + + ci_args = "--ci-test " "--ci-disable-kl-checker " + + misc_args = "--actor-num-nodes 1 " "--actor-num-gpus-per-node 2 " "--colocate " "--train-backend fsdp " + + # Phase 1: Run first 30 rollouts + train_args_phase1 = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{sft_args} " + f"{optimizer_args} " + f"{U.get_default_wandb_args(__file__, run_name_prefix='hf-sft-phase1')} " + f"{eval_args} " + f"{fsdp_args} " + f"{ci_args} " + f"{misc_args} " + ) + + print("\n" + "=" * 80) + print("Phase 1: Running first 30 rollouts with HF Datasets (SFT mode)") + print("=" * 80 + "\n") + + U.execute_train( + train_args=train_args_phase1, + num_gpus_per_node=2, + megatron_model_type=None, + ) + + # Verify checkpoint was saved + checkpoint_path = Path("/root/Qwen3-0.6B_slime_sft/rollout/global_dataset_state_dict_30.pt") + assert checkpoint_path.exists(), f"Checkpoint not found at {checkpoint_path}" + + # Verify checkpoint contains HF adapter state + state_dict = torch.load(checkpoint_path) + assert "hf_adapter_state" in state_dict, "Missing HF adapter state in checkpoint!" + hf_state = state_dict["hf_adapter_state"] + assert "epoch_id" in hf_state, "Missing epoch_id in HF adapter state" + assert "consumed_count" in hf_state, "Missing consumed_count in HF adapter state" + + print("\n" + "=" * 80) + print("Checkpoint verified successfully!") + print(f" - Epoch ID: {hf_state['epoch_id']}") + print(f" - Consumed count: {hf_state['consumed_count']}") + print("=" * 80 + "\n") + + # Phase 2: Resume from checkpoint and continue to 60 rollouts + ckpt_args_phase2 = ( + f"--hf-checkpoint /root/models/Qwen/{MODEL_NAME} " + "--load /root/Qwen3-0.6B_slime_sft/ " # Load from previous checkpoint + "--save /root/Qwen3-0.6B_slime_sft/ " + "--save-interval 30 " + ) + + train_args_phase2 = ( + f"{ckpt_args_phase2} " + f"{rollout_args} " + f"{sft_args} " + f"{optimizer_args} " + f"{U.get_default_wandb_args(__file__, run_name_prefix='hf-sft-phase2')} " + f"{eval_args} " + f"{fsdp_args} " + f"{ci_args} " + f"{misc_args} " + ) + + print("\n" + "=" * 80) + print("Phase 2: Resuming from checkpoint (30 → 60 rollouts)") + print("=" * 80 + "\n") + + U.execute_train( + train_args=train_args_phase2, + num_gpus_per_node=2, + megatron_model_type=None, + ) + + print("\n" + "=" * 80) + print("E2E SFT test completed successfully!") + print("Verified:") + print(" ✓ HF Datasets streaming mode with Nemotron-Agentic-v1") + print(" ✓ Messages and tools field parsing") + print(" ✓ Checkpoint save with HF adapter state") + print(" ✓ Checkpoint resume and continuation") + print("=" * 80 + "\n") + + +if __name__ == "__main__": + prepare() + # Remove proxy settings (may interfere with HF Datasets download) + for key in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"]: + os.environ.pop(key, None) + execute()