We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent af6e19f commit c2ed069Copy full SHA for c2ed069
vllm/v1/sample/ops/penalties.py
@@ -21,6 +21,14 @@ def apply_all_penalties(
21
"""
22
_, vocab_size = logits.shape
23
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device)
24
+
25
+ # In the async scheduling case, rows that won't have penalties applied may contain
26
+ # -1 placeholder token ids. We must replace these with valid token ids so that the
27
+ # scatter done in apply_penalties is valid.
28
+ # NOTE(nick): The penalties implementation is currently quite inefficient and
29
+ # will be reworked anyhow.
30
+ output_tokens_t.masked_fill_(output_tokens_t == -1, vocab_size)
31
32
return apply_penalties(
33
logits,
34
prompt_token_ids,
0 commit comments