diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index df19145bf2d..8f914dd11ff 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -799,13 +799,18 @@ def get_logprobs(model, tokens, position_ids, attention_mask, no_grad=False, pac tp_world_size = get_tensor_model_parallel_world_size() if actual_len % tp_world_size != 0: actual_len = ((actual_len + tp_world_size - 1) // tp_world_size) * tp_world_size - # Update cu_seqlens to match the padded length. - # The last entry of cu_seqlens must equal the tensor's sequence dimension. - # Without this, TE attention/rotary ops see mismatched dimensions. + # We need to add padding tokens to make the sequence divisible by TP. + # if we extend cu_seqlens[-1] to include padding in the + # last real sequence, this would cause that sequence to attend to + # padding tokens (garbage values), leading to numerical instability/NaN. + # Instead, we add a new "padding sequence" entry to cu_seqlens. if packed_seq_params.cu_seqlens_q[-1].item() != actual_len: - # Clone to avoid modifying cached params - new_cu_seqlens = packed_seq_params.cu_seqlens_q.clone() - new_cu_seqlens[-1] = actual_len + # Add a new entry for the padding pseudo-sequence + # cu_seqlens goes from [0, ..., total_real] to [0, ..., total_real, padded_len] + new_cu_seqlens = torch.cat([ + packed_seq_params.cu_seqlens_q, + torch.tensor([actual_len], dtype=torch.int32, device=packed_seq_params.cu_seqlens_q.device) + ]) packed_seq_params = PackedSeqParams( qkv_format=packed_seq_params.qkv_format, cu_seqlens_q=new_cu_seqlens,