1+ # std imports 
2+ import  argparse 
3+ import  json 
4+ import  os 
5+ import  sys 
6+ import  time 
7+ from  tqdm  import  tqdm 
8+ import  torch 
9+ 
10+ # tpl imports 
11+ from  vllm  import  LLM , SamplingParams 
12+ 
13+ # local imports 
14+ from  utils  import  BalancedBracketsCriteria , PromptDataset , clean_output , get_inference_config 
15+ 
16+ """ Parse command line arguments """ 
17+ parser  =  argparse .ArgumentParser (description = 'Generate code with vLLM' )
18+ parser .add_argument ('--prompts' , required = True , help = 'Path to the prompt JSON file' )
19+ parser .add_argument ('--model' , required = True , help = 'Path to the language model' )
20+ parser .add_argument ('--output' , required = True , help = 'Path to the output JSON file' )
21+ parser .add_argument ('--restart' , action = 'store_true' , help = 'Restart generation from scratch (default: False)' )
22+ parser .add_argument ('--cache' , help = 'JSONL file to cache intermediate results in. Will be restored from if it '  + 
23+     'already exists and --restart is not specified' )
24+ parser .add_argument ('--restore_from' , help = 'JSON file to restore old results from. Will be restored from '  + 
25+     'if it already exists and --restart is not specified. Is different from --cache in that it is a JSON file, not a '  + 
26+     'JSONL file, and it is only used to restore old results where the prompt is equivalent. Cached results are '  + 
27+     'prioritized over restored results.' )
28+ parser .add_argument ('--max_new_tokens' , type = int , default = 1024 , help = 'Maximum number of new tokens to generate (default: 1024)' )
29+ parser .add_argument ('--num_samples_per_prompt' , type = int , default = 50 , help = 'Number of code samples to generate (default: 50)' )
30+ parser .add_argument ('--temperature' , type = float , default = 0.2 , help = 'Temperature for controlling randomness (default: 0.2)' )
31+ parser .add_argument ('--top_p' , type = float , default = 0.95 , help = 'Top p value for nucleus sampling (default: 0.95)' )
32+ parser .add_argument ('--do_sample' , action = 'store_true' , help = 'Enable sampling (default: False)' )
33+ parser .add_argument ('--prompted' , action = 'store_true' , help = 'Use prompted generation. See StarCoder paper (default: False)' )
34+ args  =  parser .parse_args ()
35+ 
36+ """ Load prompts """ 
37+ with  open (args .prompts , 'r' ) as  json_file :
38+     prompts  =  json .load (json_file )
39+ 
40+ """ Load existing responses if they exist """ 
41+ if  not  args .restart  and  os .path .exists (args .cache ):
42+     with  open (args .cache , 'r' ) as  jsonl_file :
43+         responses  =  [json .loads (line ) for  line  in  jsonl_file ]
44+     
45+     # remove prompt from prompts if it is in responses and has an 'output' value with at least 1 entry 
46+     original_len  =  len (prompts )
47+     prompts  =  [p  for  p  in  prompts  if  
48+                 not  any (p ["name" ] ==  r ["name" ] and  
49+                         p ["parallelism_model" ] ==  r ["parallelism_model" ] and 
50+                         p ["prompt" ] ==  r ["prompt" ] and  
51+                         args .temperature  ==  r ["temperature" ] and  
52+                         args .prompted  ==  r ["prompted" ] and 
53+                         args .num_samples_per_prompt  ==  len (r ["outputs" ])
54+                         for  r  in  responses )]
55+     print (f"[cache] Skipping { original_len  -  len (prompts )}  )
56+ 
57+ """ Load existing responses if they exist """ 
58+ if  not  args .restart  and  args .restore_from  and  os .path .exists (args .restore_from ):
59+     with  open (args .restore_from , 'r' ) as  json_file :
60+         restored_responses  =  json .load (json_file )
61+     
62+     # remove prompt from prompts if it is in responses and has an 'output' value with at least 1 entry 
63+     original_len  =  len (prompts )
64+     responses_to_keep  =  []
65+     prompts_without_existing_responses  =  []
66+     for  p  in  prompts :
67+         for  r  in  restored_responses :
68+             if  p ["name" ] ==  r ["name" ] and  \
69+                 p ["parallelism_model" ] ==  r ["parallelism_model" ] and  \
70+                 p ["prompt" ] ==  r ["prompt" ] and  \
71+                 args .temperature  ==  r ["temperature" ] and  \
72+                 args .prompted  ==  r ["prompted" ] and  \
73+                 args .num_samples_per_prompt  ==  len (r ["outputs" ]):
74+                 responses_to_keep .append (r )
75+                 break 
76+         else :
77+             prompts_without_existing_responses .append (p )
78+     prompts  =  prompts_without_existing_responses 
79+     print (f"[restore_from] Skipping { original_len  -  len (prompts )}   + 
80+         f"{ len (prompts )}  )
81+ 
82+     # write restored responses to cache 
83+     if  args .cache  is  not None :
84+         with  open (args .cache , 'a' ) as  jsonl_file :
85+             for  response  in  responses_to_keep :
86+                 jsonl_file .write (json .dumps (response ) +  "\n " )
87+             print (f"[restore_from] Wrote { len (responses_to_keep )}  )
88+ 
89+ """ Initialize inference config """ 
90+ inference_config  =  get_inference_config (args .model , prompted = args .prompted )
91+ 
92+ prompts_repeated  =  [p  for  p  in  prompts  for  _  in  range (args .num_samples_per_prompt )]
93+ 
94+ """ Initialize vLLM engine """ 
95+ llm  =  LLM (model = args .model , tensor_parallel_size = torch .cuda .device_count ())
96+ 
97+ # Configure sampling parameters 
98+ sampling_params  =  SamplingParams (
99+     temperature = args .temperature  if  args .do_sample  else  0 ,
100+     top_p = args .top_p  if  args .do_sample  else  1.0 ,
101+     max_tokens = args .max_new_tokens ,
102+     n = 1 ,  # We handle multiple samples manually 
103+ )
104+ 
105+ """ Generate code """ 
106+ if  not  args .restart  and  args .cache  is  not None  and  os .path .exists (args .cache ):
107+     with  open (args .cache , 'r' ) as  jsonl_file :
108+         responses  =  [json .loads (line ) for  line  in  jsonl_file ]
109+         responses  =  [r  for  r  in  responses  if  r ["temperature" ] ==  args .temperature  and  r ["prompted" ] ==  args .prompted 
110+                         and  args .num_samples_per_prompt  ==  len (r ["outputs" ])
111+                         and  any (p ["name" ] ==  r ["name" ] and  p ["prompt" ] ==  r ["prompt" ] and  p ["parallelism_model" ] ==  r ["parallelism_model" ] for  p  in  prompts )]
112+ else :
113+     responses  =  []
114+ 
115+ cur_prompt  =  None 
116+ start_time  =  time .time ()
117+ total_tokens  =  0 
118+ 
119+ # Format all prompts 
120+ formatted_prompts  =  [inference_config .format_prompt (p ["prompt" ]) for  p  in  prompts_repeated ]
121+ 
122+ # Generate all outputs at once 
123+ outputs  =  llm .generate (formatted_prompts , sampling_params )
124+ 
125+ # Process outputs 
126+ for  idx , (prompt , output ) in  enumerate (zip (prompts_repeated , outputs )):
127+     if  idx  %  args .num_samples_per_prompt  ==  0 :
128+         cur_prompt  =  prompt .copy ()
129+         cur_prompt .update ({
130+             "temperature" : args .temperature ,
131+             "top_p" : args .top_p ,
132+             "do_sample" : args .do_sample ,
133+             "max_new_tokens" : args .max_new_tokens ,
134+             "prompted" : args .prompted 
135+         })
136+         cur_prompt ["outputs" ] =  []
137+         cur_prompt ["raw_outputs" ] =  []
138+         prompt_str  =  cur_prompt ["prompt" ]
139+ 
140+     # Count tokens and clean output 
141+     # FIXME: This is to keep the same behavior as generate.py 
142+     huggingface_style_output  =  output .prompt  +  output .outputs [0 ].text 
143+     total_tokens  +=  len (llm .get_tokenizer ().encode (huggingface_style_output ))
144+     cleaned_output  =  inference_config .clean_output (huggingface_style_output , prompt_str )
145+     cur_prompt ["outputs" ].append (cleaned_output )
146+     cur_prompt ["raw_outputs" ].append (huggingface_style_output )
147+ 
148+     if  idx  %  args .num_samples_per_prompt  ==  args .num_samples_per_prompt  -  1 :
149+         responses .append (cur_prompt )
150+ 
151+         if  not  args .restart  and  args .cache  is  not None :
152+             with  open (args .cache , 'a' ) as  jsonl_file :
153+                 jsonl_file .write (json .dumps (cur_prompt ) +  "\n " )
154+ 
155+ end_time  =  time .time ()
156+ tokens_per_second  =  total_tokens  /  (end_time  -  start_time )
157+ print (f"Generated { len (responses )} { end_time  -  start_time :.2f} { tokens_per_second :.2f}  )
158+ 
159+ """ Save responses to JSON file """ 
160+ with  open (args .output , 'w' ) as  output_file :
161+     json .dump (responses , output_file , indent = 4 )
0 commit comments