@@ -23,10 +23,10 @@ def __init__(
2323 prompt : str ,
2424 * ,
2525 debug : bool = False ,
26- max_array_length : int = 128 ,
27- max_number_tokens : int = 2048 ,
26+ max_array_length : int = 10 ,
27+ max_number_tokens : int = 6 ,
2828 temperature : float = 1.0 ,
29- max_string_token_length : int = 1024 ,
29+ max_string_token_length : int = 10 ,
3030 ):
3131 self .model = model
3232 self .tokenizer = tokenizer
@@ -71,15 +71,14 @@ def generate_number(self, temperature: Union[float, None] = None, iterations=0):
7171 )
7272 response = self .tokenizer .decode (response [0 ], skip_special_tokens = True )
7373
74- response = response [len (prompt ) :]
74+ response = response [len (prompt ):]
7575 response = response .strip ().rstrip ("." )
7676 self .debug ("[generate_number]" , response )
7777 try :
7878 return float (response )
7979 except ValueError :
8080 if iterations > 3 :
8181 raise ValueError ("Failed to generate a valid number" )
82-
8382 return self .generate_number (temperature = self .temperature * 1.3 , iterations = iterations + 1 )
8483
8584 def generate_boolean (self ) -> bool :
@@ -90,16 +89,11 @@ def generate_boolean(self) -> bool:
9089 output = self .model .forward (input_tensor .to (self .model .device ))
9190 logits = output .logits [0 , - 1 ]
9291
93- # todo: this assumes that "true" and "false" are both tokenized to a single token
94- # this is probably not true for all tokenizers
95- # this can be fixed by looking at only the first token of both "true" and "false"
9692 true_token_id = self .tokenizer .convert_tokens_to_ids ("true" )
9793 false_token_id = self .tokenizer .convert_tokens_to_ids ("false" )
9894
9995 result = logits [true_token_id ] > logits [false_token_id ]
100-
10196 self .debug ("[generate_boolean]" , result )
102-
10397 return result .item ()
10498
10599 def generate_string (self ) -> str :
@@ -120,8 +114,6 @@ def generate_string(self) -> str:
120114 pad_token_id = self .tokenizer .eos_token_id ,
121115 )
122116
123- # Some models output the prompt as part of the response
124- # This removes the prompt from the response if it is present
125117 if (
126118 len (response [0 ]) >= len (input_tokens [0 ])
127119 and (response [0 ][: len (input_tokens [0 ])] == input_tokens ).all ()
@@ -131,7 +123,6 @@ def generate_string(self) -> str:
131123 response = response [0 ]
132124
133125 response = self .tokenizer .decode (response , skip_special_tokens = True )
134-
135126 self .debug ("[generate_string]" , "|" + response + "|" )
136127
137128 if response .count ('"' ) < 1 :
@@ -153,6 +144,12 @@ def generate_value(
153144 obj : Union [Dict [str , Any ], List [Any ]],
154145 key : Union [str , None ] = None ,
155146 ) -> Any :
147+ if "anyOf" in schema :
148+ options = [option for option in schema ["anyOf" ] if "type" in option ]
149+ if options :
150+ schema = options [0 ]
151+ else :
152+ raise ValueError ("No valid type in anyOf for key: " + str (key ))
156153 schema_type = schema ["type" ]
157154 if schema_type == "number" :
158155 if key :
@@ -174,7 +171,10 @@ def generate_value(
174171 return self .generate_string ()
175172 elif schema_type == "array" :
176173 new_array = []
177- obj [key ] = new_array
174+ if key :
175+ obj [key ] = new_array
176+ else :
177+ obj .append (new_array )
178178 return self .generate_array (schema ["items" ], new_array )
179179 elif schema_type == "object" :
180180 new_obj = {}
@@ -186,60 +186,7 @@ def generate_value(
186186 else :
187187 raise ValueError (f"Unsupported schema type: { schema_type } " )
188188
189- def generate_array (self , item_schema : Dict [str , Any ], obj : Dict [str , Any ]) -> list :
190- for _ in range (self .max_array_length ):
191- # forces array to have at least one element
192- element = self .generate_value (item_schema , obj )
193- obj [- 1 ] = element
194-
195- obj .append (self .generation_marker )
196- input_prompt = self .get_prompt ()
197- obj .pop ()
198- input_tensor = self .tokenizer .encode (input_prompt , return_tensors = "pt" )
199- output = self .model .forward (input_tensor .to (self .model .device ))
200- logits = output .logits [0 , - 1 ]
201-
202-
203- top_indices = logits .topk (30 ).indices
204- sorted_token_ids = top_indices [logits [top_indices ].argsort (descending = True )]
205-
206- found_comma = False
207- found_close_bracket = False
208-
209- for token_id in sorted_token_ids :
210- decoded_token = self .tokenizer .decode (token_id )
211- if ',' in decoded_token :
212- found_comma = True
213- break
214- if ']' in decoded_token :
215- found_close_bracket = True
216- break
217-
218- if found_close_bracket or not found_comma :
219- break
220-
221- return obj
222-
223- def get_prompt (self ):
224- template = """{prompt}\n Output result in the following JSON schema format:\n {schema}\n Result: {progress}"""
225- progress = json .dumps (self .value )
226- gen_marker_index = progress .find (f'"{ self .generation_marker } "' )
227- if gen_marker_index != - 1 :
228- progress = progress [:gen_marker_index ]
229- else :
230- raise ValueError ("Failed to find generation marker" )
231-
232- prompt = template .format (
233- prompt = self .prompt ,
234- schema = json .dumps (self .json_schema ),
235- progress = progress ,
236- )
237-
238- return prompt
239-
240189 def __call__ (self ) -> Dict [str , Any ]:
241190 self .value = {}
242- generated_data = self .generate_object (
243- self .json_schema ["properties" ], self .value
244- )
245- return generated_data
191+ generated_data = self .generate_object (self .json_schema ["properties" ], self .value )
192+ return generated_data
0 commit comments