Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions megatron/rl/rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@
get_sequence_packing_tensorboard_metrics,
get_sequence_packing_log_info,
get_default_packed_seq_params,
get_packing_actual_tokens,
get_packing_compute_tokens,
get_packing_efficiency,
get_packing_avg_seq_length,
update_microbatch_calculator,
)
from megatron.rl.agent.api import (
Expand Down Expand Up @@ -301,11 +305,26 @@ def __init__(self):
self.last_collection_iteration = 0
self.sequences_this_iteration_on_rank = 0
self.latest_batch_num_sequences = 0
# Derived throughput metrics (set by training_log, read by RLProfiler)
self.tokens_per_sec = None
self.tokens_per_sec_per_gpu = None
self.compute_tokens_per_sec = None
self.compute_tokens_per_sec_per_gpu = None
self.actual_tokens_per_sec = None
self.actual_tokens_per_sec_per_gpu = None
self.packing_efficiency = None

def reset_iteration_counters(self, iteration):
"""Reset per-iteration counters."""
self.sequences_this_iteration_on_rank = 0
self.last_collection_iteration = iteration
self.tokens_per_sec = None
self.tokens_per_sec_per_gpu = None
self.compute_tokens_per_sec = None
self.compute_tokens_per_sec_per_gpu = None
self.actual_tokens_per_sec = None
self.actual_tokens_per_sec_per_gpu = None
self.packing_efficiency = None

def increment_sequences(self, count):
"""Increment the sequence counter."""
Expand Down
58 changes: 58 additions & 0 deletions megatron/rl/sequence_packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,3 +1173,61 @@ def get_sequence_packing_tensorboard_metrics(args):
metrics['bin-batch-size'] = bin_batch_size
metrics['consumed-bins'] = args.consumed_train_bins
return metrics


def get_packing_actual_tokens(packing_context: PackingContext) -> int:
"""Get the actual number of tokens (non-padding) in the packed sequences for this rank.

Args:
packing_context: The PackingContext containing packing information.

Returns:
Total number of actual tokens across all bins on this rank.
"""
return sum(
packing_context.packing_info.seq_lengths[idx]
for indices in packing_context.packing_info.bin_seq_indices
for idx in indices
)


def get_packing_compute_tokens(packing_context: PackingContext) -> int:
"""Get the total compute tokens (including padding) for packed sequences on this rank.

Args:
packing_context: The PackingContext containing packing information.

Returns:
Total compute tokens (num_bins * bin_size) on this rank.
"""
return packing_context.packed_trajs.shape[0] * packing_context.packed_trajs.shape[1]


def get_packing_efficiency(packing_context: PackingContext) -> float:
"""Get the packing efficiency (actual_tokens / total_capacity) across all DP ranks.

Args:
packing_context: The PackingContext containing packing information.

Returns:
Packing efficiency as a float between 0 and 1.
"""
total_actual_tokens = sum(packing_context.packing_info.seq_lengths)
num_ranks = mpu.get_data_parallel_world_size()
bins_per_rank = packing_context.packed_trajs.shape[0]
bin_size = packing_context.packed_trajs.shape[1]
total_capacity = bins_per_rank * bin_size * num_ranks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true that every rank will have the same amount of bins?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. distribute_packed_bins ensures it.


if total_capacity == 0:
return 0.0

return total_actual_tokens / total_capacity


def get_packing_avg_seq_length(packing_context: PackingContext) -> float:
"""Get the average sequence length across all sequences in the packing context."""
seq_lengths = packing_context.packing_info.seq_lengths
if not seq_lengths:
return 0.0

return sum(seq_lengths) / len(seq_lengths)
86 changes: 86 additions & 0 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2165,6 +2165,92 @@ def training_log(
total_loss_dict[skipped_iters_key]
)
log_string += ' number of nan iterations: {:3d} |'.format(total_loss_dict[nan_iters_key])

# RL token throughput metrics.
if args.perform_rl_step:
tokens_per_sec = None
tokens_per_sec_per_gpu = None
compute_tokens_per_sec = None
compute_tokens_per_sec_per_gpu = None
actual_tokens_per_sec = None
actual_tokens_per_sec_per_gpu = None
packing_efficiency = None

if args.seq_length > 0:
tokens_per_iteration = batch_size * args.seq_length
tokens_per_sec = tokens_per_iteration / elapsed_time_per_iteration
tokens_per_sec_per_gpu = tokens_per_sec / args.world_size

# For sequence packing, break down into compute vs actual tokens
if args.rl_use_sequence_packing:
runtime_state = rl_utils.get_rl_runtime_state()
if runtime_state.packing_context is not None:
dp_world_size = mpu.get_data_parallel_world_size()

compute_tokens = rl_utils.get_packing_compute_tokens(runtime_state.packing_context)
all_ranks_compute_tokens = compute_tokens * dp_world_size
compute_tokens_per_sec = all_ranks_compute_tokens / elapsed_time_per_iteration
compute_tokens_per_sec_per_gpu = compute_tokens_per_sec / args.world_size

actual_tokens = rl_utils.get_packing_actual_tokens(runtime_state.packing_context)
all_ranks_actual_tokens = actual_tokens * dp_world_size
actual_tokens_per_sec = all_ranks_actual_tokens / elapsed_time_per_iteration
actual_tokens_per_sec_per_gpu = actual_tokens_per_sec / args.world_size

packing_efficiency = rl_utils.get_packing_efficiency(runtime_state.packing_context)

# Add tokens/sec to log string
log_string += f' toks/s: {tokens_per_sec:.0f} |'
log_string += f' toks/s/gpu: {tokens_per_sec_per_gpu:.0f} |'
if compute_tokens_per_sec is not None:
log_string += f' compute_toks/s: {compute_tokens_per_sec:.0f} |'
log_string += f' compute_toks/s/gpu: {compute_tokens_per_sec_per_gpu:.0f} |'
if actual_tokens_per_sec is not None:
log_string += f' actual_toks/s: {actual_tokens_per_sec:.0f} |'
log_string += f' actual_toks/s/gpu: {actual_tokens_per_sec_per_gpu:.0f} |'
log_string += f' packing_eff: {packing_efficiency:.1%} |'

# Log throughput metrics to wandb
if wandb_writer is not None:
if tokens_per_sec is not None:
wandb_writer.log({
'throughput/tokens_per_sec': tokens_per_sec,
'throughput/tokens_per_sec_per_gpu': tokens_per_sec_per_gpu,
}, iteration)
if compute_tokens_per_sec is not None:
wandb_writer.log({
'throughput/compute_tokens_per_sec': compute_tokens_per_sec,
'throughput/compute_tokens_per_sec_per_gpu': compute_tokens_per_sec_per_gpu,
}, iteration)
if actual_tokens_per_sec is not None:
wandb_writer.log({
'throughput/actual_tokens_per_sec': actual_tokens_per_sec,
'throughput/actual_tokens_per_sec_per_gpu': actual_tokens_per_sec_per_gpu,
'throughput/packing_efficiency': packing_efficiency,
}, iteration)

# Store derived throughput metrics on RLRuntimeState so that
# downstream consumers (e.g. RLProfiler) can read them.
runtime_state = rl_utils.get_rl_runtime_state()
runtime_state.tokens_per_sec = tokens_per_sec
runtime_state.tokens_per_sec_per_gpu = tokens_per_sec_per_gpu
runtime_state.compute_tokens_per_sec = compute_tokens_per_sec
runtime_state.compute_tokens_per_sec_per_gpu = compute_tokens_per_sec_per_gpu
runtime_state.actual_tokens_per_sec = actual_tokens_per_sec
runtime_state.actual_tokens_per_sec_per_gpu = actual_tokens_per_sec_per_gpu
runtime_state.packing_efficiency = packing_efficiency

# Log average sequence length. With packing this shows real sequence
# lengths; without packing it equals seq_length as a baseline.
packing_ctx = runtime_state.packing_context
if args.rl_use_sequence_packing and packing_ctx is not None:
avg_seq_length = rl_utils.get_packing_avg_seq_length(packing_ctx)
log_string += f' avg_seq_len: {avg_seq_length:.1f} |'
if wandb_writer is not None:
wandb_writer.log({'throughput/avg_seq_length': avg_seq_length}, iteration)
elif args.log_throughput:
log_string += f' avg_seq_len: {args.seq_length} |'

if should_reset:
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0
Expand Down
45 changes: 45 additions & 0 deletions tests/unit_tests/rl/test_sequence_packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,51 @@ def test_compute_packed_inference_logprobs_stats_shape_mismatch():
assert group_stats.mean_piold_to_inf_prob is None


def test_packing_observability_metrics():
"""Test various observability metrics related to sequence packing."""

# 4 sequences with known lengths packed into 2 bins of size 16.
# Bin 0 holds seqs 0 (len 5) and 1 (len 3) → 8 actual tokens
# Bin 1 holds seqs 2 (len 10) and 3 (len 4) → 14 actual tokens
seq_lengths = [5, 3, 10, 4]
packing_info = sequence_packing_utils.PackingInfo(
bin_seq_indices=[[0, 1], [2, 3]],
seq_starts={0: [0, 5], 1: [0, 10]},
seq_lengths=seq_lengths,
seq_to_bin_idx=[0, 0, 1, 1],
packing_algo='fifo',
)

num_bins, bin_size = 2, 16
packed_trajs = torch.zeros(num_bins, bin_size, dtype=torch.long)
ctx = sequence_packing_utils.PackingContext(
bin_size=bin_size,
packer=None,
packing_info=packing_info,
original_generation_masks=None,
original_trajs=None,
packed_trajs=packed_trajs,
packed_position_ids=None,
packed_attention_mask=None,
packed_loss_mask=None,
)

# actual tokens = sum of all seq_lengths referenced by bin_seq_indices
assert sequence_packing_utils.get_packing_actual_tokens(ctx) == 5 + 3 + 10 + 4

# compute tokens = num_bins * bin_size
assert sequence_packing_utils.get_packing_compute_tokens(ctx) == 2 * 16

# avg seq length = mean of seq_lengths
assert sequence_packing_utils.get_packing_avg_seq_length(ctx) == pytest.approx(22 / 4)

# efficiency = total_actual / (bins_per_rank * bin_size * num_ranks)
with patch('megatron.core.mpu.get_data_parallel_world_size', return_value=4):
eff = sequence_packing_utils.get_packing_efficiency(ctx)
# total_actual = sum(seq_lengths) = 22, capacity = 2 * 16 * 4 = 128
assert eff == pytest.approx(22 / 128)


@pytest.mark.parametrize("num_sequences", [1, 10, 48, 49, 50])
def test_cu_seqlens_size(num_sequences):
"""Test that cu_seqlens always has a fixed size regardless of how many sequences are packed."""
Expand Down
Loading