11from typing import List , Union , Dict , Any
2-
3- from jsonformer .logits_processors import (
4- NumberStoppingCriteria ,
5- OutputNumbersTokens ,
6- StringStoppingCriteria ,
7- )
2+ from jsonformer .logits_processors import NumberStoppingCriteria , OutputNumbersTokens , StringStoppingCriteria
83from termcolor import cprint
94from transformers import PreTrainedModel , PreTrainedTokenizer
105import json
11-
12- GENERATION_MARKER = "|GENERATION|"
13-
6+ GENERATION_MARKER = '|GENERATION|'
147
158class Jsonformer :
169 value : Dict [str , Any ] = {}
1710
18- def __init__ (
19- self ,
20- model : PreTrainedModel ,
21- tokenizer : PreTrainedTokenizer ,
22- json_schema : Dict [str , Any ],
23- prompt : str ,
24- * ,
25- debug : bool = False ,
26- max_array_length : int = 128 ,
27- max_number_tokens : int = 2048 ,
28- temperature : float = 1.0 ,
29- max_string_token_length : int = 1024 ,
30- ):
11+ def __init__ (self , model : PreTrainedModel , tokenizer : PreTrainedTokenizer , json_schema : Dict [str , Any ], prompt : str , * , debug : bool = False , max_array_length : int = 128 , max_number_tokens : int = 2048 , temperature : float = 1.0 , max_string_token_length : int = 1024 ):
3112 self .model = model
3213 self .tokenizer = tokenizer
3314 self .json_schema = json_schema
3415 self .prompt = prompt
35-
3616 self .number_logit_processor = OutputNumbersTokens (self .tokenizer , self .prompt )
37-
38- self .generation_marker = "|GENERATION|"
17+ self .generation_marker = '|GENERATION|'
3918 self .debug_on = debug
4019 self .max_array_length = max_array_length
41-
4220 self .max_number_tokens = max_number_tokens
4321 self .temperature = temperature
4422 self .max_string_token_length = max_string_token_length
4523
46- def debug (self , caller : str , value : str , is_prompt : bool = False ):
24+ def debug (self , caller : str , value : str , is_prompt : bool = False ):
4725 if self .debug_on :
4826 if is_prompt :
49- cprint (caller , " green" , end = " " )
50- cprint (value , " yellow" )
27+ cprint (caller , ' green' , end = ' ' )
28+ cprint (value , ' yellow' )
5129 else :
52- cprint (caller , " green" , end = " " )
53- cprint (value , " blue" )
30+ cprint (caller , ' green' , end = ' ' )
31+ cprint (value , ' blue' )
5432
55- def generate_number (self , temperature : Union [float , None ] = None , iterations = 0 ):
33+ def generate_number (self , temperature : Union [float , None ]= None , iterations = 0 ):
5634 prompt = self .get_prompt ()
57- self .debug ("[generate_number]" , prompt , is_prompt = True )
58- input_tokens = self .tokenizer .encode (prompt , return_tensors = "pt" ).to (
59- self .model .device
60- )
61- response = self .model .generate (
62- input_tokens ,
63- max_new_tokens = self .max_number_tokens ,
64- num_return_sequences = 1 ,
65- logits_processor = [self .number_logit_processor ],
66- stopping_criteria = [
67- NumberStoppingCriteria (self .tokenizer , len (input_tokens [0 ]))
68- ],
69- temperature = temperature or self .temperature ,
70- pad_token_id = self .tokenizer .eos_token_id ,
71- )
35+ self .debug ('[generate_number]' , prompt , is_prompt = True )
36+ input_tokens = self .tokenizer .encode (prompt , return_tensors = 'pt' ).to (self .model .device )
37+ response = self .model .generate (input_tokens , max_new_tokens = self .max_number_tokens , num_return_sequences = 1 , logits_processor = [self .number_logit_processor ], stopping_criteria = [NumberStoppingCriteria (self .tokenizer , len (input_tokens [0 ]))], temperature = temperature or self .temperature , pad_token_id = self .tokenizer .eos_token_id )
7238 response = self .tokenizer .decode (response [0 ], skip_special_tokens = True )
73-
74- response = response [len (prompt ) :]
75- response = response .strip ().rstrip ("." )
76- self .debug ("[generate_number]" , response )
39+ response = response [len (prompt ):]
40+ response = response .strip ().rstrip ('.' )
41+ self .debug ('[generate_number]' , response )
7742 try :
7843 return float (response )
7944 except ValueError :
8045 if iterations > 3 :
81- raise ValueError ("Failed to generate a valid number" )
82-
83- return self .generate_number (temperature = self .temperature * 1.3 , iterations = iterations + 1 )
46+ raise ValueError ('Failed to generate a valid number' )
47+ return self .generate_number (temperature = self .temperature * 1.3 , iterations = iterations + 1 )
8448
8549 def generate_boolean (self ) -> bool :
8650 prompt = self .get_prompt ()
87- self .debug ("[generate_boolean]" , prompt , is_prompt = True )
88-
89- input_tensor = self .tokenizer .encode (prompt , return_tensors = "pt" )
51+ self .debug ('[generate_boolean]' , prompt , is_prompt = True )
52+ input_tensor = self .tokenizer .encode (prompt , return_tensors = 'pt' )
9053 output = self .model .forward (input_tensor .to (self .model .device ))
9154 logits = output .logits [0 , - 1 ]
92-
93- # todo: this assumes that "true" and "false" are both tokenized to a single token
94- # this is probably not true for all tokenizers
95- # this can be fixed by looking at only the first token of both "true" and "false"
96- true_token_id = self .tokenizer .convert_tokens_to_ids ("true" )
97- false_token_id = self .tokenizer .convert_tokens_to_ids ("false" )
98-
55+ true_token_id = self .tokenizer .convert_tokens_to_ids ('true' )
56+ false_token_id = self .tokenizer .convert_tokens_to_ids ('false' )
9957 result = logits [true_token_id ] > logits [false_token_id ]
100-
101- self .debug ("[generate_boolean]" , result )
102-
58+ self .debug ('[generate_boolean]' , result )
10359 return result .item ()
10460
10561 def generate_string (self ) -> str :
10662 prompt = self .get_prompt () + '"'
107- self .debug ("[generate_string]" , prompt , is_prompt = True )
108- input_tokens = self .tokenizer .encode (prompt , return_tensors = "pt" ).to (
109- self .model .device
110- )
111-
112- response = self .model .generate (
113- input_tokens ,
114- max_new_tokens = self .max_string_token_length ,
115- num_return_sequences = 1 ,
116- temperature = self .temperature ,
117- stopping_criteria = [
118- StringStoppingCriteria (self .tokenizer , len (input_tokens [0 ]))
119- ],
120- pad_token_id = self .tokenizer .eos_token_id ,
121- )
122-
123- # Some models output the prompt as part of the response
124- # This removes the prompt from the response if it is present
125- if (
126- len (response [0 ]) >= len (input_tokens [0 ])
127- and (response [0 ][: len (input_tokens [0 ])] == input_tokens ).all ()
128- ):
129- response = response [0 ][len (input_tokens [0 ]) :]
63+ self .debug ('[generate_string]' , prompt , is_prompt = True )
64+ input_tokens = self .tokenizer .encode (prompt , return_tensors = 'pt' ).to (self .model .device )
65+ response = self .model .generate (input_tokens , max_new_tokens = self .max_string_token_length , num_return_sequences = 1 , temperature = self .temperature , stopping_criteria = [StringStoppingCriteria (self .tokenizer , len (input_tokens [0 ]))], pad_token_id = self .tokenizer .eos_token_id )
66+ if len (response [0 ]) >= len (input_tokens [0 ]) and (response [0 ][:len (input_tokens [0 ])] == input_tokens ).all ():
67+ response = response [0 ][len (input_tokens [0 ]):]
13068 if response .shape [0 ] == 1 :
13169 response = response [0 ]
132-
13370 response = self .tokenizer .decode (response , skip_special_tokens = True )
134-
135- self .debug ("[generate_string]" , "|" + response + "|" )
136-
71+ self .debug ('[generate_string]' , '|' + response + '|' )
13772 if response .count ('"' ) < 1 :
13873 return response
139-
14074 return response .split ('"' )[0 ].strip ()
14175
142- def generate_object (
143- self , properties : Dict [str , Any ], obj : Dict [str , Any ]
144- ) -> Dict [str , Any ]:
145- for key , schema in properties .items ():
146- self .debug ("[generate_object] generating value for" , key )
76+ def generate_object (self , properties : Dict [str , Any ], obj : Dict [str , Any ]) -> Dict [str , Any ]:
77+ for (key , schema ) in properties .items ():
78+ self .debug ('[generate_object] generating value for' , key )
14779 obj [key ] = self .generate_value (schema , obj , key )
14880 return obj
14981
150- def generate_value (
151- self ,
152- schema : Dict [str , Any ],
153- obj : Union [Dict [str , Any ], List [Any ]],
154- key : Union [str , None ] = None ,
155- ) -> Any :
156- schema_type = schema ["type" ]
157- if schema_type == "number" :
82+ def generate_value (self , schema : Dict [str , Any ], obj : Union [Dict [str , Any ], List [Any ]], key : Union [str , None ]= None ) -> Any :
83+ if 'anyOf' in schema :
84+ options = [option for option in schema ['anyOf' ] if 'type' in option ]
85+ if options :
86+ schema = options [0 ]
87+ else :
88+ raise ValueError ('No valid type in anyOf for key: ' + str (key ))
89+ schema_type = schema ['type' ]
90+ if schema_type == 'number' :
15891 if key :
15992 obj [key ] = self .generation_marker
16093 else :
16194 obj .append (self .generation_marker )
16295 return self .generate_number ()
163- elif schema_type == " boolean" :
96+ elif schema_type == ' boolean' :
16497 if key :
16598 obj [key ] = self .generation_marker
16699 else :
167100 obj .append (self .generation_marker )
168101 return self .generate_boolean ()
169- elif schema_type == " string" :
102+ elif schema_type == ' string' :
170103 if key :
171104 obj [key ] = self .generation_marker
172105 else :
173106 obj .append (self .generation_marker )
174107 return self .generate_string ()
175- elif schema_type == " array" :
108+ elif schema_type == ' array' :
176109 new_array = []
177- obj [key ] = new_array
178- return self .generate_array (schema ["items" ], new_array )
179- elif schema_type == "object" :
110+ if key :
111+ obj [key ] = new_array
112+ else :
113+ obj .append (new_array )
114+ return self .generate_array (schema ['items' ], new_array )
115+ elif schema_type == 'object' :
180116 new_obj = {}
181117 if key :
182118 obj [key ] = new_obj
183119 else :
184120 obj .append (new_obj )
185- return self .generate_object (schema [" properties" ], new_obj )
121+ return self .generate_object (schema [' properties' ], new_obj )
186122 else :
187- raise ValueError (f" Unsupported schema type: { schema_type } " )
123+ raise ValueError (f' Unsupported schema type: { schema_type } ' )
188124
189125 def generate_array (self , item_schema : Dict [str , Any ], obj : Dict [str , Any ]) -> list :
190126 for _ in range (self .max_array_length ):
191- # forces array to have at least one element
192127 element = self .generate_value (item_schema , obj )
193128 obj [- 1 ] = element
194-
195129 obj .append (self .generation_marker )
196130 input_prompt = self .get_prompt ()
197131 obj .pop ()
198- input_tensor = self .tokenizer .encode (input_prompt , return_tensors = "pt" )
132+ input_tensor = self .tokenizer .encode (input_prompt , return_tensors = 'pt' )
199133 output = self .model .forward (input_tensor .to (self .model .device ))
200134 logits = output .logits [0 , - 1 ]
201-
202-
203135 top_indices = logits .topk (30 ).indices
204136 sorted_token_ids = top_indices [logits [top_indices ].argsort (descending = True )]
205-
206137 found_comma = False
207138 found_close_bracket = False
208-
209139 for token_id in sorted_token_ids :
210140 decoded_token = self .tokenizer .decode (token_id )
211141 if ',' in decoded_token :
@@ -214,32 +144,22 @@ def generate_array(self, item_schema: Dict[str, Any], obj: Dict[str, Any]) -> li
214144 if ']' in decoded_token :
215145 found_close_bracket = True
216146 break
217-
218147 if found_close_bracket or not found_comma :
219148 break
220-
221149 return obj
222150
223151 def get_prompt (self ):
224- template = """ {prompt}\n Output result in the following JSON schema format:\n {schema}\n Result: {progress}"""
152+ template = ' {prompt}\n Output result in the following JSON schema format:\n {schema}\n Result: {progress}'
225153 progress = json .dumps (self .value )
226154 gen_marker_index = progress .find (f'"{ self .generation_marker } "' )
227155 if gen_marker_index != - 1 :
228156 progress = progress [:gen_marker_index ]
229157 else :
230- raise ValueError ("Failed to find generation marker" )
231-
232- prompt = template .format (
233- prompt = self .prompt ,
234- schema = json .dumps (self .json_schema ),
235- progress = progress ,
236- )
237-
158+ raise ValueError ('Failed to find generation marker' )
159+ prompt = template .format (prompt = self .prompt , schema = json .dumps (self .json_schema ), progress = progress )
238160 return prompt
239161
240162 def __call__ (self ) -> Dict [str , Any ]:
241163 self .value = {}
242- generated_data = self .generate_object (
243- self .json_schema ["properties" ], self .value
244- )
245- return generated_data
164+ generated_data = self .generate_object (self .json_schema ['properties' ], self .value )
165+ return generated_data
0 commit comments