Skip to content

Commit c51c0bf

Browse files
zpqiuclaude
andauthored
fix: use seq_length instead of padded_seq_length for topk output padding (#1929)
Signed-off-by: Zhaopeng Qiu <qiuzhaopeng@foxmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 336803f commit c51c0bf

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ def collection_fn(_):
939939
for out in list_of_outputs:
940940
tk = out["topk_logits"]
941941
ti = out["topk_indices"]
942-
pad_len = padded_seq_length - tk.shape[1]
942+
pad_len = seq_length - tk.shape[1]
943943
if pad_len > 0:
944944
tk = torch.nn.functional.pad(tk, (0, 0, 0, pad_len), value=0.0)
945945
ti = torch.nn.functional.pad(ti, (0, 0, 0, pad_len), value=0)

0 commit comments

Comments
 (0)