Skip to content

Commit aab9533

Browse files
committed
fix: adjust the NextTokenChooser logit bias processor
1 parent 4479480 commit aab9533

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

server/text_generation_server/utils/logits_process.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,53 @@ def filter(self, indices):
625625
return self
626626

627627

628+
class LogitBiasProcessor:
629+
"""Process logits with logit biases."""
630+
631+
def __init__(
632+
self, logit_biases: Optional[dict], tokenizer: PreTrainedTokenizerBase
633+
):
634+
self.tokenizer = tokenizer
635+
self.logit_biases = logit_biases or {}
636+
637+
# Pre-compute token IDs for each token string
638+
self.token_id_mapping = {}
639+
640+
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
641+
# If no logit biases, return scores unchanged
642+
if not self.logit_biases:
643+
return scores
644+
645+
# Apply bias to the corresponding scores
646+
for token_str, bias_value in self.logit_biases.items():
647+
# Get token ID, either from cache or by computing it
648+
if token_str not in self.token_id_mapping:
649+
if token_str.isdigit():
650+
# If the token string is already a numeric ID
651+
token_id = int(token_str)
652+
else:
653+
# Otherwise, use the tokenizer to get the ID
654+
tokens = self.tokenizer.encode(token_str, add_special_tokens=False)
655+
token_id = tokens[0] if tokens else -1 # Use -1 for not found
656+
657+
self.token_id_mapping[token_str] = token_id
658+
659+
token_id = self.token_id_mapping[token_str]
660+
661+
# Apply bias if token ID is valid
662+
if 0 <= token_id < scores.size(-1):
663+
scores[:, token_id] += bias_value
664+
665+
return scores
666+
667+
def filter(self, indices):
668+
"""Keep only the logit biases for the specified indices."""
669+
new_logit_biases = {
670+
k: self.logit_biases[k] for k in indices if k in self.logit_biases
671+
}
672+
return LogitBiasProcessor(new_logit_biases, self.tokenizer)
673+
674+
628675
class HeterogeneousLogitBiasProcessor:
629676
"""Process logits with different logit biases for each sequence in the batch."""
630677

server/text_generation_server/utils/tokens.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from text_generation_server.utils.logits_process import (
88
FrequencyPenaltyLogitsProcessor,
99
GrammarLogitProcessor,
10+
LogitBiasProcessor,
1011
HeterogeneousProcessorWrapper,
1112
HeterogeneousRepetitionPenaltyLogitsProcessor,
1213
HeterogeneousFrequencyPenaltyLogitsProcessor,
@@ -59,6 +60,11 @@ def __init__(
5960
if grammar != ""
6061
else None
6162
)
63+
self.logit_bias_processor = (
64+
LogitBiasProcessor(logit_bias, tokenizer, device)
65+
if logit_bias is not None and len(logit_bias) > 0
66+
else None
67+
)
6268
self.tokenizer = tokenizer
6369
self.logit_bias = logit_bias
6470

0 commit comments

Comments
 (0)