Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions megatron/rl/rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,13 +799,18 @@ def get_logprobs(model, tokens, position_ids, attention_mask, no_grad=False, pac
tp_world_size = get_tensor_model_parallel_world_size()
if actual_len % tp_world_size != 0:
actual_len = ((actual_len + tp_world_size - 1) // tp_world_size) * tp_world_size
# Update cu_seqlens to match the padded length.
# The last entry of cu_seqlens must equal the tensor's sequence dimension.
# Without this, TE attention/rotary ops see mismatched dimensions.
# We need to add padding tokens to make the sequence divisible by TP.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need this to be divisible, we need to have a unit test checking this and preventing this bug in the future.

# if we extend cu_seqlens[-1] to include padding in the
# last real sequence, this would cause that sequence to attend to
# padding tokens (garbage values), leading to numerical instability/NaN.
# Instead, we add a new "padding sequence" entry to cu_seqlens.
if packed_seq_params.cu_seqlens_q[-1].item() != actual_len:
# Clone to avoid modifying cached params
new_cu_seqlens = packed_seq_params.cu_seqlens_q.clone()
new_cu_seqlens[-1] = actual_len
# Add a new entry for the padding pseudo-sequence
# cu_seqlens goes from [0, ..., total_real] to [0, ..., total_real, padded_len]
new_cu_seqlens = torch.cat([
packed_seq_params.cu_seqlens_q,
torch.tensor([actual_len], dtype=torch.int32, device=packed_seq_params.cu_seqlens_q.device)
])
packed_seq_params = PackedSeqParams(
qkv_format=packed_seq_params.qkv_format,
cu_seqlens_q=new_cu_seqlens,
Expand Down
Loading