Skip to content

Commit 61a4b32

Browse files
committed
[feat] Enhance LLM stats tests with chunked prefill support
- Updated test parameters to include `enable_chunked_prefill` for both synchronous and asynchronous LLM stats tests. - Modified `validate_stats` function to account for chunked prefill behavior in result validation. - Improved test harnesses to handle new parameter and ensure correct behavior with chunked prefill enabled. Signed-off-by: Robin Kobus <[email protected]>
1 parent 1da1751 commit 61a4b32

File tree

2 files changed

+133
-70
lines changed

2 files changed

+133
-70
lines changed

tests/unittest/llmapi/test_llm.py

Lines changed: 117 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,39 +2045,64 @@ def batch_task():
20452045
batch_task()
20462046

20472047

2048-
def validate_stats(results,
2049-
pytorch_backend,
2050-
max_tokens,
2051-
enable_iter_req_stats=False):
2048+
def validate_stats(
2049+
*,
2050+
results,
2051+
pytorch_backend,
2052+
max_tokens,
2053+
use_overlap=False,
2054+
enable_chunked_prefill=False,
2055+
enable_iter_req_stats=False,
2056+
):
20522057
assert results
2053-
assert len(results) == max_tokens if pytorch_backend else max_tokens + 1
2058+
expected_num_results = max_tokens if pytorch_backend else max_tokens + 1
2059+
if enable_chunked_prefill:
2060+
expected_num_results += 1
2061+
assert len(results) == expected_num_results
2062+
2063+
context_iterations = 2 if enable_chunked_prefill else 1
2064+
generation_iterations = max_tokens - 1
20542065
for iter, result in enumerate(results):
20552066
ifbStats = result["inflightBatchingStats"]
2056-
expected_num_scheduled = 1 if (iter < max_tokens) else 0
2057-
assert ifbStats["numScheduledRequests"] == expected_num_scheduled
2058-
if iter == 0:
2067+
2068+
if iter < context_iterations:
2069+
assert ifbStats["numScheduledRequests"] == 1
20592070
assert ifbStats["numContextRequests"] == 1
20602071
assert ifbStats["numGenRequests"] == 0
20612072
assert result["numActiveRequests"] == 1
2062-
elif iter == max_tokens:
2063-
assert ifbStats["numContextRequests"] == 0
2064-
assert ifbStats["numGenRequests"] == 0
2065-
assert result["numActiveRequests"] == 0
2066-
else:
2073+
elif iter < (context_iterations + generation_iterations):
2074+
assert ifbStats["numScheduledRequests"] == 1
20672075
assert ifbStats["numContextRequests"] == 0
20682076
assert ifbStats["numGenRequests"] == 1
20692077
assert result["numActiveRequests"] == 1
2078+
else:
2079+
assert ifbStats["numScheduledRequests"] == 0
2080+
assert ifbStats["numContextRequests"] == 0
2081+
assert ifbStats["numGenRequests"] == 0
2082+
assert result["numActiveRequests"] == 0
20702083

20712084
if enable_iter_req_stats:
20722085
assert "requestStats" in result
20732086
req_stats = result["requestStats"]
20742087
assert len(req_stats) == 1
20752088
req_stat = req_stats[0]
2076-
assert req_stat["numGeneratedTokens"] == iter + 1
2089+
if iter < (context_iterations - 1):
2090+
# If use_overlap, the stats are one iteration ahead
2091+
assert req_stat[
2092+
"stage"] == "GENERATION_IN_PROGRESS" if use_overlap else "CONTEXT_IN_PROGRESS"
2093+
assert req_stat[
2094+
"contextPrefillPosition"] == 54 if use_overlap else 32
2095+
assert req_stat["numGeneratedTokens"] == 0
2096+
elif iter < (context_iterations - 1 + generation_iterations):
2097+
assert req_stat["stage"] == "GENERATION_IN_PROGRESS"
2098+
assert req_stat["contextPrefillPosition"] == 54
2099+
assert req_stat["numGeneratedTokens"] == iter - (
2100+
context_iterations - 1) + 1
2101+
else:
2102+
assert req_stat["stage"] == "GENERATION_COMPLETE"
2103+
assert req_stat["contextPrefillPosition"] == 54
2104+
assert req_stat["numGeneratedTokens"] == max_tokens
20772105
assert req_stat["scheduled"] == True
2078-
assert req_stat[
2079-
"stage"] == "GENERATION_IN_PROGRESS" if iter + 1 < max_tokens else "GENERATION_COMPLETE"
2080-
assert req_stat["contextPrefillPosition"] == 4
20812106

20822107
expected_num_completed = 1 if iter == len(results) - 1 else 0
20832108

@@ -2087,9 +2112,11 @@ def validate_stats(results,
20872112

20882113

20892114
def llm_get_stats_test_harness(tp_size: int = 1,
2115+
pp_size: int = 1,
20902116
return_context_logits: bool = False,
20912117
pytorch_backend: bool = False,
20922118
use_overlap: bool = False,
2119+
enable_chunked_prefill: bool = False,
20932120
enable_iter_req_stats: bool = False):
20942121

20952122
if return_context_logits and pytorch_backend:
@@ -2103,6 +2130,7 @@ def llm_get_stats_test_harness(tp_size: int = 1,
21032130
print("return_context_logits: ", return_context_logits)
21042131
print("pytorch_backend: ", pytorch_backend)
21052132
print("use_overlap: ", use_overlap)
2133+
print("enable_chunked_prefill: ", enable_chunked_prefill)
21062134
print("enable_iter_req_stats: ", enable_iter_req_stats)
21072135
print("-------------")
21082136

@@ -2113,6 +2141,10 @@ def llm_get_stats_test_harness(tp_size: int = 1,
21132141
llm_args_extra["gather_generation_logits"] = True
21142142
sampling_args_extra["return_context_logits"] = True
21152143

2144+
if enable_chunked_prefill:
2145+
llm_args_extra["enable_chunked_prefill"] = True
2146+
llm_args_extra["max_num_tokens"] = 32
2147+
21162148
if pytorch_backend:
21172149
llm_args_extra.update(
21182150
dict(enable_iter_perf_stats=True,
@@ -2125,27 +2157,38 @@ def llm_get_stats_test_harness(tp_size: int = 1,
21252157
if not pytorch_backend:
21262158
llm_args_extra["fast_build"] = True
21272159

2128-
llm = LLM_CLASS(model=llama_model_path,
2129-
kv_cache_config=global_kvcache_config,
2130-
tensor_parallel_size=tp_size,
2131-
**llm_args_extra)
2160+
with LLM_CLASS(model=llama_model_path,
2161+
kv_cache_config=global_kvcache_config,
2162+
tensor_parallel_size=tp_size,
2163+
pipeline_parallel_size=pp_size,
2164+
**llm_args_extra) as llm:
21322165

2133-
max_tokens = 5
2134-
sampling_params = SamplingParams(max_tokens=max_tokens,
2135-
**sampling_args_extra)
2166+
max_tokens = 5
2167+
sampling_params = SamplingParams(max_tokens=max_tokens,
2168+
**sampling_args_extra)
21362169

2137-
for output in llm.generate(prompts, sampling_params=sampling_params):
2138-
print(output)
2170+
long_prompts = [
2171+
"A B C D E F G H I J K L M N O P Q R S T U V W X Y Z " * 2
2172+
]
21392173

2140-
results = llm.get_stats(2)
2174+
for output in llm.generate(long_prompts,
2175+
sampling_params=sampling_params):
2176+
print(output)
2177+
2178+
results = llm.get_stats(2)
21412179

2142-
validate_stats(results, pytorch_backend, max_tokens, enable_iter_req_stats)
2180+
validate_stats(results=results,
2181+
pytorch_backend=pytorch_backend,
2182+
max_tokens=max_tokens,
2183+
use_overlap=use_overlap,
2184+
enable_chunked_prefill=enable_chunked_prefill,
2185+
enable_iter_req_stats=enable_iter_req_stats)
21432186

2144-
assert not llm.get_stats(2)
2187+
assert not llm.get_stats(2)
21452188

2146-
# test that IterationResult()._done is properly set
2147-
_ = llm.generate(prompts, sampling_params=sampling_params)
2148-
assert llm.get_stats(2)
2189+
# test that IterationResult()._done is properly set
2190+
_ = llm.generate(prompts, sampling_params=sampling_params)
2191+
assert llm.get_stats(2)
21492192

21502193

21512194
@pytest.mark.parametrize("return_context_logits", [True, False])
@@ -2213,9 +2256,11 @@ def test_llm_get_queued_stats():
22132256

22142257

22152258
def llm_get_stats_async_test_harness(tp_size: int = 1,
2259+
pp_size: int = 1,
22162260
return_context_logits: bool = False,
22172261
pytorch_backend: bool = False,
22182262
use_overlap: bool = False,
2263+
enable_chunked_prefill: bool = False,
22192264
enable_iter_req_stats: bool = False):
22202265

22212266
if return_context_logits and pytorch_backend:
@@ -2229,6 +2274,7 @@ def llm_get_stats_async_test_harness(tp_size: int = 1,
22292274
print("return_context_logits: ", return_context_logits)
22302275
print("pytorch_backend: ", pytorch_backend)
22312276
print("use_overlap: ", use_overlap)
2277+
print("enable_chunked_prefill: ", enable_chunked_prefill)
22322278
print("enable_iter_req_stats: ", enable_iter_req_stats)
22332279
print("-------------")
22342280

@@ -2238,6 +2284,10 @@ def llm_get_stats_async_test_harness(tp_size: int = 1,
22382284
llm_args_extra["build_config"] = BuildConfig(gather_context_logits=True)
22392285
sampling_args_extra["return_context_logits"] = True
22402286

2287+
if enable_chunked_prefill:
2288+
llm_args_extra["enable_chunked_prefill"] = True
2289+
llm_args_extra["max_num_tokens"] = 32
2290+
22412291
if pytorch_backend:
22422292
llm_args_extra.update(
22432293
dict(enable_iter_perf_stats=True,
@@ -2248,38 +2298,47 @@ def llm_get_stats_async_test_harness(tp_size: int = 1,
22482298
LLM_CLASS = LLM
22492299
llm_args_extra["fast_build"] = True
22502300

2251-
llm = LLM_CLASS(model=llama_model_path,
2252-
kv_cache_config=global_kvcache_config,
2253-
tensor_parallel_size=tp_size,
2254-
**llm_args_extra)
2301+
with LLM_CLASS(model=llama_model_path,
2302+
kv_cache_config=global_kvcache_config,
2303+
tensor_parallel_size=tp_size,
2304+
**llm_args_extra) as llm:
22552305

2256-
max_tokens = 6
2257-
sampling_params = SamplingParams(max_tokens=max_tokens,
2258-
**sampling_args_extra)
2306+
max_tokens = 6
2307+
sampling_params = SamplingParams(max_tokens=max_tokens,
2308+
**sampling_args_extra)
22592309

2260-
async def task0():
2261-
async for output in llm.generate_async(prompts[0],
2262-
streaming=True,
2263-
sampling_params=sampling_params):
2264-
print(output)
2310+
long_prompts = [
2311+
"A B C D E F G H I J K L M N O P Q R S T U V W X Y Z " * 2
2312+
]
22652313

2266-
async def task1():
2267-
results = []
2268-
await asyncio.sleep(
2269-
3) # ensure there's stats to collect for the assertion
2270-
async for stats in llm.get_stats_async(timeout=2):
2271-
results.append(stats)
2314+
async def task0():
2315+
async for output in llm.generate_async(
2316+
long_prompts[0],
2317+
streaming=True,
2318+
sampling_params=sampling_params):
2319+
print(output)
22722320

2273-
assert results
2274-
if not use_overlap:
2275-
validate_stats(results, pytorch_backend, max_tokens,
2276-
enable_iter_req_stats)
2321+
async def task1():
2322+
results = []
2323+
await asyncio.sleep(
2324+
3) # ensure there's stats to collect for the assertion
2325+
async for stats in llm.get_stats_async(timeout=2):
2326+
results.append(stats)
2327+
2328+
assert results
2329+
if not use_overlap:
2330+
validate_stats(results=results,
2331+
pytorch_backend=pytorch_backend,
2332+
max_tokens=max_tokens,
2333+
use_overlap=use_overlap,
2334+
enable_chunked_prefill=enable_chunked_prefill,
2335+
enable_iter_req_stats=enable_iter_req_stats)
22772336

2278-
async def main():
2279-
for i in range(2): # test recurrent usage
2280-
await asyncio.gather(task0(), task1())
2337+
async def main():
2338+
for i in range(2): # test recurrent usage
2339+
await asyncio.gather(task0(), task1())
22812340

2282-
asyncio.run(main())
2341+
asyncio.run(main())
22832342

22842343

22852344
@pytest.mark.parametrize("return_context_logits", [True, False])

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,36 +54,40 @@ def test_tinyllama_logits_processor(enable_chunked_prefill):
5454

5555
@skip_ray
5656
@pytest.mark.parametrize(
57-
"return_context_logits, use_overlap, enable_iter_req_stats", [
58-
(False, False, False),
59-
(False, False, True),
60-
(False, True, False),
61-
(False, True, True),
57+
"return_context_logits, use_overlap, enable_chunked_prefill, enable_iter_req_stats",
58+
[
59+
(False, False, False, True),
60+
(False, False, True, True),
61+
(False, True, False, True),
62+
(False, True, True, True),
6263
])
6364
def test_llm_get_stats(return_context_logits, use_overlap,
64-
enable_iter_req_stats):
65+
enable_chunked_prefill, enable_iter_req_stats):
6566
llm_get_stats_test_harness(tp_size=1,
67+
pp_size=1,
6668
return_context_logits=return_context_logits,
6769
pytorch_backend=True,
6870
use_overlap=use_overlap,
71+
enable_chunked_prefill=enable_chunked_prefill,
6972
enable_iter_req_stats=enable_iter_req_stats)
7073

7174

7275
@skip_ray
7376
@pytest.mark.parametrize(
74-
"return_context_logits, use_overlap, enable_iter_req_stats", [
75-
(False, False, False),
76-
(False, False, True),
77-
(False, True, False),
78-
(False, True, True),
77+
"return_context_logits, use_overlap, enable_chunked_prefill, enable_iter_req_stats",
78+
[
79+
(False, False, False, True),
80+
(False, True, False, True),
7981
])
8082
def test_llm_get_stats_async(return_context_logits, use_overlap,
81-
enable_iter_req_stats):
83+
enable_chunked_prefill, enable_iter_req_stats):
8284
llm_get_stats_async_test_harness(
8385
tp_size=1,
86+
pp_size=1,
8487
return_context_logits=return_context_logits,
8588
pytorch_backend=True,
8689
use_overlap=use_overlap,
90+
enable_chunked_prefill=enable_chunked_prefill,
8791
enable_iter_req_stats=enable_iter_req_stats)
8892

8993

0 commit comments

Comments
 (0)