From a891783f53c0ae094464fbc981db772208d3d609 Mon Sep 17 00:00:00 2001 From: Ryan <18477649+Ryul0rd@users.noreply.github.com> Date: Sun, 14 May 2023 21:51:36 -0700 Subject: [PATCH 1/7] add integer type --- jsonformer/logits_processors.py | 51 +++++++++++++++++++++++++++++++++ jsonformer/main.py | 39 +++++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/jsonformer/logits_processors.py b/jsonformer/logits_processors.py index db288d3..c1088ce 100644 --- a/jsonformer/logits_processors.py +++ b/jsonformer/logits_processors.py @@ -82,3 +82,54 @@ def __call__(self, _, scores): scores[~mask] = -float("inf") return scores + +class IntegerStoppingCriteria(StoppingCriteria): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + prompt_length: int, + max_digits: int = 15, + ): + self.tokenizer = tokenizer + self.prompt_length = prompt_length + self.max_digits = max_digits + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + ) -> bool: + decoded = self.tokenizer.decode( + input_ids[0][self.prompt_length :], skip_special_tokens=True + ) + + if len(decoded.strip()) > self.max_digits: + return True + + if ( + len(decoded) > 1 + and any(c.isdigit() for c in decoded) + and decoded[-1] in [" ", "\n"] + ): + return True + + return False + +class OutputIntegersTokens(LogitsWarper): + def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str): + self.tokenizer = tokenizer + self.tokenized_prompt = tokenizer(prompt, return_tensors="pt") + vocab_size = len(tokenizer) + self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool) + + for _, token_id in tokenizer.get_vocab().items(): + token_str = tokenizer.decode(token_id).strip() + + if token_str == "" or all(c.isdigit() for c in token_str): + self.allowed_mask[token_id] = True + + def __call__(self, _, scores): + mask = self.allowed_mask.expand_as(scores) + scores[~mask] = -float("inf") + + return scores diff --git a/jsonformer/main.py b/jsonformer/main.py index dd867d4..8b5ac8e 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -3,6 +3,8 @@ from jsonformer.logits_processors import ( NumberStoppingCriteria, OutputNumbersTokens, + IntegerStoppingCriteria, + OutputIntegersTokens, StringStoppingCriteria, ) from termcolor import cprint @@ -34,6 +36,7 @@ def __init__( self.prompt = prompt self.number_logit_processor = OutputNumbersTokens(self.tokenizer, self.prompt) + self.integer_logit_processor = OutputIntegersTokens(self.tokenizer, self.prompt) self.generation_marker = "|GENERATION|" self.debug_on = debug @@ -82,6 +85,36 @@ def generate_number(self, temperature: Union[float, None] = None, iterations=0): return self.generate_number(temperature=self.temperature * 1.3) + def generate_integer(self, temperature: Union[float, None] = None, iterations=0): + prompt = self.get_prompt() + self.debug("[generate_number]", prompt, is_prompt=True) + input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to( + self.model.device + ) + response = self.model.generate( + input_tokens, + max_new_tokens=self.max_number_tokens, + num_return_sequences=1, + logits_processor=[self.integer_logit_processor], + stopping_criteria=[ + IntegerStoppingCriteria(self.tokenizer, len(input_tokens[0])) + ], + temperature=temperature or self.temperature, + pad_token_id=self.tokenizer.eos_token_id, + ) + response = self.tokenizer.decode(response[0], skip_special_tokens=True) + + response = response[len(prompt) :] + response = response.strip() + self.debug("[generate_integer]", response) + try: + return int(response) + except ValueError: + if iterations > 3: + raise ValueError("Failed to generate a valid integer") + + return self.generate_integer(temperature=self.temperature * 1.3) + def generate_boolean(self) -> bool: prompt = self.get_prompt() self.debug("[generate_boolean]", prompt, is_prompt=True) @@ -160,6 +193,12 @@ def generate_value( else: obj.append(self.generation_marker) return self.generate_number() + elif schema_type == "integer": + if key: + obj[key] = self.generation_marker + else: + obj.append(self.generation_marker) + return self.generate_integer() elif schema_type == "boolean": if key: obj[key] = self.generation_marker From 76bec08f7fd15bd33cde2d62ccdb0049d5dec195 Mon Sep 17 00:00:00 2001 From: Ryan <18477649+Ryul0rd@users.noreply.github.com> Date: Mon, 15 May 2023 00:27:07 -0700 Subject: [PATCH 2/7] add enum type --- jsonformer/main.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/jsonformer/main.py b/jsonformer/main.py index 8b5ac8e..2dfab48 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -1,4 +1,4 @@ -from typing import List, Union, Dict, Any +from typing import List, Set, Union, Dict, Any from jsonformer.logits_processors import ( NumberStoppingCriteria, @@ -10,6 +10,7 @@ from termcolor import cprint from transformers import PreTrainedModel, PreTrainedTokenizer import json +import torch GENERATION_MARKER = "|GENERATION|" @@ -172,6 +173,32 @@ def generate_string(self) -> str: return response.split('"')[0].strip() + def generate_enum(self, enum_values: Set[str]) -> str: + prompt = self.get_prompt() + self.debug("[generate_enum]", prompt, is_prompt=True) + prompt_tokens = self.tokenizer.encode(prompt, return_tensors="pt") + + highest_probability = 0.0 + best_option = None + for option in enum_values: + option_tokens = self.tokenizer.encode(f'"{option}"', return_tensors="pt") + n_option_tokens = option_tokens.shape[1] + prompt_option_tokens = torch.concat([prompt_tokens, option_tokens], dim=1) + + with torch.no_grad(): + logits = self.model.forward(prompt_option_tokens.to(self.model.device)).logits[0, -n_option_tokens-1:-1] + probabilities = torch.softmax(logits, dim=1) + option_token_probabilities = probabilities[torch.arange(probabilities.shape[0]), option_tokens] + option_probability = torch.prod(option_token_probabilities).item() + + if option_probability > highest_probability: + best_option = option + highest_probability = option_probability + + self.debug("[generate_enum]", best_option) + + return best_option + def generate_object( self, properties: Dict[str, Any], obj: Dict[str, Any] ) -> Dict[str, Any]: @@ -211,6 +238,12 @@ def generate_value( else: obj.append(self.generation_marker) return self.generate_string() + elif schema_type == "enum": + if key: + obj[key] = self.generation_marker + else: + obj.append(self.generation_marker) + return self.generate_enum(set(schema["values"])) elif schema_type == "array": new_array = [] obj[key] = new_array From ab4ad40a4c8ed74bd93e4e48f5a09972108160f9 Mon Sep 17 00:00:00 2001 From: Ryan <18477649+Ryul0rd@users.noreply.github.com> Date: Mon, 15 May 2023 00:30:47 -0700 Subject: [PATCH 3/7] Fix tokenizer edgecase todo in generate_boolean --- jsonformer/main.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/jsonformer/main.py b/jsonformer/main.py index 2dfab48..61c8ef8 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -124,11 +124,8 @@ def generate_boolean(self) -> bool: output = self.model.forward(input_tensor.to(self.model.device)) logits = output.logits[0, -1] - # todo: this assumes that "true" and "false" are both tokenized to a single token - # this is probably not true for all tokenizers - # this can be fixed by looking at only the first token of both "true" and "false" - true_token_id = self.tokenizer.convert_tokens_to_ids("true") - false_token_id = self.tokenizer.convert_tokens_to_ids("false") + true_token_id = self.tokenizer.encode("true", return_tensors="pt")[0, 0] + false_token_id = self.tokenizer.encode("false", return_tensors="pt")[0, 0] result = logits[true_token_id] > logits[false_token_id] From e34793be84c16288cb1f016853602a489f38ccac Mon Sep 17 00:00:00 2001 From: Ryan <18477649+Ryul0rd@users.noreply.github.com> Date: Wed, 17 May 2023 03:29:35 -0700 Subject: [PATCH 4/7] Fix bug in enum generation --- jsonformer/main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jsonformer/main.py b/jsonformer/main.py index 61c8ef8..6169ea2 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -178,12 +178,12 @@ def generate_enum(self, enum_values: Set[str]) -> str: highest_probability = 0.0 best_option = None for option in enum_values: - option_tokens = self.tokenizer.encode(f'"{option}"', return_tensors="pt") - n_option_tokens = option_tokens.shape[1] - prompt_option_tokens = torch.concat([prompt_tokens, option_tokens], dim=1) + n_option_tokens = self.tokenizer.encode(f'"{option}"', add_special_tokens=False, return_tensors="pt").shape[1] + prompt_tokens = self.tokenizer.encode(prompt + f'"{option}"', return_tensors="pt") + option_tokens = prompt_tokens[0, -n_option_tokens:] with torch.no_grad(): - logits = self.model.forward(prompt_option_tokens.to(self.model.device)).logits[0, -n_option_tokens-1:-1] + logits = self.model.forward(prompt_tokens[:, :-1].to(self.model.device)).logits[0, -n_option_tokens:] probabilities = torch.softmax(logits, dim=1) option_token_probabilities = probabilities[torch.arange(probabilities.shape[0]), option_tokens] option_probability = torch.prod(option_token_probabilities).item() From c0e51b367b66b2fc4bb2d520f361e3df6cbf45eb Mon Sep 17 00:00:00 2001 From: Ryan <18477649+Ryul0rd@users.noreply.github.com> Date: Wed, 17 May 2023 23:25:53 -0700 Subject: [PATCH 5/7] Fix bug in enum generation --- jsonformer/main.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/jsonformer/main.py b/jsonformer/main.py index 6169ea2..b7bc5b4 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -173,20 +173,26 @@ def generate_string(self) -> str: def generate_enum(self, enum_values: Set[str]) -> str: prompt = self.get_prompt() self.debug("[generate_enum]", prompt, is_prompt=True) - prompt_tokens = self.tokenizer.encode(prompt, return_tensors="pt") + + # These are necessary because we don't know if we're at the end or middle of an object/array + terminal_tokens = torch.concat([ + self.tokenizer.encode(s, add_special_tokens=False, return_tensors="pt")[:, 0] + for s in ('", "', '"}', '"]dsdsf') + ]) highest_probability = 0.0 best_option = None for option in enum_values: - n_option_tokens = self.tokenizer.encode(f'"{option}"', add_special_tokens=False, return_tensors="pt").shape[1] - prompt_tokens = self.tokenizer.encode(prompt + f'"{option}"', return_tensors="pt") + n_option_tokens = self.tokenizer.encode(f'"{option}', add_special_tokens=False, return_tensors="pt").shape[1] + prompt_tokens = self.tokenizer.encode(prompt + f'"{option}', return_tensors="pt") option_tokens = prompt_tokens[0, -n_option_tokens:] with torch.no_grad(): - logits = self.model.forward(prompt_tokens[:, :-1].to(self.model.device)).logits[0, -n_option_tokens:] + logits = self.model.forward(prompt_tokens.to(self.model.device)).logits[0, -n_option_tokens-1:] probabilities = torch.softmax(logits, dim=1) - option_token_probabilities = probabilities[torch.arange(probabilities.shape[0]), option_tokens] - option_probability = torch.prod(option_token_probabilities).item() + option_token_probabilities = probabilities[:-1][torch.arange(probabilities.shape[0]-1), option_tokens] + termination_probability = torch.max(probabilities[-1, terminal_tokens]) + option_probability = torch.prod(option_token_probabilities) * termination_probability if option_probability > highest_probability: best_option = option From 58e8270537fcd98557a0768a85887c2d2c038485 Mon Sep 17 00:00:00 2001 From: Ryan <18477649+Ryul0rd@users.noreply.github.com> Date: Thu, 18 May 2023 02:33:46 -0700 Subject: [PATCH 6/7] Fix bug resulting in overly long numbers/ints --- jsonformer/logits_processors.py | 39 +++++++++++++++++++++++++++------ jsonformer/main.py | 8 +++++-- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/jsonformer/logits_processors.py b/jsonformer/logits_processors.py index c1088ce..7b09ea2 100644 --- a/jsonformer/logits_processors.py +++ b/jsonformer/logits_processors.py @@ -48,14 +48,21 @@ def __call__( if ( decoded.count(".") == 1 - and len(decoded.strip().split(".")[1]) > self.precision + and len(decoded.replace(" ", "").split(".")[1]) > self.precision + ): + return True + + if ( + len(decoded) > 1 + and "," in decoded + and any(c.isdigit() for c in decoded.split(",")[0]) ): return True if ( len(decoded) > 1 and any(c.isdigit() for c in decoded) - and decoded[-1] in [" ", "\n"] + and ("," in decoded or decoded[-1] in (" ", "\n")) ): return True @@ -71,9 +78,16 @@ def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str): for _, token_id in tokenizer.get_vocab().items(): token_str = tokenizer.decode(token_id).strip() - if token_str == "" or ( - all(c.isdigit() or c == "." for c in token_str) - and token_str.count(".") <= 1 + if ( + token_str == "" + or ( + all(c.isdigit() or c == "." for c in token_str) + and token_str.count(".") <= 1 + ) or ( + "," in token_str + and all(c.isdigit() or c == "." for c in token_str.split(",")[0]) + and token_str.count(".") <= 1 + ) ): self.allowed_mask[token_id] = True @@ -106,10 +120,17 @@ def __call__( if len(decoded.strip()) > self.max_digits: return True + if ( + len(decoded) > 1 + and "," in decoded + and any(c.isdigit() for c in decoded.split(",")[0]) + ): + return True + if ( len(decoded) > 1 and any(c.isdigit() for c in decoded) - and decoded[-1] in [" ", "\n"] + and decoded[-1] in (" ", "\n") ): return True @@ -125,7 +146,11 @@ def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str): for _, token_id in tokenizer.get_vocab().items(): token_str = tokenizer.decode(token_id).strip() - if token_str == "" or all(c.isdigit() for c in token_str): + if ( + token_str == "" + or all(c.isdigit() for c in token_str) + or "," in token_str and all(c.isdigit() for c in token_str.split(",")[0]) + ): self.allowed_mask[token_id] = True def __call__(self, _, scores): diff --git a/jsonformer/main.py b/jsonformer/main.py index b7bc5b4..f9e0280 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -76,7 +76,9 @@ def generate_number(self, temperature: Union[float, None] = None, iterations=0): response = self.tokenizer.decode(response[0], skip_special_tokens=True) response = response[len(prompt) :] - response = response.strip().rstrip(".") + if "," in response: + response = response.split(",")[0] + response = response.replace(" ", "").rstrip(".") self.debug("[generate_number]", response) try: return float(response) @@ -106,7 +108,9 @@ def generate_integer(self, temperature: Union[float, None] = None, iterations=0) response = self.tokenizer.decode(response[0], skip_special_tokens=True) response = response[len(prompt) :] - response = response.strip() + if "," in response: + response = response.split(",")[0] + response = response.replace(" ", "") self.debug("[generate_integer]", response) try: return int(response) From 3756cc068d2e08e573114f69980ae47e9085b92d Mon Sep 17 00:00:00 2001 From: Ryan <18477649+Ryul0rd@users.noreply.github.com> Date: Thu, 18 May 2023 02:38:02 -0700 Subject: [PATCH 7/7] Remove leftover garbage from testing --- jsonformer/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jsonformer/main.py b/jsonformer/main.py index f9e0280..4bb2d5d 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -181,7 +181,7 @@ def generate_enum(self, enum_values: Set[str]) -> str: # These are necessary because we don't know if we're at the end or middle of an object/array terminal_tokens = torch.concat([ self.tokenizer.encode(s, add_special_tokens=False, return_tensors="pt")[:, 0] - for s in ('", "', '"}', '"]dsdsf') + for s in ('", "', '"}', '"]') ]) highest_probability = 0.0