-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconstrainers.py
More file actions
56 lines (50 loc) · 2.04 KB
/
Copy pathconstrainers.py
File metadata and controls
56 lines (50 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from dataclasses import dataclass
from abc import ABC, abstractmethod
from transformers_cfg.token_grammar_recognizer import AbsTokenRecognizer, check_token_acceptance_in_trie
from transformers_cfg.recognizer import AcceptState
from copy import deepcopy
from functools import lru_cache
import numpy as np
class Constrainer(ABC):
@abstractmethod
def get_blocklist(self, tokens: list[int]) -> list[int]:
...
def ask_token(self, tokens: list[int], token: int) -> bool:
return token not in self.get_blocklist(tokens)
@dataclass(frozen=True)
class TokenRecognizerConstrainer(Constrainer):
recognizer: AbsTokenRecognizer
max_gen_length: int
@lru_cache
def get_length(self):
return len(self.recognizer.tokenizer)
def get_accept_state(self, tokens: list[int]) -> AcceptState:
initial_state = deepcopy(self.recognizer.string_recognizer.get_initial_parsing_state())
string = self.recognizer.tokenizer.decode(tokens)
return self.recognizer.string_recognizer._update_state_with_string(string, initial_state)
@lru_cache(maxsize=1024)
def get_token_acceptance_array_for_stack(self, stack: tuple[int]):
accepts = np.zeros(self.get_length(), dtype=bool)
token_acceptance = check_token_acceptance_in_trie(
self.recognizer.byte_trie.root,
[list(stack)],
self.recognizer.string_recognizer,
self.recognizer.eos_token_id,
accepts,
)
assert not token_acceptance[self.recognizer.eos_token_id]
if not token_acceptance.any():
token_acceptance[self.recognizer.eos_token_id] = True
return token_acceptance
def get_blocklist(self, tokens: list[int]) -> list[int]:
accept_state = self.get_accept_state(tuple(tokens))
if not accept_state.stacks:
return [i for i in range(self.get_length()) if i != self.recognizer.eos_token_id]
acceptance_matrix = [
self.get_token_acceptance_array_for_stack(
tuple(stack)
)
for stack in accept_state.stacks
]
blocklist = (~np.stack(acceptance_matrix).any(axis=0)).nonzero()[0].tolist()
return blocklist