Skip to content

Commit d8023b8

Browse files
committed
Fixed anyOf handling in Jsonformer
1 parent ec91cec commit d8023b8

File tree

1 file changed

+16
-69
lines changed

1 file changed

+16
-69
lines changed

jsonformer/main.py

Lines changed: 16 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ def __init__(
2323
prompt: str,
2424
*,
2525
debug: bool = False,
26-
max_array_length: int = 128,
27-
max_number_tokens: int = 2048,
26+
max_array_length: int = 10,
27+
max_number_tokens: int = 6,
2828
temperature: float = 1.0,
29-
max_string_token_length: int = 1024,
29+
max_string_token_length: int = 10,
3030
):
3131
self.model = model
3232
self.tokenizer = tokenizer
@@ -71,15 +71,14 @@ def generate_number(self, temperature: Union[float, None] = None, iterations=0):
7171
)
7272
response = self.tokenizer.decode(response[0], skip_special_tokens=True)
7373

74-
response = response[len(prompt) :]
74+
response = response[len(prompt):]
7575
response = response.strip().rstrip(".")
7676
self.debug("[generate_number]", response)
7777
try:
7878
return float(response)
7979
except ValueError:
8080
if iterations > 3:
8181
raise ValueError("Failed to generate a valid number")
82-
8382
return self.generate_number(temperature=self.temperature * 1.3, iterations=iterations+1)
8483

8584
def generate_boolean(self) -> bool:
@@ -90,16 +89,11 @@ def generate_boolean(self) -> bool:
9089
output = self.model.forward(input_tensor.to(self.model.device))
9190
logits = output.logits[0, -1]
9291

93-
# todo: this assumes that "true" and "false" are both tokenized to a single token
94-
# this is probably not true for all tokenizers
95-
# this can be fixed by looking at only the first token of both "true" and "false"
9692
true_token_id = self.tokenizer.convert_tokens_to_ids("true")
9793
false_token_id = self.tokenizer.convert_tokens_to_ids("false")
9894

9995
result = logits[true_token_id] > logits[false_token_id]
100-
10196
self.debug("[generate_boolean]", result)
102-
10397
return result.item()
10498

10599
def generate_string(self) -> str:
@@ -120,8 +114,6 @@ def generate_string(self) -> str:
120114
pad_token_id=self.tokenizer.eos_token_id,
121115
)
122116

123-
# Some models output the prompt as part of the response
124-
# This removes the prompt from the response if it is present
125117
if (
126118
len(response[0]) >= len(input_tokens[0])
127119
and (response[0][: len(input_tokens[0])] == input_tokens).all()
@@ -131,7 +123,6 @@ def generate_string(self) -> str:
131123
response = response[0]
132124

133125
response = self.tokenizer.decode(response, skip_special_tokens=True)
134-
135126
self.debug("[generate_string]", "|" + response + "|")
136127

137128
if response.count('"') < 1:
@@ -153,6 +144,12 @@ def generate_value(
153144
obj: Union[Dict[str, Any], List[Any]],
154145
key: Union[str, None] = None,
155146
) -> Any:
147+
if "anyOf" in schema:
148+
options = [option for option in schema["anyOf"] if "type" in option]
149+
if options:
150+
schema = options[0]
151+
else:
152+
raise ValueError("No valid type in anyOf for key: " + str(key))
156153
schema_type = schema["type"]
157154
if schema_type == "number":
158155
if key:
@@ -174,7 +171,10 @@ def generate_value(
174171
return self.generate_string()
175172
elif schema_type == "array":
176173
new_array = []
177-
obj[key] = new_array
174+
if key:
175+
obj[key] = new_array
176+
else:
177+
obj.append(new_array)
178178
return self.generate_array(schema["items"], new_array)
179179
elif schema_type == "object":
180180
new_obj = {}
@@ -186,60 +186,7 @@ def generate_value(
186186
else:
187187
raise ValueError(f"Unsupported schema type: {schema_type}")
188188

189-
def generate_array(self, item_schema: Dict[str, Any], obj: Dict[str, Any]) -> list:
190-
for _ in range(self.max_array_length):
191-
# forces array to have at least one element
192-
element = self.generate_value(item_schema, obj)
193-
obj[-1] = element
194-
195-
obj.append(self.generation_marker)
196-
input_prompt = self.get_prompt()
197-
obj.pop()
198-
input_tensor = self.tokenizer.encode(input_prompt, return_tensors="pt")
199-
output = self.model.forward(input_tensor.to(self.model.device))
200-
logits = output.logits[0, -1]
201-
202-
203-
top_indices = logits.topk(30).indices
204-
sorted_token_ids = top_indices[logits[top_indices].argsort(descending=True)]
205-
206-
found_comma = False
207-
found_close_bracket = False
208-
209-
for token_id in sorted_token_ids:
210-
decoded_token = self.tokenizer.decode(token_id)
211-
if ',' in decoded_token:
212-
found_comma = True
213-
break
214-
if ']' in decoded_token:
215-
found_close_bracket = True
216-
break
217-
218-
if found_close_bracket or not found_comma:
219-
break
220-
221-
return obj
222-
223-
def get_prompt(self):
224-
template = """{prompt}\nOutput result in the following JSON schema format:\n{schema}\nResult: {progress}"""
225-
progress = json.dumps(self.value)
226-
gen_marker_index = progress.find(f'"{self.generation_marker}"')
227-
if gen_marker_index != -1:
228-
progress = progress[:gen_marker_index]
229-
else:
230-
raise ValueError("Failed to find generation marker")
231-
232-
prompt = template.format(
233-
prompt=self.prompt,
234-
schema=json.dumps(self.json_schema),
235-
progress=progress,
236-
)
237-
238-
return prompt
239-
240189
def __call__(self) -> Dict[str, Any]:
241190
self.value = {}
242-
generated_data = self.generate_object(
243-
self.json_schema["properties"], self.value
244-
)
245-
return generated_data
191+
generated_data = self.generate_object(self.json_schema["properties"], self.value)
192+
return generated_data

0 commit comments

Comments
 (0)