Skip to content

Commit 67b4221

Browse files
authored
[Core][5/N] Fully working chunked prefill e2e (vllm-project#3884)
1 parent 63e7176 commit 67b4221

26 files changed

+927
-315
lines changed

.buildkite/test-pipeline.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ steps:
2929
- pytest -v -s test_pynccl.py
3030
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
3131
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
32+
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
33+
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
3234

3335
- label: Engine Test
3436
command: pytest -v -s engine tokenization test_sequence.py test_config.py

benchmarks/benchmark_latency.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
177177
help='block size of key/value cache')
178178
parser.add_argument(
179179
'--enable-chunked-prefill',
180-
type=bool,
181-
default=False,
180+
action='store_true',
182181
help='If True, the prefill requests can be chunked based on the '
183182
'max_num_batched_tokens')
184183
parser.add_argument(

benchmarks/benchmark_throughput.py

+38-24
Original file line numberDiff line numberDiff line change
@@ -74,25 +74,31 @@ def run_vllm(
7474
quantization_param_path: Optional[str],
7575
device: str,
7676
enable_prefix_caching: bool,
77+
enable_chunked_prefill: bool,
78+
max_num_batched_tokens: int,
7779
gpu_memory_utilization: float = 0.9,
7880
download_dir: Optional[str] = None,
7981
) -> float:
8082
from vllm import LLM, SamplingParams
81-
llm = LLM(model=model,
82-
tokenizer=tokenizer,
83-
quantization=quantization,
84-
tensor_parallel_size=tensor_parallel_size,
85-
seed=seed,
86-
trust_remote_code=trust_remote_code,
87-
dtype=dtype,
88-
max_model_len=max_model_len,
89-
gpu_memory_utilization=gpu_memory_utilization,
90-
enforce_eager=enforce_eager,
91-
kv_cache_dtype=kv_cache_dtype,
92-
quantization_param_path=quantization_param_path,
93-
device=device,
94-
enable_prefix_caching=enable_prefix_caching,
95-
download_dir=download_dir)
83+
llm = LLM(
84+
model=model,
85+
tokenizer=tokenizer,
86+
quantization=quantization,
87+
tensor_parallel_size=tensor_parallel_size,
88+
seed=seed,
89+
trust_remote_code=trust_remote_code,
90+
dtype=dtype,
91+
max_model_len=max_model_len,
92+
gpu_memory_utilization=gpu_memory_utilization,
93+
enforce_eager=enforce_eager,
94+
kv_cache_dtype=kv_cache_dtype,
95+
quantization_param_path=quantization_param_path,
96+
device=device,
97+
enable_prefix_caching=enable_prefix_caching,
98+
download_dir=download_dir,
99+
enable_chunked_prefill=enable_chunked_prefill,
100+
max_num_batched_tokens=max_num_batched_tokens,
101+
)
96102

97103
# Add the requests to the engine.
98104
for prompt, _, output_len in requests:
@@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
213219
args.output_len)
214220

215221
if args.backend == "vllm":
216-
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
217-
args.quantization, args.tensor_parallel_size,
218-
args.seed, args.n, args.use_beam_search,
219-
args.trust_remote_code, args.dtype,
220-
args.max_model_len, args.enforce_eager,
221-
args.kv_cache_dtype,
222-
args.quantization_param_path, args.device,
223-
args.enable_prefix_caching,
224-
args.gpu_memory_utilization, args.download_dir)
222+
elapsed_time = run_vllm(
223+
requests, args.model, args.tokenizer, args.quantization,
224+
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
225+
args.trust_remote_code, args.dtype, args.max_model_len,
226+
args.enforce_eager, args.kv_cache_dtype,
227+
args.quantization_param_path, args.device,
228+
args.enable_prefix_caching, args.enable_chunked_prefill,
229+
args.max_num_batched_tokens, args.gpu_memory_utilization,
230+
args.download_dir)
225231
elif args.backend == "hf":
226232
assert args.tensor_parallel_size == 1
227233
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -335,6 +341,14 @@ def main(args: argparse.Namespace):
335341
"--enable-prefix-caching",
336342
action='store_true',
337343
help="enable automatic prefix caching for vLLM backend.")
344+
parser.add_argument("--enable-chunked-prefill",
345+
action='store_true',
346+
help="enable chunked prefill for vLLM backend.")
347+
parser.add_argument('--max-num-batched-tokens',
348+
type=int,
349+
default=None,
350+
help='maximum number of batched tokens per '
351+
'iteration')
338352
parser.add_argument('--download-dir',
339353
type=str,
340354
default=None,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Compare the outputs of HF and vLLM when using greedy sampling.
2+
3+
It tests chunked prefill. Chunked prefill can be enabled by
4+
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
5+
prefill requests are chunked.
6+
7+
Run `pytest tests/models/test_chunked_prefill.py`.
8+
"""
9+
import pytest
10+
11+
MODELS = [
12+
"facebook/opt-125m",
13+
"meta-llama/Llama-2-7b-hf",
14+
]
15+
16+
17+
@pytest.mark.parametrize("model", MODELS)
18+
@pytest.mark.parametrize("dtype", ["half"])
19+
@pytest.mark.parametrize("max_tokens", [32])
20+
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
21+
@pytest.mark.parametrize("enforce_eager", [False, True])
22+
# NOTE: Increasing this in this suite will fail CI because we currently cannot
23+
# reset distributed env properly. Use a value > 1 just when you test.
24+
@pytest.mark.parametrize("tensor_parallel_size", [1])
25+
def test_models(
26+
hf_runner,
27+
vllm_runner,
28+
example_prompts,
29+
model: str,
30+
dtype: str,
31+
max_tokens: int,
32+
chunked_prefill_token_size: int,
33+
enforce_eager: bool,
34+
tensor_parallel_size: int,
35+
) -> None:
36+
if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16
37+
and not enforce_eager):
38+
pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} "
39+
"for high TP to save testing time.")
40+
max_num_seqs = min(chunked_prefill_token_size, 256)
41+
enable_chunked_prefill = False
42+
max_num_batched_tokens = None
43+
if chunked_prefill_token_size != -1:
44+
enable_chunked_prefill = True
45+
max_num_batched_tokens = chunked_prefill_token_size
46+
47+
hf_model = hf_runner(model, dtype=dtype)
48+
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
49+
del hf_model
50+
51+
vllm_model = vllm_runner(
52+
model,
53+
dtype=dtype,
54+
max_num_batched_tokens=max_num_batched_tokens,
55+
enable_chunked_prefill=enable_chunked_prefill,
56+
tensor_parallel_size=tensor_parallel_size,
57+
enforce_eager=enforce_eager,
58+
max_num_seqs=max_num_seqs,
59+
)
60+
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
61+
del vllm_model
62+
print(vllm_outputs[0])
63+
64+
for i in range(len(example_prompts)):
65+
hf_output_ids, hf_output_str = hf_outputs[i]
66+
vllm_output_ids, vllm_output_str = vllm_outputs[i]
67+
assert hf_output_str == vllm_output_str, (
68+
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
69+
assert hf_output_ids == vllm_output_ids, (
70+
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

tests/core/test_chunked_prefill_scheduler.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,10 @@ def test_chunk():
104104
# One chunked prefill, and one decoding.
105105
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
106106
assert set(get_sequence_groups(out)) == set(running)
107-
# The first one is decoding.
108-
assert seq_group_meta[0].token_chunk_size == 1
107+
# The first one is prefill. Scheduler guarantees ordering.
108+
assert seq_group_meta[0].token_chunk_size == 56
109109
# The second one is a chunked prefill.
110-
assert seq_group_meta[1].token_chunk_size == 56
110+
assert seq_group_meta[1].token_chunk_size == 1
111111
assert out.num_prefill_groups == 1
112112
assert out.num_batched_tokens == 57
113113

@@ -157,12 +157,12 @@ def test_complex():
157157
# Decoding & chunked prefill & first chunk of 3rd request is scheduled.
158158
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
159159
assert len(get_sequence_groups(out)) == 3
160-
# The first one is decoding.
161-
assert seq_group_meta[0].token_chunk_size == 1
162-
# The second one is a chunked prefill.
160+
# The first one is the first chunked prefill.
161+
assert seq_group_meta[0].token_chunk_size == 7
162+
# The second one is the second new chunked prefill.
163163
assert seq_group_meta[1].token_chunk_size == 56
164-
# The third one is also chunked.
165-
assert seq_group_meta[2].token_chunk_size == 7
164+
# The last one is decode.
165+
assert seq_group_meta[2].token_chunk_size == 1
166166
# Two of them are in chunked prefill.
167167
assert out.num_prefill_groups == 2
168168
assert out.num_batched_tokens == 64

tests/distributed/test_basic_distributed_correctness.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,16 @@ def test_models(
3333
dtype: str,
3434
max_tokens: int,
3535
) -> None:
36+
3637
hf_model = hf_runner(model, dtype=dtype)
3738
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
3839
del hf_model
3940

40-
vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2)
41+
vllm_model = vllm_runner(
42+
model,
43+
dtype=dtype,
44+
tensor_parallel_size=2,
45+
)
4146
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
4247
del vllm_model
4348

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
2+
vLLM will allocate all the available memory, so we need to run the tests one
3+
by one. The solution is to pass arguments (model name) by environment
4+
variables.
5+
6+
Run:
7+
```sh
8+
TEST_DIST_MODEL=facebook/opt-125m pytest \
9+
test_chunked_prefill_distributed.py
10+
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
11+
test_chunked_prefill_distributed.py
12+
```
13+
"""
14+
import os
15+
16+
import pytest
17+
import torch
18+
19+
MODELS = [
20+
os.environ["TEST_DIST_MODEL"],
21+
]
22+
23+
24+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
25+
reason="Need at least 2 GPUs to run the test.")
26+
@pytest.mark.parametrize("model", MODELS)
27+
@pytest.mark.parametrize("dtype", ["half"])
28+
@pytest.mark.parametrize("max_tokens", [5])
29+
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
30+
def test_models(
31+
hf_runner,
32+
vllm_runner,
33+
example_prompts,
34+
model: str,
35+
dtype: str,
36+
max_tokens: int,
37+
chunked_prefill_token_size: int,
38+
) -> None:
39+
# Add a chunked prefill config.
40+
max_num_seqs = min(chunked_prefill_token_size, 256)
41+
assert chunked_prefill_token_size != -1
42+
enable_chunked_prefill = True
43+
max_num_batched_tokens = chunked_prefill_token_size
44+
45+
hf_model = hf_runner(model, dtype=dtype)
46+
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
47+
del hf_model
48+
49+
vllm_model = vllm_runner(
50+
model,
51+
dtype=dtype,
52+
tensor_parallel_size=2,
53+
max_num_seqs=max_num_seqs,
54+
enable_chunked_prefill=enable_chunked_prefill,
55+
max_num_batched_tokens=max_num_batched_tokens,
56+
)
57+
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
58+
del vllm_model
59+
60+
for i in range(len(example_prompts)):
61+
hf_output_ids, hf_output_str = hf_outputs[i]
62+
vllm_output_ids, vllm_output_str = vllm_outputs[i]
63+
assert hf_output_str == vllm_output_str, (
64+
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
65+
assert hf_output_ids == vllm_output_ids, (
66+
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

tests/entrypoints/test_openai_server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def server(zephyr_lora_files):
141141
"--max-cpu-loras",
142142
"2",
143143
"--max-num-seqs",
144-
"128"
144+
"128",
145145
])
146146
ray.get(server_runner.ready.remote())
147147
yield server_runner

tests/models/test_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"gpt2",
1313
"bigcode/tiny_starcoder_py",
1414
"EleutherAI/pythia-70m",
15-
"bigscience/bloom-560m",
15+
"bigscience/bloom-560m", # Testing alibi slopes.
1616
"microsoft/phi-2",
1717
"stabilityai/stablelm-3b-4e1t",
1818
# "allenai/OLMo-1B", # Broken

0 commit comments

Comments
 (0)