@@ -748,19 +748,18 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
748748 # Convert logits to probabilities
749749 probs = torch .softmax (scores , dim = - 1 )
750750 # Get the probability of the top token for each sequence in the batch
751- top_probs , _ = probs .max (dim = - 1 , keepdim = True )
751+ top_probs = probs .amax (dim = - 1 , keepdim = True )
752752 # Calculate the actual min_p threshold by scaling min_p with the top token's probability
753753 scaled_min_p = self .min_p * top_probs
754754 # Create a mask for tokens that have a probability less than the scaled min_p
755755 tokens_to_remove = probs < scaled_min_p
756756
757- sorted_indices = torch . argsort ( scores , descending = True , dim = - 1 )
758- sorted_indices_to_remove = torch . gather ( tokens_to_remove , dim = - 1 , index = sorted_indices )
759- # Keep at least min_tokens_to_keep
760- sorted_indices_to_remove [..., : self . min_tokens_to_keep ] = False
757+ # Keep at least min_tokens_to_keep tokens (clip k to vocab size if needed, avoids index out of range )
758+ k = min ( self . min_tokens_to_keep , probs . shape [ - 1 ] )
759+ sorted_indices = torch . topk ( probs , k , dim = - 1 ). indices
760+ tokens_to_remove . scatter_ ( - 1 , sorted_indices , False )
761761
762- indices_to_remove = sorted_indices_to_remove .scatter (1 , sorted_indices , sorted_indices_to_remove )
763- scores_processed = scores .masked_fill (indices_to_remove , self .filter_value )
762+ scores_processed = scores .masked_fill (tokens_to_remove , self .filter_value )
764763 return scores_processed
765764
766765
0 commit comments