1+ # std imports
2+ import argparse
3+ import json
4+ import os
5+ import sys
6+ import time
7+ from tqdm import tqdm
8+ import torch
9+
10+ # tpl imports
11+ from vllm import LLM , SamplingParams
12+
13+ # local imports
14+ from utils import BalancedBracketsCriteria , PromptDataset , clean_output , get_inference_config
15+
16+ """ Parse command line arguments """
17+ parser = argparse .ArgumentParser (description = 'Generate code with vLLM' )
18+ parser .add_argument ('--prompts' , required = True , help = 'Path to the prompt JSON file' )
19+ parser .add_argument ('--model' , required = True , help = 'Path to the language model' )
20+ parser .add_argument ('--output' , required = True , help = 'Path to the output JSON file' )
21+ parser .add_argument ('--restart' , action = 'store_true' , help = 'Restart generation from scratch (default: False)' )
22+ parser .add_argument ('--cache' , help = 'JSONL file to cache intermediate results in. Will be restored from if it ' +
23+ 'already exists and --restart is not specified' )
24+ parser .add_argument ('--restore_from' , help = 'JSON file to restore old results from. Will be restored from ' +
25+ 'if it already exists and --restart is not specified. Is different from --cache in that it is a JSON file, not a ' +
26+ 'JSONL file, and it is only used to restore old results where the prompt is equivalent. Cached results are ' +
27+ 'prioritized over restored results.' )
28+ parser .add_argument ('--max_new_tokens' , type = int , default = 1024 , help = 'Maximum number of new tokens to generate (default: 1024)' )
29+ parser .add_argument ('--num_samples_per_prompt' , type = int , default = 50 , help = 'Number of code samples to generate (default: 50)' )
30+ parser .add_argument ('--temperature' , type = float , default = 0.2 , help = 'Temperature for controlling randomness (default: 0.2)' )
31+ parser .add_argument ('--top_p' , type = float , default = 0.95 , help = 'Top p value for nucleus sampling (default: 0.95)' )
32+ parser .add_argument ('--do_sample' , action = 'store_true' , help = 'Enable sampling (default: False)' )
33+ parser .add_argument ('--prompted' , action = 'store_true' , help = 'Use prompted generation. See StarCoder paper (default: False)' )
34+ args = parser .parse_args ()
35+
36+ """ Load prompts """
37+ with open (args .prompts , 'r' ) as json_file :
38+ prompts = json .load (json_file )
39+
40+ """ Load existing responses if they exist """
41+ if not args .restart and os .path .exists (args .cache ):
42+ with open (args .cache , 'r' ) as jsonl_file :
43+ responses = [json .loads (line ) for line in jsonl_file ]
44+
45+ # remove prompt from prompts if it is in responses and has an 'output' value with at least 1 entry
46+ original_len = len (prompts )
47+ prompts = [p for p in prompts if
48+ not any (p ["name" ] == r ["name" ] and
49+ p ["parallelism_model" ] == r ["parallelism_model" ] and
50+ p ["prompt" ] == r ["prompt" ] and
51+ args .temperature == r ["temperature" ] and
52+ args .prompted == r ["prompted" ] and
53+ args .num_samples_per_prompt == len (r ["outputs" ])
54+ for r in responses )]
55+ print (f"[cache] Skipping { original_len - len (prompts )} prompts that already have responses" )
56+
57+ """ Load existing responses if they exist """
58+ if not args .restart and args .restore_from and os .path .exists (args .restore_from ):
59+ with open (args .restore_from , 'r' ) as json_file :
60+ restored_responses = json .load (json_file )
61+
62+ # remove prompt from prompts if it is in responses and has an 'output' value with at least 1 entry
63+ original_len = len (prompts )
64+ responses_to_keep = []
65+ prompts_without_existing_responses = []
66+ for p in prompts :
67+ for r in restored_responses :
68+ if p ["name" ] == r ["name" ] and \
69+ p ["parallelism_model" ] == r ["parallelism_model" ] and \
70+ p ["prompt" ] == r ["prompt" ] and \
71+ args .temperature == r ["temperature" ] and \
72+ args .prompted == r ["prompted" ] and \
73+ args .num_samples_per_prompt == len (r ["outputs" ]):
74+ responses_to_keep .append (r )
75+ break
76+ else :
77+ prompts_without_existing_responses .append (p )
78+ prompts = prompts_without_existing_responses
79+ print (f"[restore_from] Skipping { original_len - len (prompts )} prompts that already have responses. " +
80+ f"{ len (prompts )} prompts left." )
81+
82+ # write restored responses to cache
83+ if args .cache is not None :
84+ with open (args .cache , 'a' ) as jsonl_file :
85+ for response in responses_to_keep :
86+ jsonl_file .write (json .dumps (response ) + "\n " )
87+ print (f"[restore_from] Wrote { len (responses_to_keep )} restored responses to cache" )
88+
89+ """ Initialize inference config """
90+ inference_config = get_inference_config (args .model , prompted = args .prompted )
91+
92+ prompts_repeated = [p for p in prompts for _ in range (args .num_samples_per_prompt )]
93+
94+ """ Initialize vLLM engine """
95+ llm = LLM (model = args .model , tensor_parallel_size = torch .cuda .device_count ())
96+
97+ # Configure sampling parameters
98+ sampling_params = SamplingParams (
99+ temperature = args .temperature if args .do_sample else 0 ,
100+ top_p = args .top_p if args .do_sample else 1.0 ,
101+ max_tokens = args .max_new_tokens ,
102+ n = 1 , # We handle multiple samples manually
103+ )
104+
105+ """ Generate code """
106+ if not args .restart and args .cache is not None and os .path .exists (args .cache ):
107+ with open (args .cache , 'r' ) as jsonl_file :
108+ responses = [json .loads (line ) for line in jsonl_file ]
109+ responses = [r for r in responses if r ["temperature" ] == args .temperature and r ["prompted" ] == args .prompted
110+ and args .num_samples_per_prompt == len (r ["outputs" ])
111+ and any (p ["name" ] == r ["name" ] and p ["prompt" ] == r ["prompt" ] and p ["parallelism_model" ] == r ["parallelism_model" ] for p in prompts )]
112+ else :
113+ responses = []
114+
115+ cur_prompt = None
116+ start_time = time .time ()
117+ total_tokens = 0
118+
119+ # Format all prompts
120+ formatted_prompts = [inference_config .format_prompt (p ["prompt" ]) for p in prompts_repeated ]
121+
122+ # Generate all outputs at once
123+ outputs = llm .generate (formatted_prompts , sampling_params )
124+
125+ # Process outputs
126+ for idx , (prompt , output ) in enumerate (zip (prompts_repeated , outputs )):
127+ if idx % args .num_samples_per_prompt == 0 :
128+ cur_prompt = prompt .copy ()
129+ cur_prompt .update ({
130+ "temperature" : args .temperature ,
131+ "top_p" : args .top_p ,
132+ "do_sample" : args .do_sample ,
133+ "max_new_tokens" : args .max_new_tokens ,
134+ "prompted" : args .prompted
135+ })
136+ cur_prompt ["outputs" ] = []
137+ cur_prompt ["raw_outputs" ] = []
138+ prompt_str = cur_prompt ["prompt" ]
139+
140+ # Count tokens and clean output
141+ # FIXME: This is to keep the same behavior as generate.py
142+ huggingface_style_output = output .prompt + output .outputs [0 ].text
143+ total_tokens += len (llm .get_tokenizer ().encode (huggingface_style_output ))
144+ cleaned_output = inference_config .clean_output (huggingface_style_output , prompt_str )
145+ cur_prompt ["outputs" ].append (cleaned_output )
146+ cur_prompt ["raw_outputs" ].append (huggingface_style_output )
147+
148+ if idx % args .num_samples_per_prompt == args .num_samples_per_prompt - 1 :
149+ responses .append (cur_prompt )
150+
151+ if not args .restart and args .cache is not None :
152+ with open (args .cache , 'a' ) as jsonl_file :
153+ jsonl_file .write (json .dumps (cur_prompt ) + "\n " )
154+
155+ end_time = time .time ()
156+ tokens_per_second = total_tokens / (end_time - start_time )
157+ print (f"Generated { len (responses )} code samples in { end_time - start_time :.2f} seconds ({ tokens_per_second :.2f} tokens per second)" )
158+
159+ """ Save responses to JSON file """
160+ with open (args .output , 'w' ) as output_file :
161+ json .dump (responses , output_file , indent = 4 )
0 commit comments