@@ -625,6 +625,53 @@ def filter(self, indices):
625
625
return self
626
626
627
627
628
+ class LogitBiasProcessor :
629
+ """Process logits with logit biases."""
630
+
631
+ def __init__ (
632
+ self , logit_biases : Optional [dict ], tokenizer : PreTrainedTokenizerBase
633
+ ):
634
+ self .tokenizer = tokenizer
635
+ self .logit_biases = logit_biases or {}
636
+
637
+ # Pre-compute token IDs for each token string
638
+ self .token_id_mapping = {}
639
+
640
+ def __call__ (self , input_ids : torch .Tensor , scores : torch .Tensor ) -> torch .Tensor :
641
+ # If no logit biases, return scores unchanged
642
+ if not self .logit_biases :
643
+ return scores
644
+
645
+ # Apply bias to the corresponding scores
646
+ for token_str , bias_value in self .logit_biases .items ():
647
+ # Get token ID, either from cache or by computing it
648
+ if token_str not in self .token_id_mapping :
649
+ if token_str .isdigit ():
650
+ # If the token string is already a numeric ID
651
+ token_id = int (token_str )
652
+ else :
653
+ # Otherwise, use the tokenizer to get the ID
654
+ tokens = self .tokenizer .encode (token_str , add_special_tokens = False )
655
+ token_id = tokens [0 ] if tokens else - 1 # Use -1 for not found
656
+
657
+ self .token_id_mapping [token_str ] = token_id
658
+
659
+ token_id = self .token_id_mapping [token_str ]
660
+
661
+ # Apply bias if token ID is valid
662
+ if 0 <= token_id < scores .size (- 1 ):
663
+ scores [:, token_id ] += bias_value
664
+
665
+ return scores
666
+
667
+ def filter (self , indices ):
668
+ """Keep only the logit biases for the specified indices."""
669
+ new_logit_biases = {
670
+ k : self .logit_biases [k ] for k in indices if k in self .logit_biases
671
+ }
672
+ return LogitBiasProcessor (new_logit_biases , self .tokenizer )
673
+
674
+
628
675
class HeterogeneousLogitBiasProcessor :
629
676
"""Process logits with different logit biases for each sequence in the batch."""
630
677
0 commit comments