Skip to content

Commit f25c023

Browse files
committed
[test] Add LLM stats tests for pipeline parallel mode
- Introduced new test cases for LLM stats to validate behavior with multiple pipeline parallel configurations. - Added micro batch ID tracking to LLM stats and verify it in the test cases. - Used the new test cases to verify the new pipeline parallel mode behavior with chunked prefill enabled. Signed-off-by: Robin Kobus <[email protected]>
1 parent 61a4b32 commit f25c023

File tree

3 files changed

+103
-33
lines changed

3 files changed

+103
-33
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ def get_queued_req_stats(request_id: int) -> RequestStats:
697697
return req_stats
698698

699699
def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
700-
scheduled_batch) -> IterationStats:
700+
scheduled_batch, micro_batch_id) -> IterationStats:
701701
stats.iter_latency_ms = iter_latency_ms
702702

703703
stats.num_queued_requests = self.executor_request_queue.get_request_queue_size(
@@ -738,7 +738,7 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
738738
stats.inflight_batching_stats.num_paused_requests = len(
739739
scheduled_batch.paused_requests)
740740
stats.inflight_batching_stats.avg_num_decoded_tokens_per_iter = 0
741-
stats.inflight_batching_stats.micro_batch_id = 0
741+
stats.inflight_batching_stats.micro_batch_id = micro_batch_id
742742
if stats.specdec_stats is not None:
743743
stats.specdec_stats.draft_overhead = 0.0 if iter_latency_ms <= 0.0 else float(
744744
stats.specdec_stats.iter_latency_ms) / float(iter_latency_ms)
@@ -751,9 +751,13 @@ def _append_iter_stats(self,
751751
with self.stats_lock:
752752
self.stats.append((stats, req_stats))
753753

754-
def _process_iter_stats(self, finished_requests: list[LlmRequest],
755-
active_requests: List[LlmRequest],
756-
batch_state: BatchState):
754+
def _process_iter_stats(
755+
self,
756+
finished_requests: list[LlmRequest],
757+
active_requests: List[LlmRequest],
758+
batch_state: BatchState,
759+
micro_batch_id: int = 0,
760+
):
757761
iter_end_time = time.time()
758762
iter_latency_ms = (iter_end_time - batch_state.iter_start_time) * 1e3
759763
if batch_state.iter_stats is None:
@@ -766,9 +770,10 @@ def _process_iter_stats(self, finished_requests: list[LlmRequest],
766770
and self.enable_iter_perf_stats) else None
767771

768772
self._append_iter_stats(
769-
self._update_iter_stats(
770-
batch_state.iter_stats, iter_latency_ms, len(finished_requests),
771-
batch_state.sample_state.scheduled_requests), req_stats)
773+
self._update_iter_stats(batch_state.iter_stats, iter_latency_ms,
774+
len(finished_requests),
775+
batch_state.sample_state.scheduled_requests,
776+
micro_batch_id), req_stats)
772777

773778
def _executor_loop_cleanup(self):
774779

@@ -828,6 +833,7 @@ def _executor_loop_pp(self):
828833
self.num_scheduled_requests = scheduled_batch.batch_size
829834

830835
logger.debug(
836+
f'iteration {self.iter_counter}, microbatch {microbatch_id}, '
831837
f'has {len(self.active_requests)} active_requests, '
832838
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
833839
f'{len(scheduled_batch.generation_requests)} generation requests'
@@ -1008,9 +1014,11 @@ def _executor_loop_pp(self):
10081014
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
10091015

10101016
if self.enable_iter_perf_stats and previous_batch is not None:
1017+
sample_state = previous_batch.sample_state
1018+
sample_state.scheduled_requests.context_requests = previous_batch.scheduled_ctx_reqs
10111019
self._process_iter_stats(finished_requests,
10121020
self.active_requests,
1013-
previous_batch)
1021+
previous_batch, microbatch_id)
10141022

10151023
self.iter_counter += 1
10161024

tests/unittest/llmapi/test_llm.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2050,59 +2050,76 @@ def validate_stats(
20502050
results,
20512051
pytorch_backend,
20522052
max_tokens,
2053+
pp_size=1,
20532054
use_overlap=False,
20542055
enable_chunked_prefill=False,
20552056
enable_iter_req_stats=False,
20562057
):
20572058
assert results
2059+
for iter, result in enumerate(results):
2060+
ifbStats = result["inflightBatchingStats"]
2061+
print(f"iter: {iter}, ifbStats: {ifbStats}")
2062+
20582063
expected_num_results = max_tokens if pytorch_backend else max_tokens + 1
20592064
if enable_chunked_prefill:
20602065
expected_num_results += 1
20612066
assert len(results) == expected_num_results
20622067

20632068
context_iterations = 2 if enable_chunked_prefill else 1
20642069
generation_iterations = max_tokens - 1
2070+
microbatch_id = 0
20652071
for iter, result in enumerate(results):
20662072
ifbStats = result["inflightBatchingStats"]
20672073

20682074
if iter < context_iterations:
2069-
assert ifbStats["numScheduledRequests"] == 1
2070-
assert ifbStats["numContextRequests"] == 1
2071-
assert ifbStats["numGenRequests"] == 0
2072-
assert result["numActiveRequests"] == 1
2075+
assert ifbStats["numScheduledRequests"] == 1, f"iter: {iter}"
2076+
assert ifbStats["numContextRequests"] == 1, f"iter: {iter}"
2077+
assert ifbStats["numGenRequests"] == 0, f"iter: {iter}"
2078+
assert result["numActiveRequests"] == 1, f"iter: {iter}"
2079+
assert ifbStats["microBatchId"] == microbatch_id, f"iter: {iter}"
20732080
elif iter < (context_iterations + generation_iterations):
2074-
assert ifbStats["numScheduledRequests"] == 1
2075-
assert ifbStats["numContextRequests"] == 0
2076-
assert ifbStats["numGenRequests"] == 1
2077-
assert result["numActiveRequests"] == 1
2081+
assert ifbStats["numScheduledRequests"] == 1, f"iter: {iter}"
2082+
assert ifbStats["numContextRequests"] == 0, f"iter: {iter}"
2083+
assert ifbStats["numGenRequests"] == 1, f"iter: {iter}"
2084+
assert result["numActiveRequests"] == 1, f"iter: {iter}"
2085+
assert ifbStats["microBatchId"] == microbatch_id, f"iter: {iter}"
20782086
else:
2079-
assert ifbStats["numScheduledRequests"] == 0
2080-
assert ifbStats["numContextRequests"] == 0
2081-
assert ifbStats["numGenRequests"] == 0
2082-
assert result["numActiveRequests"] == 0
2087+
assert ifbStats["numScheduledRequests"] == 0, f"iter: {iter}"
2088+
assert ifbStats["numContextRequests"] == 0, f"iter: {iter}"
2089+
assert ifbStats["numGenRequests"] == 0, f"iter: {iter}"
2090+
assert result["numActiveRequests"] == 0, f"iter: {iter}"
2091+
assert ifbStats["microBatchId"] == microbatch_id, f"iter: {iter}"
2092+
2093+
# In pipeline parallel mode, increment microbatch_id for each context iteration except the last one,
2094+
# since the context chunks can be scheduled in each iteration.
2095+
if pp_size > 1 and iter < context_iterations - 1:
2096+
microbatch_id += 1
20832097

20842098
if enable_iter_req_stats:
2085-
assert "requestStats" in result
2099+
assert "requestStats" in result, f"iter: {iter}"
20862100
req_stats = result["requestStats"]
2087-
assert len(req_stats) == 1
2101+
assert len(req_stats) == 1, f"iter: {iter}"
20882102
req_stat = req_stats[0]
20892103
if iter < (context_iterations - 1):
20902104
# If use_overlap, the stats are one iteration ahead
20912105
assert req_stat[
2092-
"stage"] == "GENERATION_IN_PROGRESS" if use_overlap else "CONTEXT_IN_PROGRESS"
2106+
"stage"] == "GENERATION_IN_PROGRESS" if use_overlap else "CONTEXT_IN_PROGRESS", f"iter: {iter}"
20932107
assert req_stat[
2094-
"contextPrefillPosition"] == 54 if use_overlap else 32
2095-
assert req_stat["numGeneratedTokens"] == 0
2108+
"contextPrefillPosition"] == 54 if use_overlap else 32, f"iter: {iter}"
2109+
assert req_stat["numGeneratedTokens"] == 0, f"iter: {iter}"
20962110
elif iter < (context_iterations - 1 + generation_iterations):
2097-
assert req_stat["stage"] == "GENERATION_IN_PROGRESS"
2098-
assert req_stat["contextPrefillPosition"] == 54
2111+
assert req_stat[
2112+
"stage"] == "GENERATION_IN_PROGRESS", f"iter: {iter}"
2113+
assert req_stat["contextPrefillPosition"] == 54, f"iter: {iter}"
20992114
assert req_stat["numGeneratedTokens"] == iter - (
2100-
context_iterations - 1) + 1
2115+
context_iterations - 1) + 1, f"iter: {iter}"
21012116
else:
2102-
assert req_stat["stage"] == "GENERATION_COMPLETE"
2103-
assert req_stat["contextPrefillPosition"] == 54
2104-
assert req_stat["numGeneratedTokens"] == max_tokens
2105-
assert req_stat["scheduled"] == True
2117+
assert req_stat[
2118+
"stage"] == "GENERATION_COMPLETE", f"iter: {iter}"
2119+
assert req_stat["contextPrefillPosition"] == 54, f"iter: {iter}"
2120+
assert req_stat[
2121+
"numGeneratedTokens"] == max_tokens, f"iter: {iter}"
2122+
assert req_stat["scheduled"] == True, f"iter: {iter}"
21062123

21072124
expected_num_completed = 1 if iter == len(results) - 1 else 0
21082125

@@ -2178,6 +2195,7 @@ def llm_get_stats_test_harness(tp_size: int = 1,
21782195
results = llm.get_stats(2)
21792196

21802197
validate_stats(results=results,
2198+
pp_size=pp_size,
21812199
pytorch_backend=pytorch_backend,
21822200
max_tokens=max_tokens,
21832201
use_overlap=use_overlap,
@@ -2328,6 +2346,7 @@ async def task1():
23282346
assert results
23292347
if not use_overlap:
23302348
validate_stats(results=results,
2349+
pp_size=pp_size,
23312350
pytorch_backend=pytorch_backend,
23322351
max_tokens=max_tokens,
23332352
use_overlap=use_overlap,

tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
check_llama_7b_multi_lora_from_request_test_harness,
1212
check_phi3_lora_fused_modules_output_tp2_identical_to_tp1)
1313
from .test_llm import (_test_llm_capture_request_error, llama_model_path,
14+
llm_get_stats_test_harness,
1415
llm_return_logprobs_test_harness,
1516
tinyllama_logits_processor_test_harness)
1617
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
@@ -125,3 +126,45 @@ def test_llm_return_logprobs_streaming_tp2(prompt_logprobs, logprobs,
125126
streaming=True,
126127
backend="pytorch",
127128
tp_size=2)
129+
130+
131+
@skip_ray
132+
@pytest.mark.gpu2
133+
@pytest.mark.parametrize(
134+
"return_context_logits, enable_chunked_prefill, enable_iter_req_stats",
135+
[
136+
(False, False, True),
137+
(False, True, True),
138+
],
139+
)
140+
def test_llm_get_stats_pp2(return_context_logits, enable_chunked_prefill,
141+
enable_iter_req_stats):
142+
llm_get_stats_test_harness(
143+
tp_size=1,
144+
pp_size=2,
145+
return_context_logits=return_context_logits,
146+
pytorch_backend=True,
147+
enable_chunked_prefill=enable_chunked_prefill,
148+
enable_iter_req_stats=enable_iter_req_stats,
149+
)
150+
151+
152+
@skip_ray
153+
@pytest.mark.gpu4
154+
@pytest.mark.parametrize(
155+
"return_context_logits, enable_chunked_prefill, enable_iter_req_stats",
156+
[
157+
(False, False, True),
158+
(False, True, True),
159+
],
160+
)
161+
def test_llm_get_stats_pp4(return_context_logits, enable_chunked_prefill,
162+
enable_iter_req_stats):
163+
llm_get_stats_test_harness(
164+
tp_size=1,
165+
pp_size=4,
166+
return_context_logits=return_context_logits,
167+
pytorch_backend=True,
168+
enable_chunked_prefill=enable_chunked_prefill,
169+
enable_iter_req_stats=enable_iter_req_stats,
170+
)

0 commit comments

Comments
 (0)