diff --git a/jsonformer/logits_processors.py b/jsonformer/logits_processors.py index db288d3..9217e7b 100644 --- a/jsonformer/logits_processors.py +++ b/jsonformer/logits_processors.py @@ -1,5 +1,4 @@ -from typing import List -from transformers import PreTrainedTokenizer, LogitsWarper, StoppingCriteria +from transformers import PreTrainedTokenizer, LogitsProcessor, StoppingCriteria import torch class StringStoppingCriteria(StoppingCriteria): @@ -61,7 +60,7 @@ def __call__( return False -class OutputNumbersTokens(LogitsWarper): +class OutputNumbersTokens(LogitsProcessor): def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str): self.tokenizer = tokenizer self.tokenized_prompt = tokenizer(prompt, return_tensors="pt")