Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,10 @@ def forward_step(
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, model)
output = forward_step_func(data_iterator, model)
output_tensor, loss_func, num_empty_bins = output
else:
output_tensor, loss_func = forward_step_func(
output_tensor, loss_func, num_empty_bins = forward_step_func(
data_iterator, model, checkpoint_activations_microbatch
)
output_tensor, num_tokens = forward_step_calc_loss(
Expand All @@ -421,8 +422,8 @@ def forward_step(
)

if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens
return output_tensor, num_tokens, num_empty_bins
return [output_tensor], num_tokens, num_empty_bins


def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):
Expand Down Expand Up @@ -576,6 +577,7 @@ def forward_backward_no_pipelining(
total_num_tokens = torch.zeros([], dtype=torch.int, device="cuda")

if config.overlap_moe_expert_parallel_comm and not forward_only:
num_empty_bins = 0
forward_data_store, total_num_tokens = combined_1f1b_schedule_for_no_pipelining(
forward_step_func,
data_iterator,
Expand All @@ -595,7 +597,7 @@ def forward_backward_no_pipelining(
else:
with no_sync_func():
for i in range(num_microbatches - 1):
output_tensor, num_tokens = forward_step(
output_tensor, num_tokens, num_empty_bins = forward_step(
forward_step_func,
data_iterator,
model,
Expand All @@ -615,7 +617,7 @@ def forward_backward_no_pipelining(
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor, num_tokens = forward_step(
output_tensor, num_tokens, num_empty_bins = forward_step(
forward_step_func,
data_iterator,
model,
Expand Down Expand Up @@ -655,6 +657,8 @@ def forward_backward_no_pipelining(
):
create_cudagraphs()

forward_data_store.append(num_empty_bins)

return forward_data_store


Expand Down Expand Up @@ -1209,7 +1213,7 @@ def forward_step_helper(virtual_microbatch_id, checkpoint_activations_microbatch
virtual_microbatch_id, model_chunk_id, microbatch_id
)

output_tensor, num_tokens = forward_step(
output_tensor, num_tokens, num_empty_bins = forward_step(
forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
Expand Down Expand Up @@ -1351,6 +1355,7 @@ def forward_backward_helper_wrapper(
return forward_output_tensor, backward_input_tensor_grad

# ==============================main logic=========================================
num_empty_bins = 0
_is_vp_first_stage = partial(
is_vp_first_stage, vp_size=config.virtual_pipeline_model_parallel_size
)
Expand Down Expand Up @@ -1920,6 +1925,8 @@ def pp_post_backward(input_tensor_grad, vp_stage=None):
create_cudagraphs()
nvtx_range_pop(suffix="misc")

forward_data_store.append(num_empty_bins)

return forward_data_store


Expand Down Expand Up @@ -1980,6 +1987,8 @@ def forward_backward_pipelining_without_interleaving(
data_iterator = data_iterator[0]

config = get_model_config(model)
num_empty_bins = 0

if config.overlap_p2p_comm:
raise ValueError(
"Non-interleaved pipeline parallelism does not support overlapping p2p communication"
Expand Down Expand Up @@ -2135,7 +2144,7 @@ def enable_grad_sync():
input_tensor = p2p_communicator.recv_forward(
recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group)
)
output_tensor, num_tokens = forward_step(
output_tensor, num_tokens, num_empty_bins = forward_step(
forward_step_func,
data_iterator,
model,
Expand Down Expand Up @@ -2178,7 +2187,7 @@ def enable_grad_sync():
else:
checkpoint_activations_microbatch = None

output_tensor, num_tokens = forward_step(
output_tensor, num_tokens, num_empty_bins = forward_step(
forward_step_func,
data_iterator,
model,
Expand Down Expand Up @@ -2303,4 +2312,6 @@ def enable_grad_sync():
):
create_cudagraphs()

forward_data_store.append(num_empty_bins)

return forward_data_store
4 changes: 3 additions & 1 deletion megatron/rl/rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ def prepare_data_for_update(
(mpu.get_data_parallel_rank() + 1) * data_split_size,
)
rollouts = rollouts[data_split_range[0] : data_split_range[1]]

# First we calculate them on a global level and then we split and recalculate on a local level.
# Sequence packing and reporting needs it global but non-packing wants it local.
rewards = torch.tensor([[r.reward for r in group] for group in rollouts], device='cpu')
Expand All @@ -894,7 +895,7 @@ def prepare_data_for_update(
if args.rl_use_sequence_packing:
with nvtx_range("sequence_packing", time=True):
runtime_state.packing_context = packing_context = pack_all_trajectories(
trajs,
trajs,
generation_masks,
inference_logprobs,
global_advantages,
Expand Down Expand Up @@ -968,6 +969,7 @@ def logprobs_forward_step(data_iterator, model):
packed_seq_params=b_packed_seq_params,
),
None,
0 # These tokens do not count toward the tokens/second calculation
)

dtype = (
Expand Down
185 changes: 99 additions & 86 deletions megatron/rl/sequence_packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class PackingContext:
original_inference_logprobs: Optional[torch.Tensor] = None
bin_advantages: List[torch.Tensor] = field(default_factory=list)
cached_packed_seq_params: List[Optional[PackedSeqParams]] = field(default_factory=list)
stats: Optional[dict] = None


def load_packed_data_by_index(bin_idx: int, packing_context: PackingContext, logprobs_is_correction: bool):
Expand Down Expand Up @@ -156,7 +157,6 @@ def log_packing_efficiency(packing_context: PackingContext):
packing_efficiency = my_tokens / total_capacity if total_capacity > 0 else 0
avg_seq_length = total_tokens / len(packing_info.seq_lengths)
rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()

log_single_rank(logger, logging.INFO, f"[Sequence Packing] Statistics:")
log_single_rank(
Expand Down Expand Up @@ -189,98 +189,110 @@ def log_packing_efficiency(packing_context: PackingContext):
)

# Add detailed per-rank sequence distribution analysis
if torch.distributed.is_initialized():
# Gather sequence counts from all ranks
seq_counts_per_bin = [len(indices) for indices in my_bin_seq_indices]
non_empty_bins = [c for c in seq_counts_per_bin if c > 0]

# Create tensor with rank statistics
rank_stats = torch.tensor(
[
float(rank),
float(len(my_bin_seq_indices)), # total bins
float(len(non_empty_bins)), # non-empty bins
float(my_sequences), # total sequences
(
float(min(non_empty_bins)) if non_empty_bins else 0.0
), # min sequences per bin
(
float(max(non_empty_bins)) if non_empty_bins else 0.0
), # max sequences per bin
(
float(my_sequences / len(non_empty_bins)) if non_empty_bins else 0.0
), # avg sequences per non-empty bin
],
device='cuda',
)

# Gather from all ranks
world_size = mpu.get_data_parallel_world_size()
all_rank_stats = [torch.zeros_like(rank_stats) for _ in range(world_size)]
torch.distributed.all_gather(
all_rank_stats, rank_stats, group=mpu.get_data_parallel_group()
)
# Gather sequence counts from all ranks
seq_counts_per_bin = [len(indices) for indices in my_bin_seq_indices]
non_empty_bins = [c for c in seq_counts_per_bin if c > 0]
empty_bins_on_rank = len(my_bin_seq_indices) - len(non_empty_bins)

# Create tensor with rank statistics
rank_stats = torch.tensor(
[
float(rank),
float(len(my_bin_seq_indices)), # total bins
float(len(non_empty_bins)), # non-empty bins
float(my_sequences), # total sequences
(
float(min(non_empty_bins)) if non_empty_bins else 0.0
), # min sequences per bin
(
float(max(non_empty_bins)) if non_empty_bins else 0.0
), # max sequences per bin
(
float(my_sequences / len(non_empty_bins)) if non_empty_bins else 0.0
), # avg sequences per non-empty bin
float(empty_bins_on_rank), # empty bins on each rank
],
device='cuda',
)

# Print detailed statistics for each rank
if rank == 0:
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] Per-rank distribution ({packing_info.packing_algo} algorithm):",
)
log_single_rank(
logger,
logging.INFO,
"[Sequence Packing] Rank | Total Bins | Non-empty | Sequences | Min/Bin | Max/Bin | Avg/Bin",
)
log_single_rank(
logger,
logging.INFO,
"[Sequence Packing] -----|------------|-----------|-----------|---------|---------|--------",
)
for stats in all_rank_stats:
r = int(stats[0].item())
total_bins = int(stats[1].item())
non_empty = int(stats[2].item())
sequences = int(stats[3].item())
min_seq = int(stats[4].item())
max_seq = int(stats[5].item())
avg_seq = stats[6].item()
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] {r:3d} | {total_bins:10d} | {non_empty:9d} | {sequences:9d} | {min_seq:7d} | {max_seq:7d} | {avg_seq:6.1f}",
)
# Gather from all ranks
world_size = mpu.get_data_parallel_world_size()
all_rank_stats = [torch.zeros_like(rank_stats) for _ in range(world_size)]
torch.distributed.all_gather(
all_rank_stats, rank_stats, group=mpu.get_data_parallel_group()
)
all_rank_stats_tensor = torch.stack(all_rank_stats, dim=0)

# Also show first few bins for rank 0 as example
# Print detailed statistics for each rank
if rank == 0:
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] Per-rank distribution ({packing_info.packing_algo} algorithm):",
)
log_single_rank(
logger,
logging.INFO,
"[Sequence Packing] Rank | Total Bins | Non-empty | Sequences | Min/Bin | Max/Bin | Avg/Bin",
)
log_single_rank(
logger,
logging.INFO,
"[Sequence Packing] -----|------------|-----------|-----------|---------|---------|--------",
)
for stats in all_rank_stats:
r = int(stats[0].item())
total_bins = int(stats[1].item())
non_empty = int(stats[2].item())
sequences = int(stats[3].item())
min_seq = int(stats[4].item())
max_seq = int(stats[5].item())
avg_seq = stats[6].item()
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] Example (Rank 0 first 10 bins): {seq_counts_per_bin[:10]}",
f"[Sequence Packing] {r:3d} | {total_bins:10d} | {non_empty:9d} | {sequences:9d} | {min_seq:7d} | {max_seq:7d} | {avg_seq:6.1f}",
)

# Show the improvement from round-robin
total_seqs_all_ranks = sum(int(stats[3].item()) for stats in all_rank_stats)
avg_seqs_per_rank = total_seqs_all_ranks / world_size
max_deviation = max(
abs(int(stats[3].item()) - avg_seqs_per_rank)
for stats in all_rank_stats
)
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] Round-robin distribution quality:",
)
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] - Average sequences per rank: {avg_seqs_per_rank:.1f}",
)
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] - Max deviation from average: {max_deviation:.0f} sequences ({max_deviation/avg_seqs_per_rank*100:.1f}%)",
)
# Also show first few bins for rank 0 as example
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] Example (Rank 0 first 10 bins): {seq_counts_per_bin[:10]}",
)

# Show the improvement from round-robin
total_seqs_all_ranks = sum(int(stats[3].item()) for stats in all_rank_stats)
avg_seqs_per_rank = total_seqs_all_ranks / world_size
max_deviation = max(
abs(int(stats[3].item()) - avg_seqs_per_rank)
for stats in all_rank_stats
)
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] Round-robin distribution quality:",
)
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] - Average sequences per rank: {avg_seqs_per_rank:.1f}",
)
log_single_rank(
logger,
logging.INFO,
f"[Sequence Packing] - Max deviation from average: {max_deviation:.0f} sequences ({max_deviation/avg_seqs_per_rank*100:.1f}%)",
)

result = {
"total_num_bins": int(torch.sum(all_rank_stats_tensor[:, 1]).item()),
"total_non_empty_bins": int(torch.sum(all_rank_stats_tensor[:, 2]).item()),
"total_empty_bins": int(torch.sum(all_rank_stats_tensor[:, 7]).item()),
"total_sequences": int(torch.sum(all_rank_stats_tensor[:, 3]).item()),
}

return result

def get_actual_sequence_lengths(sequences: torch.Tensor, pad_token: int) -> List[int]:
"""Get actual sequence lengths for pre-padded sequences.
Expand Down Expand Up @@ -1058,7 +1070,8 @@ def pack_all_trajectories(trajs, generation_masks, inference_logprobs, global_ad
cached_packed_seq_params=cached_packed_seq_params,
)

log_packing_efficiency(packing_context)
stats_aggregated_over_all_ranks = log_packing_efficiency(packing_context)
packing_context.stats = stats_aggregated_over_all_ranks

return packing_context

Expand Down
2 changes: 2 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,6 +1996,8 @@ def _add_logging_args(parser):
help='Path to save the wandb results locally.')
group.add_argument('--logging-level', type=int, default=None,
help='Set default logging level')
group.add_argument('--log-tokens-per-second', default=False, action="store_true",
help='Whether to log tokens per second.')
return parser


Expand Down
Loading
Loading