Skip to content

Commit 43b8dd1

Browse files
committed
add more enqueue checks
1 parent 6e9802a commit 43b8dd1

File tree

11 files changed

+226
-76
lines changed

11 files changed

+226
-76
lines changed

examples/llm-api/llm_inference_async.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,32 @@
55

66
from tensorrt_llm import LLM, SamplingParams
77
from tensorrt_llm._tmp_utils import (analyze_average_timestamps,
8-
dump_timestamps_to_json)
8+
dump_timestamps_to_json,
9+
print_enqueue_statistics)
10+
from tensorrt_llm.llmapi import KvCacheConfig
911

1012

1113
def main():
1214
# model could accept HF model name or a path to local HF model.
15+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8,
16+
max_tokens=4096,
17+
enable_block_reuse=True)
18+
1319
llm = LLM(
1420
#model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
1521
model="/scratch/llm-models/llama-3.2-models/Llama-3.2-3B-Instruct-FP8",
16-
tensor_parallel_size=2)
22+
# tensor_parallel_size=2
23+
max_seq_len=1024,
24+
kv_cache_config=kv_cache_config
25+
# max_batch_size=1,
26+
)
1727

1828
# Sample prompts.
1929
prompts = [
2030
"Hello, my name is",
2131
"The capital of France is",
2232
"The future of AI is",
23-
] * 100
33+
] * 1000
2434

2535
# Create a sampling params.
2636
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
@@ -43,6 +53,12 @@ async def main():
4353
analyze_average_timestamps(all_timestamps)
4454
dump_timestamps_to_json(all_timestamps, "timestamps_output.json")
4555

56+
print(
57+
f"executor type = {type(llm._executor)}, has enqueue_timings = {hasattr(llm._executor, 'enqueue_timings')}"
58+
)
59+
if hasattr(llm._executor, 'enqueue_timings'):
60+
print_enqueue_statistics(llm._executor.enqueue_timings)
61+
4662
# Got output like follows:
4763
# Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
4864
# Prompt: 'The capital of France is', Generated text: 'Paris.'

examples/ray_orchestrator/llm_inference_async_ray.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
from tensorrt_llm import LLM, SamplingParams
77
from tensorrt_llm._tmp_utils import (analyze_average_timestamps,
88
dump_timestamps_to_json,
9+
print_enqueue_statistics,
910
print_fetch_statistics)
1011
from tensorrt_llm.llmapi import KvCacheConfig
1112

1213

1314
def main():
1415
# Configure KV cache memory usage fraction.
15-
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5,
16+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8,
1617
max_tokens=4096,
1718
enable_block_reuse=True)
1819

@@ -22,18 +23,20 @@ def main():
2223
# model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
2324
kv_cache_config=kv_cache_config,
2425
max_seq_len=1024,
25-
max_batch_size=1,
26+
# max_batch_size=1,
2627
orchestrator_type="ray", # Enable Ray orchestrator
2728
# Enable 2-way tensor parallelism
28-
tensor_parallel_size=2
29+
# tensor_parallel_size=2
2930
)
3031

3132
# Sample prompts.
3233
prompts = [
3334
"Hello, my name is",
3435
"The capital of France is",
3536
"The future of AI is",
36-
]
37+
] * 1000
38+
39+
#* 100
3740

3841
# Create a sampling params.
3942
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
@@ -48,6 +51,10 @@ async def task(prompt: str):
4851
if output.outputs[0].timestamps:
4952
all_timestamps.append(output.outputs[0].timestamps)
5053

54+
# print(
55+
# f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
56+
# )
57+
5158
async def main():
5259
tasks = [task(prompt) for prompt in prompts]
5360
await asyncio.gather(*tasks)
@@ -57,6 +64,9 @@ async def main():
5764
analyze_average_timestamps(all_timestamps)
5865
dump_timestamps_to_json(all_timestamps, "timestamps_output.json")
5966

67+
if hasattr(llm._executor, 'enqueue_timings'):
68+
print_enqueue_statistics(llm._executor.enqueue_timings)
69+
6070
if hasattr(llm._executor, 'workers'):
6171
for i, worker in enumerate(llm._executor.workers):
6272
try:

tensorrt_llm/_tmp_utils.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,22 @@ def calculate_latencies(timestamps):
3737
latencies['post_processing_time'] = timestamps['post_processing_time']
3838

3939
latencies['execution_time'] = (timestamps['response_created'] -
40-
timestamps['request_fetched']) * 1000
40+
timestamps['request_fetched']) * 1000
4141

4242
latencies['response_handling'] = (timestamps['response_enqueued'] -
43-
timestamps['response_created']) * 1000
43+
timestamps['response_created']) * 1000
4444

4545
latencies['enqueue_response_to_handle'] = (
46-
timestamps['handle_response'] -
47-
timestamps['response_enqueued']) * 1000
46+
timestamps['handle_response'] - timestamps['response_enqueued']) * 1000
4847

4948
latencies['total_e2e'] = (timestamps['handle_response'] -
5049
timestamps['executor_submit_request']) * 1000
5150

5251
latencies['communication_overhead'] = (
5352
(timestamps['worker_enqueue_request'] -
54-
timestamps['executor_submit_request']) +
53+
timestamps['executor_submit_request']) +
5554
(timestamps['handle_response'] -
56-
timestamps['response_enqueued'])) * 1000
55+
timestamps['response_enqueued'])) * 1000
5756

5857
return latencies
5958

@@ -79,7 +78,20 @@ def analyze_average_timestamps(all_timestamps):
7978
return
8079

8180
# Calculate averages
82-
print(f"\n=== [{mode}] Latency Breakdown (milliseconds) - Average over {len(all_timestamps)} request ===")
81+
print(
82+
f"\n=== [{mode}] Latency Breakdown (milliseconds) - Average over {len(all_timestamps)} request ==="
83+
)
84+
85+
# Print first 20 submit_request_to_enqueue values
86+
submit_to_enqueue_values = [
87+
lat['submit_request_to_enqueue'] for lat in all_latencies
88+
if 'submit_request_to_enqueue' in lat
89+
]
90+
if submit_to_enqueue_values:
91+
first_20 = ', '.join(
92+
[f"{x:.2f}" for x in submit_to_enqueue_values[:20]])
93+
print(f" Submit to enqueue (first 20, ms): {first_20}", flush=True)
94+
print(flush=True)
8395

8496
metrics = [
8597
('submit_request_to_enqueue', 'Submit to enqueue'),
@@ -108,9 +120,11 @@ def analyze_average_timestamps(all_timestamps):
108120
min_val = min(values)
109121
max_val = max(values)
110122
variance = sum((x - avg)**2 for x in values) / len(values)
111-
123+
112124
if metric_key == 'num_iterations':
113-
print(f" {metric_name:48s}: {avg:8.1f} (min: {min_val:8.1f}, max: {max_val:9.1f})")
125+
print(
126+
f" {metric_name:48s}: {avg:8.1f} (min: {min_val:8.1f}, max: {max_val:9.1f})"
127+
)
114128
else:
115129
print(
116130
f" {metric_name:48s}: {avg:8.3f} ms (min: {min_val:8.3f}, max: {max_val:9.3f}, var: {variance:10.3f})"
@@ -156,4 +170,42 @@ def print_fetch_statistics(num_fetched_requests, fetch_call_count, rank=None):
156170
percentage = (count / len(num_fetched_requests)) * 100
157171
print(f" {size:3d} requests: {count:5d} times ({percentage:5.1f}%)")
158172

173+
print(f"\n Num fetched requests (all iterations): {num_fetched_requests}")
174+
175+
print("=" * 70)
176+
177+
178+
def print_enqueue_statistics(enqueue_timings):
179+
if not is_timestamp_debug_enabled():
180+
return
181+
182+
if not enqueue_timings:
183+
return
184+
185+
mode = "[Ray]" if mpi_disabled() else "[MPI]"
186+
num_requests = len(enqueue_timings)
187+
188+
print(
189+
f"\n=== {mode} Enqueue Request Timing Statistics ({num_requests} requests) ==="
190+
)
191+
first_20_enqueue = ', '.join([f"{x:.2f}" for x in enqueue_timings[:20]])
192+
print(f" Direct enqueue (first 20, ms): {first_20_enqueue}", flush=True)
193+
194+
avg = sum(enqueue_timings) / num_requests
195+
min_val = min(enqueue_timings)
196+
max_val = max(enqueue_timings)
197+
198+
# Calculate percentiles
199+
sorted_timings = sorted(enqueue_timings)
200+
p10 = sorted_timings[int(num_requests *
201+
0.1)] if num_requests > 1 else sorted_timings[0]
202+
p50 = sorted_timings[num_requests // 2]
203+
p90 = sorted_timings[int(num_requests * 0.9)]
204+
205+
print(f" Avg: {avg:.2f} ms")
206+
print(f" Min: {min_val:.2f} ms")
207+
print(f" Max: {max_val:.2f} ms")
208+
print(f" P10: {p10:.2f} ms")
209+
print(f" P50: {p50:.2f} ms")
210+
print(f" P90: {p90:.2f} ms")
159211
print("=" * 70)

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,12 @@ def _enqueue_impl(
235235
child_req_ids=child_req_ids,
236236
query=query))
237237

238-
if hasattr(request, 'py_timestamps') and request.py_timestamps is not None:
238+
if hasattr(
239+
request,
240+
'py_timestamps') and request.py_timestamps is not None:
239241
if 'request_queued' not in request.py_timestamps:
240-
request.py_timestamps['request_queued'] = request_queued_time
242+
request.py_timestamps[
243+
'request_queued'] = request_queued_time
241244

242245
req_ids.append(req_id)
243246
return req_ids

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,8 @@ def create_response(self,
593593
result, is_final = super().create_serialized_result(
594594
use_fast_logits, mpi_world_rank)
595595

596-
response_timestamps = self.py_timestamps.copy() if self.py_timestamps is not None else None
596+
response_timestamps = self.py_timestamps.copy(
597+
) if self.py_timestamps is not None else None
597598
if response_timestamps is not None:
598599
response_timestamps['response_created'] = time.time()
599600

@@ -775,15 +776,15 @@ def executor_request_to_llm_request(
775776
arrival_time=getattr(executor_request, "py_arrival_time", None),
776777
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
777778
None),
778-
py_timestamps=getattr(executor_request, "py_timestamps",
779-
{
780-
'scheduling_wait_time': 0.0,
781-
'pre_forward_overhead': 0.0,
782-
'forward_step_time': 0.0,
783-
'post_processing_time': 0.0,
784-
'num_iterations': 0,
785-
'last_iteration_end': None,
786-
} if is_timestamp_debug_enabled() else None))
779+
py_timestamps=getattr(
780+
executor_request, "py_timestamps", {
781+
'scheduling_wait_time': 0.0,
782+
'pre_forward_overhead': 0.0,
783+
'forward_step_time': 0.0,
784+
'post_processing_time': 0.0,
785+
'num_iterations': 0,
786+
'last_iteration_end': None,
787+
} if is_timestamp_debug_enabled() else None))
787788
if child_req_ids:
788789
for child_id in child_req_ids:
789790
llm_request.create_child_request(child_id)

0 commit comments

Comments
 (0)