Skip to content

Commit 04fab7d

Browse files
committed
Restored anyOf support in Jsonformer
1 parent ec91cec commit 04fab7d

File tree

1 file changed

+58
-138
lines changed

1 file changed

+58
-138
lines changed

jsonformer/main.py

Lines changed: 58 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -1,211 +1,141 @@
11
from typing import List, Union, Dict, Any
2-
3-
from jsonformer.logits_processors import (
4-
NumberStoppingCriteria,
5-
OutputNumbersTokens,
6-
StringStoppingCriteria,
7-
)
2+
from jsonformer.logits_processors import NumberStoppingCriteria, OutputNumbersTokens, StringStoppingCriteria
83
from termcolor import cprint
94
from transformers import PreTrainedModel, PreTrainedTokenizer
105
import json
11-
12-
GENERATION_MARKER = "|GENERATION|"
13-
6+
GENERATION_MARKER = '|GENERATION|'
147

158
class Jsonformer:
169
value: Dict[str, Any] = {}
1710

18-
def __init__(
19-
self,
20-
model: PreTrainedModel,
21-
tokenizer: PreTrainedTokenizer,
22-
json_schema: Dict[str, Any],
23-
prompt: str,
24-
*,
25-
debug: bool = False,
26-
max_array_length: int = 128,
27-
max_number_tokens: int = 2048,
28-
temperature: float = 1.0,
29-
max_string_token_length: int = 1024,
30-
):
11+
def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, json_schema: Dict[str, Any], prompt: str, *, debug: bool=False, max_array_length: int=128, max_number_tokens: int=2048, temperature: float=1.0, max_string_token_length: int=1024):
3112
self.model = model
3213
self.tokenizer = tokenizer
3314
self.json_schema = json_schema
3415
self.prompt = prompt
35-
3616
self.number_logit_processor = OutputNumbersTokens(self.tokenizer, self.prompt)
37-
38-
self.generation_marker = "|GENERATION|"
17+
self.generation_marker = '|GENERATION|'
3918
self.debug_on = debug
4019
self.max_array_length = max_array_length
41-
4220
self.max_number_tokens = max_number_tokens
4321
self.temperature = temperature
4422
self.max_string_token_length = max_string_token_length
4523

46-
def debug(self, caller: str, value: str, is_prompt: bool = False):
24+
def debug(self, caller: str, value: str, is_prompt: bool=False):
4725
if self.debug_on:
4826
if is_prompt:
49-
cprint(caller, "green", end=" ")
50-
cprint(value, "yellow")
27+
cprint(caller, 'green', end=' ')
28+
cprint(value, 'yellow')
5129
else:
52-
cprint(caller, "green", end=" ")
53-
cprint(value, "blue")
30+
cprint(caller, 'green', end=' ')
31+
cprint(value, 'blue')
5432

55-
def generate_number(self, temperature: Union[float, None] = None, iterations=0):
33+
def generate_number(self, temperature: Union[float, None]=None, iterations=0):
5634
prompt = self.get_prompt()
57-
self.debug("[generate_number]", prompt, is_prompt=True)
58-
input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(
59-
self.model.device
60-
)
61-
response = self.model.generate(
62-
input_tokens,
63-
max_new_tokens=self.max_number_tokens,
64-
num_return_sequences=1,
65-
logits_processor=[self.number_logit_processor],
66-
stopping_criteria=[
67-
NumberStoppingCriteria(self.tokenizer, len(input_tokens[0]))
68-
],
69-
temperature=temperature or self.temperature,
70-
pad_token_id=self.tokenizer.eos_token_id,
71-
)
35+
self.debug('[generate_number]', prompt, is_prompt=True)
36+
input_tokens = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
37+
response = self.model.generate(input_tokens, max_new_tokens=self.max_number_tokens, num_return_sequences=1, logits_processor=[self.number_logit_processor], stopping_criteria=[NumberStoppingCriteria(self.tokenizer, len(input_tokens[0]))], temperature=temperature or self.temperature, pad_token_id=self.tokenizer.eos_token_id)
7238
response = self.tokenizer.decode(response[0], skip_special_tokens=True)
73-
74-
response = response[len(prompt) :]
75-
response = response.strip().rstrip(".")
76-
self.debug("[generate_number]", response)
39+
response = response[len(prompt):]
40+
response = response.strip().rstrip('.')
41+
self.debug('[generate_number]', response)
7742
try:
7843
return float(response)
7944
except ValueError:
8045
if iterations > 3:
81-
raise ValueError("Failed to generate a valid number")
82-
83-
return self.generate_number(temperature=self.temperature * 1.3, iterations=iterations+1)
46+
raise ValueError('Failed to generate a valid number')
47+
return self.generate_number(temperature=self.temperature * 1.3, iterations=iterations + 1)
8448

8549
def generate_boolean(self) -> bool:
8650
prompt = self.get_prompt()
87-
self.debug("[generate_boolean]", prompt, is_prompt=True)
88-
89-
input_tensor = self.tokenizer.encode(prompt, return_tensors="pt")
51+
self.debug('[generate_boolean]', prompt, is_prompt=True)
52+
input_tensor = self.tokenizer.encode(prompt, return_tensors='pt')
9053
output = self.model.forward(input_tensor.to(self.model.device))
9154
logits = output.logits[0, -1]
92-
93-
# todo: this assumes that "true" and "false" are both tokenized to a single token
94-
# this is probably not true for all tokenizers
95-
# this can be fixed by looking at only the first token of both "true" and "false"
96-
true_token_id = self.tokenizer.convert_tokens_to_ids("true")
97-
false_token_id = self.tokenizer.convert_tokens_to_ids("false")
98-
55+
true_token_id = self.tokenizer.convert_tokens_to_ids('true')
56+
false_token_id = self.tokenizer.convert_tokens_to_ids('false')
9957
result = logits[true_token_id] > logits[false_token_id]
100-
101-
self.debug("[generate_boolean]", result)
102-
58+
self.debug('[generate_boolean]', result)
10359
return result.item()
10460

10561
def generate_string(self) -> str:
10662
prompt = self.get_prompt() + '"'
107-
self.debug("[generate_string]", prompt, is_prompt=True)
108-
input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(
109-
self.model.device
110-
)
111-
112-
response = self.model.generate(
113-
input_tokens,
114-
max_new_tokens=self.max_string_token_length,
115-
num_return_sequences=1,
116-
temperature=self.temperature,
117-
stopping_criteria=[
118-
StringStoppingCriteria(self.tokenizer, len(input_tokens[0]))
119-
],
120-
pad_token_id=self.tokenizer.eos_token_id,
121-
)
122-
123-
# Some models output the prompt as part of the response
124-
# This removes the prompt from the response if it is present
125-
if (
126-
len(response[0]) >= len(input_tokens[0])
127-
and (response[0][: len(input_tokens[0])] == input_tokens).all()
128-
):
129-
response = response[0][len(input_tokens[0]) :]
63+
self.debug('[generate_string]', prompt, is_prompt=True)
64+
input_tokens = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
65+
response = self.model.generate(input_tokens, max_new_tokens=self.max_string_token_length, num_return_sequences=1, temperature=self.temperature, stopping_criteria=[StringStoppingCriteria(self.tokenizer, len(input_tokens[0]))], pad_token_id=self.tokenizer.eos_token_id)
66+
if len(response[0]) >= len(input_tokens[0]) and (response[0][:len(input_tokens[0])] == input_tokens).all():
67+
response = response[0][len(input_tokens[0]):]
13068
if response.shape[0] == 1:
13169
response = response[0]
132-
13370
response = self.tokenizer.decode(response, skip_special_tokens=True)
134-
135-
self.debug("[generate_string]", "|" + response + "|")
136-
71+
self.debug('[generate_string]', '|' + response + '|')
13772
if response.count('"') < 1:
13873
return response
139-
14074
return response.split('"')[0].strip()
14175

142-
def generate_object(
143-
self, properties: Dict[str, Any], obj: Dict[str, Any]
144-
) -> Dict[str, Any]:
145-
for key, schema in properties.items():
146-
self.debug("[generate_object] generating value for", key)
76+
def generate_object(self, properties: Dict[str, Any], obj: Dict[str, Any]) -> Dict[str, Any]:
77+
for (key, schema) in properties.items():
78+
self.debug('[generate_object] generating value for', key)
14779
obj[key] = self.generate_value(schema, obj, key)
14880
return obj
14981

150-
def generate_value(
151-
self,
152-
schema: Dict[str, Any],
153-
obj: Union[Dict[str, Any], List[Any]],
154-
key: Union[str, None] = None,
155-
) -> Any:
156-
schema_type = schema["type"]
157-
if schema_type == "number":
82+
def generate_value(self, schema: Dict[str, Any], obj: Union[Dict[str, Any], List[Any]], key: Union[str, None]=None) -> Any:
83+
if 'anyOf' in schema:
84+
options = [option for option in schema['anyOf'] if 'type' in option]
85+
if options:
86+
schema = options[0]
87+
else:
88+
raise ValueError('No valid type in anyOf for key: ' + str(key))
89+
schema_type = schema['type']
90+
if schema_type == 'number':
15891
if key:
15992
obj[key] = self.generation_marker
16093
else:
16194
obj.append(self.generation_marker)
16295
return self.generate_number()
163-
elif schema_type == "boolean":
96+
elif schema_type == 'boolean':
16497
if key:
16598
obj[key] = self.generation_marker
16699
else:
167100
obj.append(self.generation_marker)
168101
return self.generate_boolean()
169-
elif schema_type == "string":
102+
elif schema_type == 'string':
170103
if key:
171104
obj[key] = self.generation_marker
172105
else:
173106
obj.append(self.generation_marker)
174107
return self.generate_string()
175-
elif schema_type == "array":
108+
elif schema_type == 'array':
176109
new_array = []
177-
obj[key] = new_array
178-
return self.generate_array(schema["items"], new_array)
179-
elif schema_type == "object":
110+
if key:
111+
obj[key] = new_array
112+
else:
113+
obj.append(new_array)
114+
return self.generate_array(schema['items'], new_array)
115+
elif schema_type == 'object':
180116
new_obj = {}
181117
if key:
182118
obj[key] = new_obj
183119
else:
184120
obj.append(new_obj)
185-
return self.generate_object(schema["properties"], new_obj)
121+
return self.generate_object(schema['properties'], new_obj)
186122
else:
187-
raise ValueError(f"Unsupported schema type: {schema_type}")
123+
raise ValueError(f'Unsupported schema type: {schema_type}')
188124

189125
def generate_array(self, item_schema: Dict[str, Any], obj: Dict[str, Any]) -> list:
190126
for _ in range(self.max_array_length):
191-
# forces array to have at least one element
192127
element = self.generate_value(item_schema, obj)
193128
obj[-1] = element
194-
195129
obj.append(self.generation_marker)
196130
input_prompt = self.get_prompt()
197131
obj.pop()
198-
input_tensor = self.tokenizer.encode(input_prompt, return_tensors="pt")
132+
input_tensor = self.tokenizer.encode(input_prompt, return_tensors='pt')
199133
output = self.model.forward(input_tensor.to(self.model.device))
200134
logits = output.logits[0, -1]
201-
202-
203135
top_indices = logits.topk(30).indices
204136
sorted_token_ids = top_indices[logits[top_indices].argsort(descending=True)]
205-
206137
found_comma = False
207138
found_close_bracket = False
208-
209139
for token_id in sorted_token_ids:
210140
decoded_token = self.tokenizer.decode(token_id)
211141
if ',' in decoded_token:
@@ -214,32 +144,22 @@ def generate_array(self, item_schema: Dict[str, Any], obj: Dict[str, Any]) -> li
214144
if ']' in decoded_token:
215145
found_close_bracket = True
216146
break
217-
218147
if found_close_bracket or not found_comma:
219148
break
220-
221149
return obj
222150

223151
def get_prompt(self):
224-
template = """{prompt}\nOutput result in the following JSON schema format:\n{schema}\nResult: {progress}"""
152+
template = '{prompt}\nOutput result in the following JSON schema format:\n{schema}\nResult: {progress}'
225153
progress = json.dumps(self.value)
226154
gen_marker_index = progress.find(f'"{self.generation_marker}"')
227155
if gen_marker_index != -1:
228156
progress = progress[:gen_marker_index]
229157
else:
230-
raise ValueError("Failed to find generation marker")
231-
232-
prompt = template.format(
233-
prompt=self.prompt,
234-
schema=json.dumps(self.json_schema),
235-
progress=progress,
236-
)
237-
158+
raise ValueError('Failed to find generation marker')
159+
prompt = template.format(prompt=self.prompt, schema=json.dumps(self.json_schema), progress=progress)
238160
return prompt
239161

240162
def __call__(self) -> Dict[str, Any]:
241163
self.value = {}
242-
generated_data = self.generate_object(
243-
self.json_schema["properties"], self.value
244-
)
245-
return generated_data
164+
generated_data = self.generate_object(self.json_schema['properties'], self.value)
165+
return generated_data

0 commit comments

Comments
 (0)