@@ -74,25 +74,31 @@ def run_vllm(
74
74
quantization_param_path : Optional [str ],
75
75
device : str ,
76
76
enable_prefix_caching : bool ,
77
+ enable_chunked_prefill : bool ,
78
+ max_num_batched_tokens : int ,
77
79
gpu_memory_utilization : float = 0.9 ,
78
80
download_dir : Optional [str ] = None ,
79
81
) -> float :
80
82
from vllm import LLM , SamplingParams
81
- llm = LLM (model = model ,
82
- tokenizer = tokenizer ,
83
- quantization = quantization ,
84
- tensor_parallel_size = tensor_parallel_size ,
85
- seed = seed ,
86
- trust_remote_code = trust_remote_code ,
87
- dtype = dtype ,
88
- max_model_len = max_model_len ,
89
- gpu_memory_utilization = gpu_memory_utilization ,
90
- enforce_eager = enforce_eager ,
91
- kv_cache_dtype = kv_cache_dtype ,
92
- quantization_param_path = quantization_param_path ,
93
- device = device ,
94
- enable_prefix_caching = enable_prefix_caching ,
95
- download_dir = download_dir )
83
+ llm = LLM (
84
+ model = model ,
85
+ tokenizer = tokenizer ,
86
+ quantization = quantization ,
87
+ tensor_parallel_size = tensor_parallel_size ,
88
+ seed = seed ,
89
+ trust_remote_code = trust_remote_code ,
90
+ dtype = dtype ,
91
+ max_model_len = max_model_len ,
92
+ gpu_memory_utilization = gpu_memory_utilization ,
93
+ enforce_eager = enforce_eager ,
94
+ kv_cache_dtype = kv_cache_dtype ,
95
+ quantization_param_path = quantization_param_path ,
96
+ device = device ,
97
+ enable_prefix_caching = enable_prefix_caching ,
98
+ download_dir = download_dir ,
99
+ enable_chunked_prefill = enable_chunked_prefill ,
100
+ max_num_batched_tokens = max_num_batched_tokens ,
101
+ )
96
102
97
103
# Add the requests to the engine.
98
104
for prompt , _ , output_len in requests :
@@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
213
219
args .output_len )
214
220
215
221
if args .backend == "vllm" :
216
- elapsed_time = run_vllm (requests , args . model , args . tokenizer ,
217
- args .quantization , args .tensor_parallel_size ,
218
- args .seed , args .n , args .use_beam_search ,
219
- args .trust_remote_code , args .dtype ,
220
- args .max_model_len , args .enforce_eager ,
221
- args .kv_cache_dtype ,
222
- args .quantization_param_path , args .device ,
223
- args .enable_prefix_caching ,
224
- args . gpu_memory_utilization , args .download_dir )
222
+ elapsed_time = run_vllm (
223
+ requests , args . model , args .tokenizer , args .quantization ,
224
+ args . tensor_parallel_size , args .seed , args .n , args .use_beam_search ,
225
+ args .trust_remote_code , args .dtype , args . max_model_len ,
226
+ args .enforce_eager , args .kv_cache_dtype ,
227
+ args . quantization_param_path , args .device ,
228
+ args .enable_prefix_caching , args .enable_chunked_prefill ,
229
+ args . max_num_batched_tokens , args .gpu_memory_utilization ,
230
+ args .download_dir )
225
231
elif args .backend == "hf" :
226
232
assert args .tensor_parallel_size == 1
227
233
elapsed_time = run_hf (requests , args .model , tokenizer , args .n ,
@@ -335,6 +341,14 @@ def main(args: argparse.Namespace):
335
341
"--enable-prefix-caching" ,
336
342
action = 'store_true' ,
337
343
help = "enable automatic prefix caching for vLLM backend." )
344
+ parser .add_argument ("--enable-chunked-prefill" ,
345
+ action = 'store_true' ,
346
+ help = "enable chunked prefill for vLLM backend." )
347
+ parser .add_argument ('--max-num-batched-tokens' ,
348
+ type = int ,
349
+ default = None ,
350
+ help = 'maximum number of batched tokens per '
351
+ 'iteration' )
338
352
parser .add_argument ('--download-dir' ,
339
353
type = str ,
340
354
default = None ,
0 commit comments