Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions roll/configs/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class BaseConfig:
default=50,
metadata={"help": "Save checkpoint every X update steps."}
)
max_ckpt_to_keep: int = field(
default=0,
metadata={"help": "Maximum number of checkpoints to keep. 0 means keep all checkpoints."}
)
logging_steps: int = field(
default=1,
metadata={"help": "Number of steps between logging information."}
Expand Down
26 changes: 26 additions & 0 deletions roll/distributed/strategy/deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from roll.utils.context_parallel import get_ulysses_group, set_upg_manager
from roll.utils.deepspeed_utils import get_optimizer_grouped_parameters
from roll.utils.functionals import append_to_dict, entropy_from_logits, log_probs_from_logits
from roll.utils.constants import IGNORE_INDEX
from roll.utils.logging import get_logger
from roll.utils.offload_states import OffloadStateType
from roll.platforms import current_platform
Expand Down Expand Up @@ -398,6 +399,31 @@ def initialize(self, model_provider):
logger.info(f"{self.model}")
dist.barrier()

def op_compute_language_loss(self, logits: torch.Tensor, labels: torch.Tensor):
"""
Override for DeepSpeed strategy: compute language loss from logits.

In DeepSpeed strategy with HuggingFace models, the model returns logits
(not loss like in Megatron strategy where labels are passed to the model).

Note: DataCollatorForSFT already shifts labels (shift_feature=True by default),
so logits and labels are already aligned. Do NOT shift again here.

Args:
logits: Model output logits [batch_size, seq_len, vocab_size]
labels: Pre-shifted labels [batch_size, seq_len], already aligned with logits

Returns:
loss: Scalar loss tensor
"""
# Labels already shifted by DataCollator, directly compute cross-entropy
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=IGNORE_INDEX
)
return loss

def train_step(
self,
batch: DataProto,
Expand Down
52 changes: 52 additions & 0 deletions roll/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import re
import shutil
from collections import defaultdict
from concurrent import futures
from typing import List, Any, Dict
Expand Down Expand Up @@ -91,9 +93,59 @@ def do_checkpoint(self, global_step):
self.state.save_rng_state(save_dir=save_dir, tag="pipeline")
self.checkpoint_manager.upload(ckpt_id=ckpt_id, local_state_path=pipeline_save_dir)

# Clean up old checkpoints if max_ckpt_to_keep is set
self._cleanup_old_checkpoints()

futures.wait(self.resume_futures)
self.resume_futures.clear()

def _cleanup_old_checkpoints(self):
"""Remove old checkpoints if max_ckpt_to_keep is set."""
max_ckpt = getattr(self.pipeline_config, 'max_ckpt_to_keep', 0)
if max_ckpt <= 0:
return

output_dir = self.pipeline_config.output_dir
if not os.path.exists(output_dir):
return

# Pattern to match checkpoint directories: checkpoint-{step}
ckpt_pattern = re.compile(r'^checkpoint-(\d+)$')

# Collect all checkpoint steps across all subdirectories
all_ckpt_steps = set()
for subdir in os.listdir(output_dir):
subdir_path = os.path.join(output_dir, subdir)
if not os.path.isdir(subdir_path):
continue
for item in os.listdir(subdir_path):
match = ckpt_pattern.match(item)
if match:
all_ckpt_steps.add(int(match.group(1)))

# Sort steps and determine which to delete
sorted_steps = sorted(all_ckpt_steps, reverse=True)
steps_to_delete = sorted_steps[max_ckpt:]

if not steps_to_delete:
return

logger.info(f"Cleaning up old checkpoints. Keeping {max_ckpt}, deleting steps: {steps_to_delete}")

# Delete old checkpoints from all subdirectories
for subdir in os.listdir(output_dir):
subdir_path = os.path.join(output_dir, subdir)
if not os.path.isdir(subdir_path):
continue
for step in steps_to_delete:
ckpt_dir = os.path.join(subdir_path, f"checkpoint-{step}")
if os.path.exists(ckpt_dir):
try:
shutil.rmtree(ckpt_dir)
logger.info(f"Deleted old checkpoint: {ckpt_dir}")
except Exception as e:
logger.warning(f"Failed to delete checkpoint {ckpt_dir}: {e}")

def download_models(self, *clusters: Cluster):
node2worker: Dict[str, Any] = {}
node2model_names: Dict[str, set[str]] = defaultdict(set)
Expand Down
14 changes: 11 additions & 3 deletions roll/pipeline/sft/sft_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from codetiming import Timer
from torch.utils.data import DataLoader
from tqdm import tqdm

from roll.datasets.chat_template import get_chat_template
from roll.datasets.collator import DataCollatorForSFT
Expand Down Expand Up @@ -152,7 +153,7 @@ def __init__(self, pipeline_config: SFTConfig):
self.dataloader = DataLoader(
dataset=self.dataset,
batch_size=global_train_batch_size,
shuffle=False,
shuffle=True, # Enable shuffle for better training
drop_last=True,
num_workers=self.pipeline_config.sft_train.training_args.dataloader_num_workers,
collate_fn=data_collator,
Expand Down Expand Up @@ -181,11 +182,14 @@ def __init__(self, pipeline_config: SFTConfig):
def run(self):
global_step = 0
metrics_mgr = MetricsManager()
num_epochs = self.pipeline_config.sft_train.training_args.num_train_epochs
total_steps = num_epochs * len(self.dataloader)

for epoch in range(self.pipeline_config.sft_train.training_args.num_train_epochs):
for epoch in range(num_epochs):
logger.info(f"epoch {epoch} start...")

for batch_dict in self.dataloader:
pbar = tqdm(self.dataloader, desc=f"Epoch {epoch}/{num_epochs}")
for batch_dict in pbar:
# for continual training
if global_step <= self.state.step:
global_step += 1
Expand Down Expand Up @@ -213,6 +217,10 @@ def run(self):
metrics = metrics_mgr.get_metrics()
metrics = {k: float(v) for k, v in metrics.items()}
logger.info(f"metrics: {metrics}")

# Update tqdm progress bar
loss = metrics.get("sft_train/loss", 0)
pbar.set_postfix({"loss": f"{loss:.4f}", "step": f"{global_step}/{total_steps}"})

self.state.step = global_step
self.state.log_history.append(metrics)
Expand Down
4 changes: 3 additions & 1 deletion roll/pipeline/sft/sft_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,6 @@ def do_checkpoint(self, global_step):

def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
labels = data.batch["labels"]
return self.strategy.op_compute_language_loss(output_tensor, labels)
loss = self.strategy.op_compute_language_loss(output_tensor, labels)
metrics = {f"{self.worker_config.name}/loss": loss.detach().float().unsqueeze(0)}
return loss, metrics
3 changes: 2 additions & 1 deletion roll/utils/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ def __init__(self, config: dict, **kwargs):
notes = kwargs.pop("notes", None)
log_dir = kwargs.pop("log_dir", None)
api_key = kwargs.pop("api_key", None)
mode = kwargs.pop("mode", None)
settings = kwargs.pop("settings", {"console": "off"})
import wandb
if api_key:
wandb.login(key=api_key)
self.run = wandb.init(project=project, tags=tags, name=name, notes=notes, dir=log_dir, settings=settings)
self.run = wandb.init(project=project, tags=tags, name=name, notes=notes, dir=log_dir, mode=mode, settings=settings)

self.run.config.update(config, allow_val_change=True)

Expand Down
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from setuptools import setup, find_packages

setup(
name="roll",
version="0.1.0",
description="ROLL - Reinforcement Learning Optimization for Large-Scale Learning",
packages=find_packages(include=["roll", "roll.*"]),
python_requires=">=3.10",
)