Skip to content

Commit 42c78db

Browse files
committed
req/response latency breakdown
Signed-off-by: Erin Ho <[email protected]> update Signed-off-by: Erin Ho <[email protected]> change to sum over iterations
1 parent 7b6803b commit 42c78db

File tree

13 files changed

+510
-28
lines changed

13 files changed

+510
-28
lines changed

examples/llm-api/llm_inference_async.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,45 @@
44
import asyncio
55

66
from tensorrt_llm import LLM, SamplingParams
7+
from tensorrt_llm._tmp_utils import (analyze_average_timestamps,
8+
dump_timestamps_to_json)
79

810

911
def main():
1012
# model could accept HF model name or a path to local HF model.
11-
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
13+
llm = LLM(
14+
#model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
15+
model="/scratch/llm-models/llama-3.2-models/Llama-3.2-3B-Instruct-FP8",
16+
tensor_parallel_size=2)
1217

1318
# Sample prompts.
1419
prompts = [
1520
"Hello, my name is",
1621
"The capital of France is",
1722
"The future of AI is",
18-
]
23+
] * 100
1924

2025
# Create a sampling params.
2126
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
2227

28+
all_timestamps = []
29+
2330
# Async based on Python coroutines
2431
async def task(prompt: str):
2532
output = await llm.generate_async(prompt, sampling_params)
26-
print(
27-
f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
28-
)
33+
34+
if output.outputs[0].timestamps:
35+
all_timestamps.append(output.outputs[0].timestamps)
2936

3037
async def main():
3138
tasks = [task(prompt) for prompt in prompts]
3239
await asyncio.gather(*tasks)
3340

3441
asyncio.run(main())
3542

43+
analyze_average_timestamps(all_timestamps)
44+
dump_timestamps_to_json(all_timestamps, "timestamps_output.json")
45+
3646
# Got output like follows:
3747
# Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
3848
# Prompt: 'The capital of France is', Generated text: 'Paris.'

examples/ray_orchestrator/llm_inference_async_ray.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
# Generate text asynchronously with Ray orchestrator.
22
import asyncio
33

4+
import ray
5+
46
from tensorrt_llm import LLM, SamplingParams
7+
from tensorrt_llm._tmp_utils import (analyze_average_timestamps,
8+
dump_timestamps_to_json,
9+
print_fetch_statistics)
510
from tensorrt_llm.llmapi import KvCacheConfig
611

712

@@ -13,13 +18,14 @@ def main():
1318

1419
# model could accept HF model name or a path to local HF model.
1520
llm = LLM(
16-
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
21+
model="/scratch/llm-models/llama-3.2-models/Llama-3.2-3B-Instruct-FP8",
22+
# model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
1723
kv_cache_config=kv_cache_config,
1824
max_seq_len=1024,
1925
max_batch_size=1,
2026
orchestrator_type="ray", # Enable Ray orchestrator
2127
# Enable 2-way tensor parallelism
22-
# tensor_parallel_size=2
28+
tensor_parallel_size=2
2329
)
2430

2531
# Sample prompts.
@@ -32,19 +38,37 @@ def main():
3238
# Create a sampling params.
3339
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
3440

41+
# Collect all timestamps
42+
all_timestamps = []
43+
3544
# Async based on Python coroutines
3645
async def task(prompt: str):
3746
output = await llm.generate_async(prompt, sampling_params)
38-
print(
39-
f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
40-
)
47+
48+
if output.outputs[0].timestamps:
49+
all_timestamps.append(output.outputs[0].timestamps)
4150

4251
async def main():
4352
tasks = [task(prompt) for prompt in prompts]
4453
await asyncio.gather(*tasks)
4554

4655
asyncio.run(main())
4756

57+
analyze_average_timestamps(all_timestamps)
58+
dump_timestamps_to_json(all_timestamps, "timestamps_output.json")
59+
60+
if hasattr(llm._executor, 'workers'):
61+
for i, worker in enumerate(llm._executor.workers):
62+
try:
63+
stats = worker.call_worker_method.remote('get_fetch_statistics')
64+
result = ray.get(stats)
65+
if result:
66+
print_fetch_statistics(result['num_fetched_requests'],
67+
result['fetch_call_count'],
68+
rank=result['rank'])
69+
except Exception as e:
70+
print(f"Could not get fetch statistics from worker {i}: {e}")
71+
4872
# Got output like follows:
4973
# Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
5074
# Prompt: 'The capital of France is', Generated text: 'Paris.'

tensorrt_llm/_tmp_utils.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""
2+
Temporary utilities for timestamp analysis and Ray vs MPI latency comparison.
3+
"""
4+
import json
5+
import os
6+
from collections import Counter
7+
8+
from tensorrt_llm._utils import mpi_disabled
9+
10+
11+
def is_timestamp_debug_enabled():
12+
return os.environ.get('TIMESTAMP_DEBUG', '0') == '1'
13+
14+
15+
def calculate_latencies(timestamps):
16+
"""
17+
Calculate latency metrics from a single set of timestamps.
18+
Returns a dict of latencies in milliseconds, or None if timestamps missing.
19+
"""
20+
if not timestamps:
21+
return None
22+
23+
latencies = {}
24+
25+
latencies['submit_request_to_enqueue'] = (
26+
timestamps['worker_enqueue_request'] -
27+
timestamps['executor_submit_request']) * 1000
28+
29+
# only for the fetch
30+
latencies['queue_wait_time'] = (timestamps['request_fetched'] -
31+
timestamps['request_queued']) * 1000
32+
33+
latencies['num_iterations'] = timestamps['num_iterations']
34+
latencies['scheduling_wait_time'] = timestamps['scheduling_wait_time']
35+
latencies['pre_forward_overhead'] = timestamps['pre_forward_overhead']
36+
latencies['forward_step_time'] = timestamps['forward_step_time']
37+
latencies['post_processing_time'] = timestamps['post_processing_time']
38+
39+
latencies['execution_time'] = (timestamps['response_created'] -
40+
timestamps['request_fetched']) * 1000
41+
42+
latencies['response_handling'] = (timestamps['response_enqueued'] -
43+
timestamps['response_created']) * 1000
44+
45+
latencies['enqueue_response_to_handle'] = (
46+
timestamps['handle_response'] -
47+
timestamps['response_enqueued']) * 1000
48+
49+
latencies['total_e2e'] = (timestamps['handle_response'] -
50+
timestamps['executor_submit_request']) * 1000
51+
52+
latencies['communication_overhead'] = (
53+
(timestamps['worker_enqueue_request'] -
54+
timestamps['executor_submit_request']) +
55+
(timestamps['handle_response'] -
56+
timestamps['response_enqueued'])) * 1000
57+
58+
return latencies
59+
60+
61+
def analyze_average_timestamps(all_timestamps):
62+
if not is_timestamp_debug_enabled():
63+
return
64+
65+
if not all_timestamps:
66+
print("No timestamps available")
67+
return
68+
69+
mode = "[Ray]" if mpi_disabled() else "[MPI]"
70+
# Calculate latencies for each request
71+
all_latencies = []
72+
for ts in all_timestamps:
73+
latencies = calculate_latencies(ts)
74+
if latencies:
75+
all_latencies.append(latencies)
76+
77+
if not all_latencies:
78+
print("No valid latencies calculated")
79+
return
80+
81+
# Calculate averages
82+
print(f"\n=== [{mode}] Latency Breakdown (milliseconds) - Average over {len(all_timestamps)} request ===")
83+
84+
metrics = [
85+
('submit_request_to_enqueue', 'Submit to enqueue'),
86+
('queue_wait_time', 'Request Queue wait (1st fetch)'),
87+
('execution_time', 'Time in executor loop (sum of all iterations)'),
88+
('scheduling_wait_time', ' ├─ Scheduling wait'),
89+
('pre_forward_overhead', ' ├─ Pre-forward overhead'),
90+
('forward_step_time', ' ├─ Forward step'),
91+
('post_processing_time', ' └─ Post-processing'),
92+
('response_handling', 'Response handling (once)'),
93+
('enqueue_response_to_handle', 'Enqueue to handle (once)'),
94+
# ('num_iterations', 'Avg iterations per request'),
95+
# ('total_e2e', 'Total E2E latency'),
96+
# ('communication_overhead', 'Total communication overhead'),
97+
]
98+
99+
for metric_key, metric_name in metrics:
100+
if metric_key == 'num_iterations':
101+
print("")
102+
if metric_key == 'total_e2e':
103+
print(" " + "-" * 68)
104+
105+
values = [lat[metric_key] for lat in all_latencies if metric_key in lat]
106+
if values:
107+
avg = sum(values) / len(values)
108+
min_val = min(values)
109+
max_val = max(values)
110+
variance = sum((x - avg)**2 for x in values) / len(values)
111+
112+
if metric_key == 'num_iterations':
113+
print(f" {metric_name:48s}: {avg:8.1f} (min: {min_val:8.1f}, max: {max_val:9.1f})")
114+
else:
115+
print(
116+
f" {metric_name:48s}: {avg:8.3f} ms (min: {min_val:8.3f}, max: {max_val:9.3f}, var: {variance:10.3f})"
117+
)
118+
119+
print("=" * 70)
120+
121+
122+
def dump_timestamps_to_json(all_timestamps,
123+
output_file="timestamps_output.json"):
124+
if not is_timestamp_debug_enabled():
125+
return
126+
127+
if not all_timestamps:
128+
print("No timestamps to dump")
129+
return
130+
131+
print(
132+
f"\nDumping {len(all_timestamps)} timestamp records to {output_file}..."
133+
)
134+
with open(output_file, 'w') as f:
135+
json.dump(all_timestamps, f, indent=2)
136+
print(f"Timestamps saved to {output_file}")
137+
138+
139+
def print_fetch_statistics(num_fetched_requests, fetch_call_count, rank=None):
140+
if not is_timestamp_debug_enabled():
141+
return
142+
143+
if not num_fetched_requests:
144+
return
145+
146+
rank_str = f"[Rank {rank}]" if rank is not None else ""
147+
mode = "[Ray]" if mpi_disabled() else "[MPI]"
148+
149+
print(f"\n=== {mode}{rank_str} Fetch Request Statistics ===")
150+
print(f" Total fetch calls: {fetch_call_count}")
151+
152+
size_distribution = Counter(num_fetched_requests)
153+
print(f"\n Fetch Size Distribution:")
154+
for size in sorted(size_distribution.keys()):
155+
count = size_distribution[size]
156+
percentage = (count / len(num_fetched_requests)) * 100
157+
print(f" {size:3d} requests: {count:5d} times ({percentage:5.1f}%)")
158+
159+
print("=" * 70)

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
7171

7272
self._disable_mpi = mpi_disabled()
7373

74+
# DIAGNOSTIC: Track iteration count and timing per rank
75+
# self.iteration_count = 0
76+
# self.last_iteration_time = None
77+
7478
def _get_from_request_queue(
7579
self,
7680
timeout: Optional[datetime.timedelta]) -> List[RequestQueueItem]:
@@ -218,6 +222,7 @@ def _enqueue_impl(
218222
with self.enqueue_lock:
219223
assert self.active, "PyExecutor has already been shutdown."
220224
start_time = time.time()
225+
request_queued_time = time.time()
221226
for request, query in requests_and_queries:
222227
req_id = self._get_request_id()
223228
if self.enable_iter_perf_stats:
@@ -229,6 +234,11 @@ def _enqueue_impl(
229234
request,
230235
child_req_ids=child_req_ids,
231236
query=query))
237+
238+
if hasattr(request, 'py_timestamps') and request.py_timestamps is not None:
239+
if 'request_queued' not in request.py_timestamps:
240+
request.py_timestamps['request_queued'] = request_queued_time
241+
232242
req_ids.append(req_id)
233243
return req_ids
234244

@@ -268,24 +278,49 @@ def _fetch_and_process_requests(
268278
all_ranks_num_active_requests: Optional[List[int]] = None
269279
) -> List[RequestQueueItem]:
270280
"""Common logic for fetching and processing requests from the queue."""
281+
# # DIAGNOSTIC: Track iteration timing
282+
# import time as time_module
283+
# fetch_start = time_module.time()
284+
# self.iteration_count += 1
285+
286+
# # Track time between iterations
287+
# if self.last_iteration_time is not None:
288+
# iteration_gap_ms = (fetch_start - self.last_iteration_time) * 1000
289+
# else:
290+
# iteration_gap_ms = 0
291+
# self.last_iteration_time = fetch_start
292+
271293
# Calculate timeout
272-
idle = (total_num_active_requests == 0) and len(self.waiting_queue) == 0
273-
if idle:
274-
# In Ray path (TLLM_DISABLE_MPI=1), use a periodic heartbeat timeout so rank 0
275-
# reaches the broadcast path regularly to prevent trtllm-serve timeout when idle.
276-
timeout = datetime.timedelta(
277-
seconds=1200) if self._disable_mpi else None
278-
else:
279-
timeout = datetime.timedelta(0)
294+
295+
# Tentatively revert this to rule this out.
296+
timeout = None if (total_num_active_requests == 0) and len(
297+
self.waiting_queue) == 0 else datetime.timedelta(0)
298+
# idle = (total_num_active_requests == 0) and len(self.waiting_queue) == 0
299+
# if idle:
300+
# # In Ray path (TLLM_DISABLE_MPI=1), use a periodic heartbeat timeout so rank 0
301+
# # reaches the broadcast path regularly to prevent trtllm-serve timeout when idle.
302+
# timeout = datetime.timedelta(
303+
# seconds=1200) if self._disable_mpi else None
304+
# else:
305+
# timeout = datetime.timedelta(0)
280306

281307
# Fetch requests from rank 0
282308
new_requests = []
283309
if self.dist.rank == 0:
284310
new_requests = self._get_from_request_queue(timeout)
285311

286312
# Broadcast requests and handle Python objects
313+
# DIAGNOSTIC: Measure broadcast time
314+
# import time as time_module
315+
# broadcast_start = time_module.time()
287316
new_requests, py_request_objects = self._handle_request_broadcasting(
288317
new_requests)
318+
# broadcast_end = time_module.time()
319+
# broadcast_duration_ms = (broadcast_end - broadcast_start) * 1000
320+
# if broadcast_duration_ms > 100: # Log if > 100ms from BOTH ranks
321+
# print(
322+
# f"[BROADCAST_DELAY][Rank {self.dist.rank}][Iter {self.iteration_count}] Broadcast took {broadcast_duration_ms:.2f} ms, num_requests={len(new_requests)}",
323+
# flush=True)
289324

290325
# Validate and filter requests
291326
new_requests = self._validate_and_filter_requests(new_requests)
@@ -307,6 +342,16 @@ def _fetch_and_process_requests(
307342
if self.enable_iter_perf_stats and self.dist.rank == 0:
308343
self._update_new_active_requests_queue_latency(new_requests)
309344

345+
# DIAGNOSTIC: Log total fetch time
346+
# fetch_end = time_module.time()
347+
# fetch_total_ms = (fetch_end - fetch_start) * 1000
348+
# if fetch_total_ms > 100 or self.iteration_count % 10 == 0: # Log if > 100ms or every 10 iterations from BOTH ranks
349+
# print(
350+
# f"[FETCH_TIMING][Rank {self.dist.rank}][Iter {self.iteration_count}] "
351+
# f"gap_since_last_iter={iteration_gap_ms:.2f}ms, fetch_took={fetch_total_ms:.2f}ms, "
352+
# f"active_reqs={total_num_active_requests}, fetched={len(new_requests)}, queue_size={self.request_queue.qsize()}",
353+
# flush=True)
354+
310355
return new_requests
311356

312357
@nvtx_range("_fetch_new_requests")

0 commit comments

Comments
 (0)