Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/ar/qwen3_drpo_4b_base_dapo_sglang.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,13 @@ sync:
flush_cache: true

stack:
_target_: unirl.train.stack.TrainStack
_target_: unirl.train.stack.LLMTrainStack
micro_batch_size: 1
# verl ppo_max_token_len_per_gpu parity: pack each mini-batch into micro-batches
# under a token budget (length-sorted bin packing) instead of fixed-count
# slicing; with mbs=1 every sequence ran its own forward/backward. null keeps
# the legacy count-based behavior.
micro_token_budget: 10240
max_grad_norm: 1.0
# 4 disjoint mini-batch optimizer steps per rollout, 128 trajectories each —
# matches the RELEASED verl run script (run_qwen3_4b.sh: ppo_mini_batch_size=16
Expand Down
2 changes: 1 addition & 1 deletion examples/ar/qwen3_grpo_4b_base_dapo_sglang.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ sync:
flush_cache: true

stack:
_target_: unirl.train.stack.TrainStack
_target_: unirl.train.stack.LLMTrainStack
micro_batch_size: 1
max_grad_norm: 1.0
# 4 disjoint mini-batch optimizer steps per rollout, 128 trajectories each —
Expand Down
2 changes: 1 addition & 1 deletion examples/ar/qwen_vl_grpo_geo3k_mc_4x8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ algorithm:
sampling_temperature: 0.7

stack:
_target_: unirl.train.stack.TrainStack
_target_: unirl.train.stack.LLMTrainStack
# Per-worker shard is 512/32 = 16 samples. Slice it into micro-batches of 4:
# replay() materializes out.logits [mb, full_seq_len, vocab=152064] in bf16
# and keeps it in the autograd graph, so a 16-sample micro-batch at long
Expand Down
2 changes: 1 addition & 1 deletion examples/ar/qwen_vl_grpo_geo3k_mc_4x8_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ algorithm:
sampling_temperature: 0.7

stack:
_target_: unirl.train.stack.TrainStack
_target_: unirl.train.stack.LLMTrainStack
micro_batch_size: 4
max_grad_norm: 1.0

Expand Down
2 changes: 1 addition & 1 deletion examples/ar/qwen_vl_grpo_geo3k_mc_sglang_4x8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ algorithm:
sampling_temperature: 0.7

stack:
_target_: unirl.train.stack.TrainStack
_target_: unirl.train.stack.LLMTrainStack
micro_batch_size: 2
max_grad_norm: 1.0

Expand Down
2 changes: 1 addition & 1 deletion examples/ar/qwen_vl_grpo_geo3k_mc_sglang_4x8_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ algorithm:
sampling_temperature: 0.7

stack:
_target_: unirl.train.stack.TrainStack
_target_: unirl.train.stack.LLMTrainStack
micro_batch_size: 2
max_grad_norm: 1.0

Expand Down
237 changes: 237 additions & 0 deletions tests/train/test_packing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""Unit tests for micro-batch planning (unirl/train/stack/_packing.py) + the
LLMTrainStack seq-mean guard. Pure / CPU-only — no GPU, no FSDP backend.

Run with ``pytest tests/train/test_packing.py`` or directly:
``python tests/train/test_packing.py``.
"""

from __future__ import annotations

import types

import pytest
import torch

from unirl.train.stack import LLMTrainStack
from unirl.train.stack._packing import (
_build_micro_batch_slices,
_count_plan,
_micro_indices,
_pack_micros,
_pack_micros_2d,
_pack_micros_sum,
_partition_into_k,
_plan_packed,
_sync_micro_count,
_update_ranges,
)


# --------------------------------------------------------------------------- #
# helpers
# --------------------------------------------------------------------------- #
def _flatten(bins):
return [i for b in bins for i in b]


def _covers(bins, indices):
"""Every index appears exactly once across the bins."""
flat = _flatten(bins)
return sorted(flat) == sorted(indices) and len(flat) == len(indices)


def _fake_track(lengths):
"""Minimal RolloutTrack stand-in: just batch_size + segment.lengths."""
seg = types.SimpleNamespace(lengths=torch.tensor(lengths, dtype=torch.long))
return types.SimpleNamespace(batch_size=len(lengths), segment=seg, conditions={})


# --------------------------------------------------------------------------- #
# update / count partitioning
# --------------------------------------------------------------------------- #
def test_update_ranges_even():
assert _update_ranges(total_size=8, num_updates=2) == ((0, 4), (4, 8))
assert _update_ranges(total_size=12, num_updates=4) == ((0, 3), (3, 6), (6, 9), (9, 12))


def test_update_ranges_requires_divisibility():
with pytest.raises(ValueError):
_update_ranges(total_size=10, num_updates=3)


def test_count_micro_slices_cover():
sl = _build_micro_batch_slices(total_size=10, micro_batch_size=4)
assert sl == ((0, 4), (4, 8), (8, 10)) # last partial, full coverage


def test_count_plan_structure():
# 8 samples, 2 updates, micro_batch_size 1 -> 2 updates x 4 single-sample micros
plan = _count_plan(total=8, num_updates=2, micro_batch_size=1)
assert len(plan) == 2
assert [_micro_indices(m) for m in plan[0]] == [[0], [1], [2], [3]]
assert [_micro_indices(m) for m in plan[1]] == [[4], [5], [6], [7]]


# --------------------------------------------------------------------------- #
# token-budget packing — coverage + budget invariants
# --------------------------------------------------------------------------- #
DENSE_LENGTHS = [4000, 3500, 300, 250, 200, 180, 150, 120]


def test_pack_micros_dense_covers_and_respects_budget():
bins = _pack_micros(indices=list(range(8)), lengths=DENSE_LENGTHS, token_budget=10240)
assert _covers(bins, list(range(8)))
for b in bins:
cost = max(DENSE_LENGTHS[i] for i in b) * len(b)
assert cost <= 10240 or len(b) == 1 # single oversize seq allowed its own bin
# the two long seqs cannot share with the shorts under 10240
assert all(b for b in bins) # no empty bins


def test_pack_micros_sum_covers_and_respects_budget():
bins = _pack_micros_sum(indices=list(range(8)), lengths=DENSE_LENGTHS, token_budget=8000)
assert _covers(bins, list(range(8)))
for b in bins:
cost = sum(DENSE_LENGTHS[i] for i in b)
assert cost <= 8000 or len(b) == 1


def test_pack_micros_2d_covers_and_respects_budget():
prompt = [1000, 50, 900, 40, 30, 20, 10, 5]
resp = [50, 1000, 40, 900, 200, 180, 150, 120]
bins = _pack_micros_2d(indices=list(range(8)), prompt_lens=prompt, resp_lens=resp, token_budget=4096)
assert _covers(bins, list(range(8)))
for b in bins:
cost = (max(prompt[i] for i in b) + max(resp[i] for i in b)) * len(b)
assert cost <= 4096 or len(b) == 1


def test_oversize_sequence_gets_its_own_bin():
# one seq longer than the whole budget must still be placed (never dropped)
bins = _pack_micros(indices=[0, 1, 2], lengths=[50, 99999, 60], token_budget=1024)
assert _covers(bins, [0, 1, 2])
big = next(b for b in bins if 1 in b)
assert big == [1]


def test_pack_micros_rejects_nonpositive_budget():
with pytest.raises(ValueError):
_pack_micros(indices=[0, 1], lengths=[10, 20], token_budget=0)


# --------------------------------------------------------------------------- #
# exact-K re-partition (NCCL micro-count parity)
# --------------------------------------------------------------------------- #
def test_partition_into_k_exact_and_covers():
idx = list(range(8))
for k in (1, 2, 3, 5, 8):
bins = _partition_into_k(indices=idx, lengths=DENSE_LENGTHS, k=k)
assert len(bins) == k
assert all(len(b) >= 1 for b in bins) # every bin non-empty
assert _covers(bins, idx)


def test_partition_into_k_out_of_range():
with pytest.raises(ValueError):
_partition_into_k(indices=[0, 1, 2], lengths=[1, 2, 3], k=0)
with pytest.raises(ValueError):
_partition_into_k(indices=[0, 1, 2], lengths=[1, 2, 3], k=4) # k > n


def test_sync_micro_count_noop_without_dist():
# torch.distributed not initialized in a unit test -> returns the local count
assert _sync_micro_count(7) == 7


# --------------------------------------------------------------------------- #
# plan equivalence: packing only regroups, never changes which samples an update trains on
# --------------------------------------------------------------------------- #
def test_packed_and_count_plans_select_same_samples_per_update():
lengths = [4000, 3500, 300, 250, 200, 180, 150, 120]
track = _fake_track(lengths)
packed = _plan_packed(track, num_updates=2, token_budget=10240, cost_model="dense")
count = _count_plan(total=8, num_updates=2, micro_batch_size=1)
assert packed is not None and len(packed) == len(count) == 2
for u, (p_update, c_update) in enumerate(zip(packed, count)):
p_samples = sorted(i for m in p_update for i in _micro_indices(m))
c_samples = sorted(i for m in c_update for i in _micro_indices(m))
assert p_samples == c_samples # identical sample set per update
assert p_samples == list(range(u * 4, u * 4 + 4)) # the contiguous update range


def test_sample_share_weights_sum_to_one_per_update():
lengths = [4000, 3500, 300, 250, 200, 180, 150, 120]
track = _fake_track(lengths)
packed = _plan_packed(track, num_updates=2, token_budget=10240, cost_model="dense")
for update in packed:
update_total = sum(len(_micro_indices(m)) for m in update)
weights = [len(_micro_indices(m)) / update_total for m in update]
assert update_total == 4
assert abs(sum(weights) - 1.0) < 1e-12


def test_plan_packed_falls_back_when_no_lengths():
track = types.SimpleNamespace(batch_size=4, segment=None, conditions={})
assert _plan_packed(track, num_updates=2, token_budget=1024, cost_model="dense") is None


def test_plan_packed_picks_up_prompt_from_conditions_dict():
# review #42 B2: conditions is a Dict, so prompt lengths must be read via dict
# access — otherwise the budget counts response tokens only. With prompt=50,
# resp=100 the 2D cost is (50+100)*count<=300 -> <=2 per micro; if the prompt
# were ignored it'd be 100*count<=300 -> 3 per micro. The cap distinguishes them.
seg = types.SimpleNamespace(lengths=torch.tensor([100, 100, 100, 100], dtype=torch.long))
prompt = types.SimpleNamespace(attention_mask=torch.ones(4, 50, dtype=torch.long))
track = types.SimpleNamespace(batch_size=4, segment=seg, conditions={"prompt": prompt})
plan = _plan_packed(track, num_updates=1, token_budget=300, cost_model="dense")
assert plan is not None
assert sorted(i for m in plan[0] for i in _micro_indices(m)) == [0, 1, 2, 3]
assert all(len(_micro_indices(m)) <= 2 for m in plan[0]) # prompt counted -> 2D cap


# --------------------------------------------------------------------------- #
# LLMTrainStack seq-mean guard
# --------------------------------------------------------------------------- #
def _algo(mode):
return types.SimpleNamespace(loss_agg_mode=mode)


@pytest.mark.parametrize("mode", ["seq-mean-token-sum-norm", "seq-mean-token-mean"])
def test_guard_allows_seq_mean(mode):
LLMTrainStack._require_seq_mean(_algo(mode)) # must not raise


@pytest.mark.parametrize("mode", ["token-mean", "something-else", None])
def test_guard_rejects_non_seq_mean(mode):
with pytest.raises(ValueError):
LLMTrainStack._require_seq_mean(_algo(mode))


def test_guard_rejects_algo_without_agg_mode():
with pytest.raises(ValueError):
LLMTrainStack._require_seq_mean(types.SimpleNamespace()) # no loss_agg_mode attr


# --------------------------------------------------------------------------- #
# direct runner (no pytest required)
# --------------------------------------------------------------------------- #
if __name__ == "__main__":
failures = 0
for name, fn in sorted(globals().items()):
if not name.startswith("test_") or not callable(fn):
continue
marks = getattr(fn, "pytestmark", [])
param_sets = None
for m in marks:
if m.name == "parametrize":
param_sets = m.args[1]
cases = [(v,) for v in param_sets] if param_sets is not None else [()]
for args in cases:
try:
fn(*args)
print(f"PASS {name}{args if args else ''}")
except Exception as exc: # noqa: BLE001
failures += 1
print(f"FAIL {name}{args if args else ''}: {type(exc).__name__}: {exc}")
print(f"\n{'OK' if failures == 0 else f'{failures} FAILURE(S)'}")
raise SystemExit(1 if failures else 0)
17 changes: 17 additions & 0 deletions unirl/algorithms/drpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,15 @@ def compute_loss_and_backward(
return AlgorithmStepResult(loss=0.0, metrics={}, num_steps_or_tokens=0, has_backward=False)

typed_conds = typed_conditions(conditions, self.conditions_cls)
# Per-micro fwd/bwd wall-clock (cuda-synced) -> train/replay_time_s and
# train/backward_time_s via metrics aggregation: attributes the train
# phase without an external profiler. Sync cost is negligible next to
# multi-second micros.
import time as _time

if torch.cuda.is_available():
torch.cuda.synchronize()
_t0 = _time.perf_counter()
new_logp = self.stage.replay(
typed_conds, segment=segment, temperature=self.sampling_temperature
) # [total_tokens]
Expand Down Expand Up @@ -328,11 +337,19 @@ def compute_loss_and_backward(
loss = torch.stack([p.sum() for p in parts]).mean() / float(self.horizon)
else:
loss = loss_per_elem.mean()
if torch.cuda.is_available():
torch.cuda.synchronize()
_t1 = _time.perf_counter()
(loss * loss_scale).backward()
if torch.cuda.is_available():
torch.cuda.synchronize()
_t2 = _time.perf_counter()

metrics: Dict[str, Any] = {
"policy_loss": float(loss.detach().item()),
"drpo_epsilon": self.drpo_epsilon,
"replay_time_s": _t1 - _t0,
"backward_time_s": _t2 - _t1,
**rollout_replay_logp_absdiff(new_logp, old_logp),
**{k: float(v.item()) for k, v in ratio_metrics.items()},
}
Expand Down
21 changes: 21 additions & 0 deletions unirl/train/stack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Train stack package: family-agnostic base + diffusion / LLM micro-batching.

``unirl.train.stack`` used to be a single module; it is now a package. The public
surface is unchanged — ``TrainStack`` (the diffusion stack) and ``TrainStepResult``
import from here exactly as before. ``LLMTrainStack`` is the token-budget packed
variant (verl dynamic-bsz parity) for varlen LLM training.
"""

from unirl.train.stack._packing import _build_micro_batch_slices
from unirl.train.stack.base import AbstractTrainStack, TrainStepResult
from unirl.train.stack.diffusion import DiffusionTrainStack, TrainStack
from unirl.train.stack.llm import LLMTrainStack

__all__ = [
"AbstractTrainStack",
"DiffusionTrainStack",
"LLMTrainStack",
"TrainStack",
"TrainStepResult",
"_build_micro_batch_slices",
]
Loading
Loading