From 4c5410d8c0878013392217a18e55933e1d2be317 Mon Sep 17 00:00:00 2001 From: tanzelin430 Date: Mon, 12 Jan 2026 21:40:09 +0800 Subject: [PATCH] feat: Add DeepSpeed SFT support and quality-of-life improvements ## DeepSpeed SFT Support The SFT pipeline was originally designed for Megatron strategy, where the model receives labels and returns per-token losses directly. In DeepSpeed strategy with HuggingFace models, the model returns logits instead. **Solution:** Override `op_compute_language_loss` in `DeepSpeedTrainStrategy`: - Compute cross-entropy directly from logits - DataCollatorForSFT already shifts labels (shift_feature=True by default), so logits and labels are already aligned - no double-shift needed This follows ROLL's design pattern where Strategy handles backend differences, keeping Worker code generic. ## Quality-of-Life Features 1. **Checkpoint Cleanup (`max_ckpt_to_keep`)** - Automatically delete old checkpoints to prevent disk exhaustion - Usage: `max_ckpt_to_keep: 3` keeps only latest 3 checkpoints 2. **Wandb Offline Mode** - Added `mode` parameter to WandbTracker - Usage: `tracker_kwargs: {mode: offline}` 3. **SFT Training Improvements** - Enable data shuffling in DataLoader (was False) - Add tqdm progress bar for training visualization 4. **pip Installation Support** - Added setup.py for `pip install -e .` Co-Authored-By: Claude Opus 4.5 --- roll/configs/base_config.py | 4 ++ .../strategy/deepspeed_strategy.py | 26 ++++++++++ roll/pipeline/base_pipeline.py | 52 +++++++++++++++++++ roll/pipeline/sft/sft_pipeline.py | 14 +++-- roll/pipeline/sft/sft_worker.py | 4 +- roll/utils/tracking.py | 3 +- setup.py | 9 ++++ 7 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 setup.py diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index aa1da4c2..dd918673 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -82,6 +82,10 @@ class BaseConfig: default=50, metadata={"help": "Save checkpoint every X update steps."} ) + max_ckpt_to_keep: int = field( + default=0, + metadata={"help": "Maximum number of checkpoints to keep. 0 means keep all checkpoints."} + ) logging_steps: int = field( default=1, metadata={"help": "Number of steps between logging information."} diff --git a/roll/distributed/strategy/deepspeed_strategy.py b/roll/distributed/strategy/deepspeed_strategy.py index e2803951..0053fb4c 100644 --- a/roll/distributed/strategy/deepspeed_strategy.py +++ b/roll/distributed/strategy/deepspeed_strategy.py @@ -28,6 +28,7 @@ from roll.utils.context_parallel import get_ulysses_group, set_upg_manager from roll.utils.deepspeed_utils import get_optimizer_grouped_parameters from roll.utils.functionals import append_to_dict, entropy_from_logits, log_probs_from_logits +from roll.utils.constants import IGNORE_INDEX from roll.utils.logging import get_logger from roll.utils.offload_states import OffloadStateType from roll.platforms import current_platform @@ -398,6 +399,31 @@ def initialize(self, model_provider): logger.info(f"{self.model}") dist.barrier() + def op_compute_language_loss(self, logits: torch.Tensor, labels: torch.Tensor): + """ + Override for DeepSpeed strategy: compute language loss from logits. + + In DeepSpeed strategy with HuggingFace models, the model returns logits + (not loss like in Megatron strategy where labels are passed to the model). + + Note: DataCollatorForSFT already shifts labels (shift_feature=True by default), + so logits and labels are already aligned. Do NOT shift again here. + + Args: + logits: Model output logits [batch_size, seq_len, vocab_size] + labels: Pre-shifted labels [batch_size, seq_len], already aligned with logits + + Returns: + loss: Scalar loss tensor + """ + # Labels already shifted by DataCollator, directly compute cross-entropy + loss = torch.nn.functional.cross_entropy( + logits.view(-1, logits.size(-1)), + labels.view(-1), + ignore_index=IGNORE_INDEX + ) + return loss + def train_step( self, batch: DataProto, diff --git a/roll/pipeline/base_pipeline.py b/roll/pipeline/base_pipeline.py index 958ce369..ad12f91a 100644 --- a/roll/pipeline/base_pipeline.py +++ b/roll/pipeline/base_pipeline.py @@ -1,4 +1,6 @@ import os +import re +import shutil from collections import defaultdict from concurrent import futures from typing import List, Any, Dict @@ -91,9 +93,59 @@ def do_checkpoint(self, global_step): self.state.save_rng_state(save_dir=save_dir, tag="pipeline") self.checkpoint_manager.upload(ckpt_id=ckpt_id, local_state_path=pipeline_save_dir) + # Clean up old checkpoints if max_ckpt_to_keep is set + self._cleanup_old_checkpoints() + futures.wait(self.resume_futures) self.resume_futures.clear() + def _cleanup_old_checkpoints(self): + """Remove old checkpoints if max_ckpt_to_keep is set.""" + max_ckpt = getattr(self.pipeline_config, 'max_ckpt_to_keep', 0) + if max_ckpt <= 0: + return + + output_dir = self.pipeline_config.output_dir + if not os.path.exists(output_dir): + return + + # Pattern to match checkpoint directories: checkpoint-{step} + ckpt_pattern = re.compile(r'^checkpoint-(\d+)$') + + # Collect all checkpoint steps across all subdirectories + all_ckpt_steps = set() + for subdir in os.listdir(output_dir): + subdir_path = os.path.join(output_dir, subdir) + if not os.path.isdir(subdir_path): + continue + for item in os.listdir(subdir_path): + match = ckpt_pattern.match(item) + if match: + all_ckpt_steps.add(int(match.group(1))) + + # Sort steps and determine which to delete + sorted_steps = sorted(all_ckpt_steps, reverse=True) + steps_to_delete = sorted_steps[max_ckpt:] + + if not steps_to_delete: + return + + logger.info(f"Cleaning up old checkpoints. Keeping {max_ckpt}, deleting steps: {steps_to_delete}") + + # Delete old checkpoints from all subdirectories + for subdir in os.listdir(output_dir): + subdir_path = os.path.join(output_dir, subdir) + if not os.path.isdir(subdir_path): + continue + for step in steps_to_delete: + ckpt_dir = os.path.join(subdir_path, f"checkpoint-{step}") + if os.path.exists(ckpt_dir): + try: + shutil.rmtree(ckpt_dir) + logger.info(f"Deleted old checkpoint: {ckpt_dir}") + except Exception as e: + logger.warning(f"Failed to delete checkpoint {ckpt_dir}: {e}") + def download_models(self, *clusters: Cluster): node2worker: Dict[str, Any] = {} node2model_names: Dict[str, set[str]] = defaultdict(set) diff --git a/roll/pipeline/sft/sft_pipeline.py b/roll/pipeline/sft/sft_pipeline.py index 163ea275..8bdc0b36 100644 --- a/roll/pipeline/sft/sft_pipeline.py +++ b/roll/pipeline/sft/sft_pipeline.py @@ -6,6 +6,7 @@ import torch from codetiming import Timer from torch.utils.data import DataLoader +from tqdm import tqdm from roll.datasets.chat_template import get_chat_template from roll.datasets.collator import DataCollatorForSFT @@ -152,7 +153,7 @@ def __init__(self, pipeline_config: SFTConfig): self.dataloader = DataLoader( dataset=self.dataset, batch_size=global_train_batch_size, - shuffle=False, + shuffle=True, # Enable shuffle for better training drop_last=True, num_workers=self.pipeline_config.sft_train.training_args.dataloader_num_workers, collate_fn=data_collator, @@ -181,11 +182,14 @@ def __init__(self, pipeline_config: SFTConfig): def run(self): global_step = 0 metrics_mgr = MetricsManager() + num_epochs = self.pipeline_config.sft_train.training_args.num_train_epochs + total_steps = num_epochs * len(self.dataloader) - for epoch in range(self.pipeline_config.sft_train.training_args.num_train_epochs): + for epoch in range(num_epochs): logger.info(f"epoch {epoch} start...") - for batch_dict in self.dataloader: + pbar = tqdm(self.dataloader, desc=f"Epoch {epoch}/{num_epochs}") + for batch_dict in pbar: # for continual training if global_step <= self.state.step: global_step += 1 @@ -213,6 +217,10 @@ def run(self): metrics = metrics_mgr.get_metrics() metrics = {k: float(v) for k, v in metrics.items()} logger.info(f"metrics: {metrics}") + + # Update tqdm progress bar + loss = metrics.get("sft_train/loss", 0) + pbar.set_postfix({"loss": f"{loss:.4f}", "step": f"{global_step}/{total_steps}"}) self.state.step = global_step self.state.log_history.append(metrics) diff --git a/roll/pipeline/sft/sft_worker.py b/roll/pipeline/sft/sft_worker.py index 8d63bd51..aedc73ae 100644 --- a/roll/pipeline/sft/sft_worker.py +++ b/roll/pipeline/sft/sft_worker.py @@ -68,4 +68,6 @@ def do_checkpoint(self, global_step): def loss_func(self, data: DataProto, output_tensor: torch.Tensor): labels = data.batch["labels"] - return self.strategy.op_compute_language_loss(output_tensor, labels) \ No newline at end of file + loss = self.strategy.op_compute_language_loss(output_tensor, labels) + metrics = {f"{self.worker_config.name}/loss": loss.detach().float().unsqueeze(0)} + return loss, metrics \ No newline at end of file diff --git a/roll/utils/tracking.py b/roll/utils/tracking.py index c9c0c1ec..dafda415 100644 --- a/roll/utils/tracking.py +++ b/roll/utils/tracking.py @@ -59,11 +59,12 @@ def __init__(self, config: dict, **kwargs): notes = kwargs.pop("notes", None) log_dir = kwargs.pop("log_dir", None) api_key = kwargs.pop("api_key", None) + mode = kwargs.pop("mode", None) settings = kwargs.pop("settings", {"console": "off"}) import wandb if api_key: wandb.login(key=api_key) - self.run = wandb.init(project=project, tags=tags, name=name, notes=notes, dir=log_dir, settings=settings) + self.run = wandb.init(project=project, tags=tags, name=name, notes=notes, dir=log_dir, mode=mode, settings=settings) self.run.config.update(config, allow_val_change=True) diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..d41c5fa2 --- /dev/null +++ b/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup, find_packages + +setup( + name="roll", + version="0.1.0", + description="ROLL - Reinforcement Learning Optimization for Large-Scale Learning", + packages=find_packages(include=["roll", "roll.*"]), + python_requires=">=3.10", +)