diff --git a/examples/llm-api/llm_inference_async.py b/examples/llm-api/llm_inference_async.py index 64c81eb198e..4cdc2c5f2d6 100644 --- a/examples/llm-api/llm_inference_async.py +++ b/examples/llm-api/llm_inference_async.py @@ -4,28 +4,45 @@ import asyncio from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm._tmp_utils import (analyze_average_timestamps, + dump_timestamps_to_json, + print_enqueue_statistics) +from tensorrt_llm.llmapi import KvCacheConfig def main(): # model could accept HF model name or a path to local HF model. - llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8, + max_tokens=4096, + enable_block_reuse=True) + + llm = LLM( + #model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + model="/scratch/llm-models/llama-3.2-models/Llama-3.2-3B-Instruct-FP8", + # tensor_parallel_size=2 + max_seq_len=1024, + kv_cache_config=kv_cache_config + # max_batch_size=1, + ) # Sample prompts. prompts = [ "Hello, my name is", "The capital of France is", "The future of AI is", - ] + ] * 1000 # Create a sampling params. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + all_timestamps = [] + # Async based on Python coroutines async def task(prompt: str): output = await llm.generate_async(prompt, sampling_params) - print( - f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}" - ) + + if output.outputs[0].timestamps: + all_timestamps.append(output.outputs[0].timestamps) async def main(): tasks = [task(prompt) for prompt in prompts] @@ -33,6 +50,15 @@ async def main(): asyncio.run(main()) + analyze_average_timestamps(all_timestamps) + dump_timestamps_to_json(all_timestamps, "timestamps_output.json") + + print( + f"executor type = {type(llm._executor)}, has enqueue_timings = {hasattr(llm._executor, 'enqueue_timings')}" + ) + if hasattr(llm._executor, 'enqueue_timings'): + print_enqueue_statistics(llm._executor.enqueue_timings) + # Got output like follows: # Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming' # Prompt: 'The capital of France is', Generated text: 'Paris.' diff --git a/examples/ray_orchestrator/llm_inference_async_ray.py b/examples/ray_orchestrator/llm_inference_async_ray.py index ea57975291a..aff5470c1ae 100644 --- a/examples/ray_orchestrator/llm_inference_async_ray.py +++ b/examples/ray_orchestrator/llm_inference_async_ray.py @@ -1,22 +1,29 @@ # Generate text asynchronously with Ray orchestrator. import asyncio +import ray + from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm._tmp_utils import (analyze_average_timestamps, + dump_timestamps_to_json, + print_enqueue_statistics, + print_fetch_statistics) from tensorrt_llm.llmapi import KvCacheConfig def main(): # Configure KV cache memory usage fraction. - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5, + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8, max_tokens=4096, enable_block_reuse=True) # model could accept HF model name or a path to local HF model. llm = LLM( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + model="/scratch/llm-models/llama-3.2-models/Llama-3.2-3B-Instruct-FP8", + # model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", kv_cache_config=kv_cache_config, max_seq_len=1024, - max_batch_size=1, + # max_batch_size=1, orchestrator_type="ray", # Enable Ray orchestrator # Enable 2-way tensor parallelism # tensor_parallel_size=2 @@ -27,17 +34,26 @@ def main(): "Hello, my name is", "The capital of France is", "The future of AI is", - ] + ] * 1000 + + #* 100 # Create a sampling params. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + # Collect all timestamps + all_timestamps = [] + # Async based on Python coroutines async def task(prompt: str): output = await llm.generate_async(prompt, sampling_params) - print( - f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}" - ) + + if output.outputs[0].timestamps: + all_timestamps.append(output.outputs[0].timestamps) + + # print( + # f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}" + # ) async def main(): tasks = [task(prompt) for prompt in prompts] @@ -45,6 +61,24 @@ async def main(): asyncio.run(main()) + analyze_average_timestamps(all_timestamps) + dump_timestamps_to_json(all_timestamps, "timestamps_output.json") + + if hasattr(llm._executor, 'enqueue_timings'): + print_enqueue_statistics(llm._executor.enqueue_timings) + + if hasattr(llm._executor, 'workers'): + for i, worker in enumerate(llm._executor.workers): + try: + stats = worker.call_worker_method.remote('get_fetch_statistics') + result = ray.get(stats) + if result: + print_fetch_statistics(result['num_fetched_requests'], + result['fetch_call_count'], + rank=result['rank']) + except Exception as e: + print(f"Could not get fetch statistics from worker {i}: {e}") + # Got output like follows: # Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming' # Prompt: 'The capital of France is', Generated text: 'Paris.' diff --git a/tensorrt_llm/_tmp_utils.py b/tensorrt_llm/_tmp_utils.py new file mode 100644 index 00000000000..ff0c8d0961e --- /dev/null +++ b/tensorrt_llm/_tmp_utils.py @@ -0,0 +1,211 @@ +""" +Temporary utilities for timestamp analysis and Ray vs MPI latency comparison. +""" +import json +import os +from collections import Counter + +from tensorrt_llm._utils import mpi_disabled + + +def is_timestamp_debug_enabled(): + return os.environ.get('TIMESTAMP_DEBUG', '0') == '1' + + +def calculate_latencies(timestamps): + """ + Calculate latency metrics from a single set of timestamps. + Returns a dict of latencies in milliseconds, or None if timestamps missing. + """ + if not timestamps: + return None + + latencies = {} + + latencies['submit_request_to_enqueue'] = ( + timestamps['worker_enqueue_request'] - + timestamps['executor_submit_request']) * 1000 + + # only for the fetch + latencies['queue_wait_time'] = (timestamps['request_fetched'] - + timestamps['request_queued']) * 1000 + + latencies['num_iterations'] = timestamps['num_iterations'] + latencies['scheduling_wait_time'] = timestamps['scheduling_wait_time'] + latencies['pre_forward_overhead'] = timestamps['pre_forward_overhead'] + latencies['forward_step_time'] = timestamps['forward_step_time'] + latencies['post_processing_time'] = timestamps['post_processing_time'] + + latencies['execution_time'] = (timestamps['response_created'] - + timestamps['request_fetched']) * 1000 + + latencies['response_handling'] = (timestamps['response_enqueued'] - + timestamps['response_created']) * 1000 + + latencies['enqueue_response_to_handle'] = ( + timestamps['handle_response'] - timestamps['response_enqueued']) * 1000 + + latencies['total_e2e'] = (timestamps['handle_response'] - + timestamps['executor_submit_request']) * 1000 + + latencies['communication_overhead'] = ( + (timestamps['worker_enqueue_request'] - + timestamps['executor_submit_request']) + + (timestamps['handle_response'] - + timestamps['response_enqueued'])) * 1000 + + return latencies + + +def analyze_average_timestamps(all_timestamps): + if not is_timestamp_debug_enabled(): + return + + if not all_timestamps: + print("No timestamps available") + return + + mode = "[Ray]" if mpi_disabled() else "[MPI]" + # Calculate latencies for each request + all_latencies = [] + for ts in all_timestamps: + latencies = calculate_latencies(ts) + if latencies: + all_latencies.append(latencies) + + if not all_latencies: + print("No valid latencies calculated") + return + + # Calculate averages + print( + f"\n=== [{mode}] Latency Breakdown (milliseconds) - Average over {len(all_timestamps)} request ===" + ) + + # Print first 20 submit_request_to_enqueue values + submit_to_enqueue_values = [ + lat['submit_request_to_enqueue'] for lat in all_latencies + if 'submit_request_to_enqueue' in lat + ] + if submit_to_enqueue_values: + first_20 = ', '.join( + [f"{x:.2f}" for x in submit_to_enqueue_values[:20]]) + print(f" Submit to enqueue (first 20, ms): {first_20}", flush=True) + print(flush=True) + + metrics = [ + ('submit_request_to_enqueue', 'Submit to enqueue'), + ('queue_wait_time', 'Request Queue wait (1st fetch)'), + ('execution_time', 'Time in executor loop (sum of all iterations)'), + ('scheduling_wait_time', ' ├─ Scheduling wait'), + ('pre_forward_overhead', ' ├─ Pre-forward overhead'), + ('forward_step_time', ' ├─ Forward step'), + ('post_processing_time', ' └─ Post-processing'), + ('response_handling', 'Response handling (once)'), + ('enqueue_response_to_handle', 'Enqueue to handle (once)'), + # ('num_iterations', 'Avg iterations per request'), + # ('total_e2e', 'Total E2E latency'), + # ('communication_overhead', 'Total communication overhead'), + ] + + for metric_key, metric_name in metrics: + if metric_key == 'num_iterations': + print("") + if metric_key == 'total_e2e': + print(" " + "-" * 68) + + values = [lat[metric_key] for lat in all_latencies if metric_key in lat] + if values: + avg = sum(values) / len(values) + min_val = min(values) + max_val = max(values) + variance = sum((x - avg)**2 for x in values) / len(values) + + if metric_key == 'num_iterations': + print( + f" {metric_name:48s}: {avg:8.1f} (min: {min_val:8.1f}, max: {max_val:9.1f})" + ) + else: + print( + f" {metric_name:48s}: {avg:8.3f} ms (min: {min_val:8.3f}, max: {max_val:9.3f}, var: {variance:10.3f})" + ) + + print("=" * 70) + + +def dump_timestamps_to_json(all_timestamps, + output_file="timestamps_output.json"): + if not is_timestamp_debug_enabled(): + return + + if not all_timestamps: + print("No timestamps to dump") + return + + print( + f"\nDumping {len(all_timestamps)} timestamp records to {output_file}..." + ) + with open(output_file, 'w') as f: + json.dump(all_timestamps, f, indent=2) + print(f"Timestamps saved to {output_file}") + + +def print_fetch_statistics(num_fetched_requests, fetch_call_count, rank=None): + if not is_timestamp_debug_enabled(): + return + + if not num_fetched_requests: + return + + rank_str = f"[Rank {rank}]" if rank is not None else "" + mode = "[Ray]" if mpi_disabled() else "[MPI]" + + print(f"\n=== {mode}{rank_str} Fetch Request Statistics ===") + print(f" Total fetch calls: {fetch_call_count}") + + size_distribution = Counter(num_fetched_requests) + print(f"\n Fetch Size Distribution:") + for size in sorted(size_distribution.keys()): + count = size_distribution[size] + percentage = (count / len(num_fetched_requests)) * 100 + print(f" {size:3d} requests: {count:5d} times ({percentage:5.1f}%)") + + print(f"\n Num fetched requests (all iterations): {num_fetched_requests}") + + print("=" * 70) + + +def print_enqueue_statistics(enqueue_timings): + if not is_timestamp_debug_enabled(): + return + + if not enqueue_timings: + return + + mode = "[Ray]" if mpi_disabled() else "[MPI]" + num_requests = len(enqueue_timings) + + print( + f"\n=== {mode} Enqueue Request Timing Statistics ({num_requests} requests) ===" + ) + first_20_enqueue = ', '.join([f"{x:.2f}" for x in enqueue_timings[:20]]) + print(f" Direct enqueue (first 20, ms): {first_20_enqueue}", flush=True) + + avg = sum(enqueue_timings) / num_requests + min_val = min(enqueue_timings) + max_val = max(enqueue_timings) + + # Calculate percentiles + sorted_timings = sorted(enqueue_timings) + p10 = sorted_timings[int(num_requests * + 0.1)] if num_requests > 1 else sorted_timings[0] + p50 = sorted_timings[num_requests // 2] + p90 = sorted_timings[int(num_requests * 0.9)] + + print(f" Avg: {avg:.2f} ms") + print(f" Min: {min_val:.2f} ms") + print(f" Max: {max_val:.2f} ms") + print(f" P10: {p10:.2f} ms") + print(f" P50: {p50:.2f} ms") + print(f" P90: {p90:.2f} ms") + print("=" * 70) diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 66d95ba869b..5df3abab7f5 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -71,6 +71,10 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool, self._disable_mpi = mpi_disabled() + # DIAGNOSTIC: Track iteration count and timing per rank + # self.iteration_count = 0 + # self.last_iteration_time = None + def _get_from_request_queue( self, timeout: Optional[datetime.timedelta]) -> List[RequestQueueItem]: @@ -218,6 +222,7 @@ def _enqueue_impl( with self.enqueue_lock: assert self.active, "PyExecutor has already been shutdown." start_time = time.time() + request_queued_time = time.time() for request, query in requests_and_queries: req_id = self._get_request_id() if self.enable_iter_perf_stats: @@ -229,6 +234,14 @@ def _enqueue_impl( request, child_req_ids=child_req_ids, query=query)) + + if hasattr( + request, + 'py_timestamps') and request.py_timestamps is not None: + if 'request_queued' not in request.py_timestamps: + request.py_timestamps[ + 'request_queued'] = request_queued_time + req_ids.append(req_id) return req_ids @@ -268,15 +281,31 @@ def _fetch_and_process_requests( all_ranks_num_active_requests: Optional[List[int]] = None ) -> List[RequestQueueItem]: """Common logic for fetching and processing requests from the queue.""" + # # DIAGNOSTIC: Track iteration timing + # import time as time_module + # fetch_start = time_module.time() + # self.iteration_count += 1 + + # # Track time between iterations + # if self.last_iteration_time is not None: + # iteration_gap_ms = (fetch_start - self.last_iteration_time) * 1000 + # else: + # iteration_gap_ms = 0 + # self.last_iteration_time = fetch_start + # Calculate timeout - idle = (total_num_active_requests == 0) and len(self.waiting_queue) == 0 - if idle: - # In Ray path (TLLM_DISABLE_MPI=1), use a periodic heartbeat timeout so rank 0 - # reaches the broadcast path regularly to prevent trtllm-serve timeout when idle. - timeout = datetime.timedelta( - seconds=1200) if self._disable_mpi else None - else: - timeout = datetime.timedelta(0) + + # Tentatively revert this to rule this out. + timeout = None if (total_num_active_requests == 0) and len( + self.waiting_queue) == 0 else datetime.timedelta(0) + # idle = (total_num_active_requests == 0) and len(self.waiting_queue) == 0 + # if idle: + # # In Ray path (TLLM_DISABLE_MPI=1), use a periodic heartbeat timeout so rank 0 + # # reaches the broadcast path regularly to prevent trtllm-serve timeout when idle. + # timeout = datetime.timedelta( + # seconds=1200) if self._disable_mpi else None + # else: + # timeout = datetime.timedelta(0) # Fetch requests from rank 0 new_requests = [] @@ -284,8 +313,17 @@ def _fetch_and_process_requests( new_requests = self._get_from_request_queue(timeout) # Broadcast requests and handle Python objects + # DIAGNOSTIC: Measure broadcast time + # import time as time_module + # broadcast_start = time_module.time() new_requests, py_request_objects = self._handle_request_broadcasting( new_requests) + # broadcast_end = time_module.time() + # broadcast_duration_ms = (broadcast_end - broadcast_start) * 1000 + # if broadcast_duration_ms > 100: # Log if > 100ms from BOTH ranks + # print( + # f"[BROADCAST_DELAY][Rank {self.dist.rank}][Iter {self.iteration_count}] Broadcast took {broadcast_duration_ms:.2f} ms, num_requests={len(new_requests)}", + # flush=True) # Validate and filter requests new_requests = self._validate_and_filter_requests(new_requests) @@ -307,6 +345,16 @@ def _fetch_and_process_requests( if self.enable_iter_perf_stats and self.dist.rank == 0: self._update_new_active_requests_queue_latency(new_requests) + # DIAGNOSTIC: Log total fetch time + # fetch_end = time_module.time() + # fetch_total_ms = (fetch_end - fetch_start) * 1000 + # if fetch_total_ms > 100 or self.iteration_count % 10 == 0: # Log if > 100ms or every 10 iterations from BOTH ranks + # print( + # f"[FETCH_TIMING][Rank {self.dist.rank}][Iter {self.iteration_count}] " + # f"gap_since_last_iter={iteration_gap_ms:.2f}ms, fetch_took={fetch_total_ms:.2f}ms, " + # f"active_reqs={total_num_active_requests}, fetched={len(new_requests)}, queue_size={self.request_queue.qsize()}", + # flush=True) + return new_requests @nvtx_range("_fetch_new_requests") diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index f5263c991b5..24cd2dba229 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -1,3 +1,4 @@ +import time from copy import deepcopy from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union @@ -5,6 +6,7 @@ import torch import tensorrt_llm.bindings +from tensorrt_llm._tmp_utils import is_timestamp_debug_enabled from tensorrt_llm._torch.shared_tensor import SharedTensorContainer from tensorrt_llm.bindings import executor as tllm_executor from tensorrt_llm.executor.result import TokenLogprobs @@ -403,6 +405,7 @@ class LlmResponse: error_msg: Optional[str] = None result: Optional[LlmResult] = None client_id: Optional[int] = None + timestamps: Optional[Dict[str, float]] = None def has_error(self): return self.error_msg is not None @@ -452,6 +455,11 @@ def __init__( self.py_lora_path: str | None = kwargs.pop("py_lora_path", None) # Multimodal data self.py_multimodal_data = kwargs.pop("py_multimodal_data", None) + + default_timestamps = {} if is_timestamp_debug_enabled() else None + self.py_timestamps: Dict[str, + float] = kwargs.pop("py_timestamps", + default_timestamps) if llm_request is not None: super().__init__(llm_request) else: @@ -588,11 +596,18 @@ def create_response(self, """ result, is_final = super().create_serialized_result( use_fast_logits, mpi_world_rank) - return LlmResponse( - request_id=self.py_request_id - if self.is_child else self.parent_request_id, - result=LlmResult(result, self.py_result, is_final), - client_id=self.py_client_id) if len(result) > 0 else None + + response_timestamps = self.py_timestamps.copy( + ) if self.py_timestamps is not None else None + if response_timestamps is not None: + response_timestamps['response_created'] = time.time() + + return LlmResponse(request_id=self.py_request_id + if self.is_child else self.parent_request_id, + result=LlmResult(result, self.py_result, is_final), + client_id=self.py_client_id, + timestamps=response_timestamps if response_timestamps + else None) if len(result) > 0 else None @property def is_dummy(self): @@ -766,6 +781,15 @@ def executor_request_to_llm_request( py_multimodal_data=getattr(executor_request, "py_multimodal_data", None), kv_cache_retention_config=executor_request.kv_cache_retention_config) + py_timestamps=getattr( + executor_request, "py_timestamps", { + 'scheduling_wait_time': 0.0, + 'pre_forward_overhead': 0.0, + 'forward_step_time': 0.0, + 'post_processing_time': 0.0, + 'num_iterations': 0, + 'last_iteration_end': None, + } if is_timestamp_debug_enabled() else None)) if child_req_ids: for child_id in child_req_ids: llm_request.create_child_request(child_id) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 137642209b6..d1b8adcb3d0 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -20,6 +20,10 @@ except ImportError: from cuda import cudart +import time + +from tensorrt_llm._tmp_utils import (is_timestamp_debug_enabled, + print_fetch_statistics) from tensorrt_llm._torch.pyexecutor.resource_manager import ( ResourceManagerType, request_context) from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled, @@ -193,6 +197,13 @@ def __init__(self, self.dist = dist self.disable_overlap_scheduler = disable_overlap_scheduler + if is_timestamp_debug_enabled(): + self.num_fetched_requests = [] + self.fetch_call_count = 0 + else: + self.num_fetched_requests = None + self.fetch_call_count = None + # enqueue and _fetch_new_requests used data self.active = True self.max_beam_width = max_beam_width @@ -814,6 +825,7 @@ def _executor_loop_pp(self): # ensure the context is created, otherwise, some MPI calls will fail. CUASSERT(cudart.cudaSetDevice(self.device_id)) microbatch_id = 0 + timestamp_enabled = is_timestamp_debug_enabled() with self._profiler() as profile_step: iter_start_time = time.time() iter_stats = None @@ -822,6 +834,11 @@ def _executor_loop_pp(self): if self.enable_iter_perf_stats: iter_start_time = time.time() new_requests = self._fetch_and_activate_new_requests() + + if self.fetch_call_count is not None: + self.fetch_call_count += 1 + self.num_fetched_requests.append(len(new_requests)) + if self.should_stop_processing: break @@ -839,6 +856,9 @@ def _executor_loop_pp(self): scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( ) + batch_scheduled_time = time.time( + ) if timestamp_enabled else None + if self.kv_cache_transceiver: # For requests that are fitting disagg gen init, also prepare resources for KV cache manager self._prepare_disagg_gen_init( @@ -1024,6 +1044,10 @@ def _executor_loop_pp(self): self.active_requests, previous_batch) + print_fetch_statistics(self.num_fetched_requests, + self.fetch_call_count, + rank=self.dist.rank) + def wait_on_pp_send_handles(self, microbatch_id): if self.send_handles[microbatch_id] is not None: self.send_handles[microbatch_id].wait() @@ -1031,8 +1055,9 @@ def wait_on_pp_send_handles(self, microbatch_id): def _prepare_and_schedule_batch(self): new_requests = self._fetch_and_activate_new_requests() + num_fetched = len(new_requests) if self.should_stop_processing: - return None, None + return None, None, num_fetched if self.kv_cache_transceiver: self._check_disagg_gen_transfer_status() @@ -1097,7 +1122,7 @@ def _prepare_and_schedule_batch(self): f'has {len(self.active_requests)} active_request, ' f'scheduled {len(scheduled_batch.context_requests)} context requests and ' f'{len(scheduled_batch.generation_requests)} generation requests') - return scheduled_batch, iter_stats + return scheduled_batch, iter_stats, num_fetched def _kv_connector_start_batch(self, scheduled_batch): if self.kv_connector_manager: @@ -1130,6 +1155,7 @@ def _executor_loop(self): torch.cuda.set_device(self.device_id) # ensure the context is created, otherwise, some MPI calls will fail. CUASSERT(cudart.cudaSetDevice(self.device_id)) + timestamp_enabled = is_timestamp_debug_enabled() with self._profiler() as profile_step: sample_state = None iter_start_time = time.time() @@ -1139,10 +1165,19 @@ def _executor_loop(self): if self.enable_iter_perf_stats: iter_start_time = time.time() - scheduled_batch, iter_stats = self._prepare_and_schedule_batch() + scheduled_batch, iter_stats, num_fetched = self._prepare_and_schedule_batch( + ) + + if self.fetch_call_count is not None: + self.fetch_call_count += 1 + self.num_fetched_requests.append(num_fetched) + if scheduled_batch is None: break + batch_scheduled_time = time.time( + ) if timestamp_enabled else None + self._pause_requests(scheduled_batch.paused_requests) finished_requests = [] @@ -1186,12 +1221,20 @@ def _executor_loop(self): if hasattr(self.drafter, "guided_decoder"): self.guided_decoder.rollback_draft_tokens() + if timestamp_enabled: + forward_step_start = time.time() + batch_outputs = self._forward_step(scheduled_batch) + + if timestamp_enabled: + forward_step_end = time.time() + if self.guided_decoder is not None: self.guided_decoder.execute(batch_outputs['logits']) sample_state = self._sample_async(scheduled_batch, batch_outputs) + if self.drafter is not None: self.drafter.run_drafter_post(scheduled_batch, self.resource_manager, @@ -1199,6 +1242,42 @@ def _executor_loop(self): self._update_request_states(scheduled_batch) self._update_requests(sample_state, self.resource_manager) + + if timestamp_enabled: + iteration_end = time.time() + for req in scheduled_batch.all_requests(): + if hasattr(req, 'py_timestamps' + ) and req.py_timestamps is not None: + if 'batch_scheduled_time' not in req.py_timestamps: + req.py_timestamps[ + 'batch_scheduled_time'] = batch_scheduled_time + + if req.py_timestamps[ + 'last_iteration_end'] is None: + if 'request_fetched' in req.py_timestamps: + req.py_timestamps[ + 'scheduling_wait_time'] += ( + batch_scheduled_time - req. + py_timestamps['request_fetched'] + ) * 1000 + else: + req.py_timestamps[ + 'scheduling_wait_time'] += ( + batch_scheduled_time - req. + py_timestamps['last_iteration_end'] + ) * 1000 + + req.py_timestamps['pre_forward_overhead'] += ( + forward_step_start - + batch_scheduled_time) * 1000 + req.py_timestamps['forward_step_time'] += ( + forward_step_end - + forward_step_start) * 1000 + req.py_timestamps['post_processing_time'] += ( + iteration_end - forward_step_end) * 1000 + req.py_timestamps['num_iterations'] += 1 + req.py_timestamps[ + 'last_iteration_end'] = iteration_end if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: for req in scheduled_batch.context_requests: if req.is_context_only_request and ( @@ -1245,6 +1324,10 @@ def _executor_loop(self): iter_stats=iter_stats, iter_start_time=iter_start_time)) + print_fetch_statistics(self.num_fetched_requests, + self.fetch_call_count, + rank=self.dist.rank) + def _prepare_draft_requests(self): try: # Set draft tokens here to make the KV cache manager @@ -1274,6 +1357,7 @@ def _executor_loop_overlap(self): torch.cuda.set_device(self.device_id) # ensure the context is created, otherwise, some MPI calls will fail. CUASSERT(cudart.cudaSetDevice(self.device_id)) + timestamp_enabled = is_timestamp_debug_enabled() with self._profiler() as profile_step: iter_start_time = time.time() iter_stats = None @@ -1283,7 +1367,13 @@ def _executor_loop_overlap(self): if self.enable_iter_perf_stats: iter_start_time = time.time() - scheduled_batch, iter_stats = self._prepare_and_schedule_batch() + scheduled_batch, iter_stats, num_fetched = self._prepare_and_schedule_batch( + ) + + if self.fetch_call_count is not None: + self.fetch_call_count += 1 + self.num_fetched_requests.append(num_fetched) + if scheduled_batch is None: break # In gen-only benchmarking mode, wait until the number of scheduled generation @@ -1317,6 +1407,9 @@ def _executor_loop_overlap(self): else: can_forward = True + batch_scheduled_time = time.time( + ) if timestamp_enabled else None + self._pause_requests(scheduled_batch.paused_requests) if scheduled_batch.batch_size > 0: @@ -1368,9 +1461,15 @@ def _executor_loop_overlap(self): else: previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device + if timestamp_enabled: + forward_step_start = time.time() + batch_outputs = self._forward_step(scheduled_batch, previous_tensors_device) + if timestamp_enabled: + forward_step_end = time.time() + if target_inputs is not None: self._process_draft_results(scheduled_batch, draft_outputs, draft_batch) @@ -1396,10 +1495,47 @@ def _executor_loop_overlap(self): sample_state = self._sample_async(scheduled_batch, batch_outputs) + assert sample_state is not None, "Sampling failed" self._update_request_states(scheduled_batch) + if timestamp_enabled: + iteration_end = time.time() + for req in scheduled_batch.all_requests(): + if hasattr(req, 'py_timestamps' + ) and req.py_timestamps is not None: + if 'batch_scheduled_time' not in req.py_timestamps: + req.py_timestamps[ + 'batch_scheduled_time'] = batch_scheduled_time + + if req.py_timestamps[ + 'last_iteration_end'] is None: + if 'request_fetched' in req.py_timestamps: + req.py_timestamps[ + 'scheduling_wait_time'] += ( + batch_scheduled_time - req. + py_timestamps['request_fetched'] + ) * 1000 + else: + req.py_timestamps[ + 'scheduling_wait_time'] += ( + batch_scheduled_time - req. + py_timestamps['last_iteration_end'] + ) * 1000 + + req.py_timestamps['pre_forward_overhead'] += ( + forward_step_start - + batch_scheduled_time) * 1000 + req.py_timestamps['forward_step_time'] += ( + forward_step_end - + forward_step_start) * 1000 + req.py_timestamps['post_processing_time'] += ( + iteration_end - forward_step_end) * 1000 + req.py_timestamps['num_iterations'] += 1 + req.py_timestamps[ + 'last_iteration_end'] = iteration_end + ctx_transmission_reqs = self._send_disagg_ctx_cache( scheduled_batch.context_requests ) if self.kv_cache_transceiver else [] @@ -1423,6 +1559,10 @@ def _executor_loop_overlap(self): self._kv_connector_terminate_requests() + print_fetch_statistics(self.num_fetched_requests, + self.fetch_call_count, + rank=self.dist.rank) + def _process_previous_batch(self): if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs: for req in self.previous_batch.ctx_transmission_reqs: @@ -1507,6 +1647,17 @@ def _respond_if_invalid(request: LlmRequest) -> bool: if not _respond_if_invalid(request) ] + if is_timestamp_debug_enabled(): + request_fetched = time.time() + for request in validated_requests: + if hasattr( + request, + 'py_timestamps') and request.py_timestamps is not None: + if 'request_fetched' not in request.py_timestamps: + # only record the first fetch time of each request + request.py_timestamps[ + 'request_fetched'] = request_fetched + self.active_requests.extend(validated_requests) return validated_requests @@ -2099,6 +2250,9 @@ def _enqueue_responses(self, responses: Iterable[Tuple[int, LlmResponse]]): resp ) == LlmResponse and req_id in self.result_wait_queues and self.result_wait_queues[ req_id] is not None: + # Timestamp: Before Ray RPC to send response + if hasattr(resp, 'timestamps') and resp.timestamps: + resp.timestamps['response_enqueued'] = time.time() self.result_wait_queues[req_id].put_response.remote( resp.client_id, resp) self.response_cv.notify_all() diff --git a/tensorrt_llm/bench/benchmark/throughput.py b/tensorrt_llm/bench/benchmark/throughput.py index 6406b755c76..c30fe3016a4 100755 --- a/tensorrt_llm/bench/benchmark/throughput.py +++ b/tensorrt_llm/bench/benchmark/throughput.py @@ -6,6 +6,7 @@ from pathlib import Path import click +import ray from click_option_group import (MutuallyExclusiveOptionGroup, OptionGroup, optgroup) from huggingface_hub import snapshot_download @@ -20,6 +21,12 @@ from tensorrt_llm.bench.benchmark.utils.general import ( get_settings_from_engine, get_settings, ALL_SUPPORTED_BACKENDS) # isort: on +from tensorrt_llm._tmp_utils import (analyze_average_timestamps, + dump_timestamps_to_json, + is_timestamp_debug_enabled, + print_enqueue_statistics, + print_fetch_statistics) +from tensorrt_llm._utils import mpi_disabled from tensorrt_llm.bench.benchmark.utils.general import ( generate_warmup_dataset, update_sampler_args_with_extra_options) from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig @@ -472,6 +479,49 @@ def throughput_command( options.request_json, partial(report_utility.get_request_info, tokenizer)) report_utility.report_statistics() + + if is_timestamp_debug_enabled() and hasattr( + statistics, 'all_timestamps') and statistics.all_timestamps: + logger.info("\n") + analyze_average_timestamps(statistics.all_timestamps) + dump_timestamps_to_json(statistics.all_timestamps, + "timestamps_output.json") + + # Collect and print fetch statistics + try: + if hasattr(llm, '_executor'): + executor = llm._executor + if mpi_disabled(): + # Ray mode + worker = executor.workers[0] + stats = worker.call_worker_method.remote( + 'get_fetch_statistics') + result = ray.get(stats) + if result: + print_fetch_statistics( + result['num_fetched_requests'], + result['fetch_call_count'], + rank=result['rank']) + else: + # MPI mode + stats = executor.get_fetch_statistics() + if stats: + print_fetch_statistics( + stats['num_fetched_requests'], + stats['fetch_call_count'], + rank=stats['rank']) + except Exception as e: + logger.debug(f"Could not retrieve fetch statistics: {e}") + + # Print enqueue timing statistics + try: + if hasattr(llm, '_executor'): + executor = llm._executor + if hasattr(executor, + 'enqueue_timings') and executor.enqueue_timings: + print_enqueue_statistics(executor.enqueue_timings) + except Exception as e: + logger.debug(f"Could not retrieve enqueue statistics: {e}") except KeyboardInterrupt: logger.info("Keyboard interrupt, exiting benchmark...") except Exception: diff --git a/tensorrt_llm/bench/benchmark/utils/asynchronous.py b/tensorrt_llm/bench/benchmark/utils/asynchronous.py index 68f3323fab8..de0bc578e11 100644 --- a/tensorrt_llm/bench/benchmark/utils/asynchronous.py +++ b/tensorrt_llm/bench/benchmark/utils/asynchronous.py @@ -44,6 +44,7 @@ def __init__(self, self.streaming = streaming self.request_seen = asyncio.Event() self.modality = modality + self.all_timestamps: List = [] # Collect timestamps for analysis async def process_request(self, request: InferenceRequest, sampling_params: SamplingParams, @@ -74,6 +75,10 @@ async def process_request(self, request: InferenceRequest, response_end_timestamp = time.perf_counter_ns() + # Collect timestamps for detailed latency analysis + if response.outputs and response.outputs[0].timestamps: + self.all_timestamps.append(response.outputs[0].timestamps) + # Mark that the response returned. Construct a record to send to statistics. tokens = list(chain(*(beam.token_ids for beam in response.outputs))) request_perf_item = PerfItemTuple( @@ -292,6 +297,8 @@ async def async_benchmark( assert finished_requests == len(requests), "Benchmark failed" logger.info("Benchmark complete.") + # Attach collected timestamps to statistics for optional analysis + statistics.all_timestamps = backend.all_timestamps return statistics finally: diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index a170d75e486..d1db0922d9e 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -2,6 +2,7 @@ import datetime import enum import json +import time import weakref from pathlib import Path from queue import Queue @@ -35,6 +36,7 @@ compute_logprobs) from .utils import (ErrorResponse, IntraProcessQueue, RequestError, is_llm_response) +from tensorrt_llm._utils import nvtx_range __all__ = [ "BaseWorker", @@ -305,6 +307,7 @@ def _load_prompt_adapter(self, model_config=self._runtime_model_config, uids=[str(prompt_adapter_request.adapter_id)]) + @nvtx_range("base_worker.enqueue_request") def _enqueue_request(self, request: GenerationRequest, result_wait_queue=None) -> int: @@ -502,6 +505,13 @@ def _deduce_max_tokens(request: GenerationRequest, if request.arrival_time is not None: executor_request.py_arrival_time = request.arrival_time + if self._is_pytorch_backend and hasattr( + request, 'timestamps') and request.timestamps is not None: + executor_request.py_timestamps = request.timestamps + + if request.timestamps is not None: + request.timestamps['worker_enqueue_request'] = time.time() + if request.query_token_ids is not None: # pytorch star attention workflow # a workaround to avoid public interface update @@ -550,7 +560,7 @@ def submit(self, request: GenerationRequest) -> GenerationResult: self._results[client_id] = result request_id = self._enqueue_request(request) - # request_id returned from backend is necessary for the abort_request method. + self._client_id_to_request_id[client_id] = request_id self._handle_background_error() @@ -593,6 +603,16 @@ def _pop_result(self, client_id: int): self._results.pop(client_id, None) self._client_id_to_request_id.pop(client_id, None) + def get_fetch_statistics(self): + if hasattr(self.engine, 'num_fetched_requests') and hasattr( + self.engine, 'fetch_call_count'): + return { + 'num_fetched_requests': self.engine.num_fetched_requests, + 'fetch_call_count': self.engine.fetch_call_count, + 'rank': self.rank + } + return None + def __enter__(self): return self @@ -826,6 +846,13 @@ def _send_rsp( # if postproc_batches is set, append to batch instead of putting to IpcQueue if worker.result_queue is not None: + if hasattr(response, 'timestamps') and response.timestamps: + response.timestamps['response_enqueued'] = time.time() + elif isinstance(response, ResponseWrapper) and hasattr( + response._response, + 'timestamps') and response._response.timestamps: + response._response.timestamps['response_enqueued'] = time.time() + if rsp_batch is not None: rsp_batch.append(response) else: @@ -845,6 +872,13 @@ def _send_rsp( pid = response.client_id % worker.postproc_config.num_postprocess_workers + if hasattr(response, 'timestamps') and response.timestamps: + response.timestamps['response_enqueued'] = time.time() + elif isinstance(response, ResponseWrapper) and hasattr( + response._response, + 'timestamps') and response._response.timestamps: + response._response.timestamps['response_enqueued'] = time.time() + if not postproc_batches: # Group the responses into buckets for the postprocessing steps. # Bucketing is used instead of random dispatching because the diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 13ff28023ef..4f142f43ff1 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -11,7 +11,7 @@ from tensorrt_llm.logger import logger -from .._utils import customized_gc_thresholds, mpi_rank, nvtx_range_debug +from .._utils import customized_gc_thresholds, mpi_rank, nvtx_range_debug, nvtx_range from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession, RemoteMpiCommSessionClient) from ..llmapi.tracer import enable_llm_tracer, get_tracer, global_tracer @@ -85,6 +85,9 @@ def __init__( self.model_world_size = model_world_size + # Track enqueue timings for analysis + self.enqueue_timings = [] + self.garbage_collection_gen0_threshold = worker_kwargs[ "llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get( "llm_args", None) is not None else None @@ -414,6 +417,7 @@ def shutdown(self): if enable_llm_debug(): print_alive_threads() + @nvtx_range("proxy.submit") def submit(self, request: GenerationRequest) -> GenerationResult: """ Low-level API to the executor. Return a "future" GenerationResult @@ -434,8 +438,17 @@ def submit(self, request: GenerationRequest) -> GenerationResult: logprob_params=logprob_params) self._results[request.id] = result + if request.timestamps is not None: + request.timestamps['executor_submit_request'] = time.time() + + enqueue_start = time.perf_counter() with nvtx_range_debug("request_queue.put"): self.request_queue.put(request) + enqueue_elapsed = (time.perf_counter() - enqueue_start) * 1000 + + from tensorrt_llm._tmp_utils import is_timestamp_debug_enabled + if is_timestamp_debug_enabled(): + self.enqueue_timings.append(enqueue_elapsed) self._handle_background_error() diff --git a/tensorrt_llm/executor/ray_executor.py b/tensorrt_llm/executor/ray_executor.py index 5d87fdc9bfc..63f1104cd2d 100644 --- a/tensorrt_llm/executor/ray_executor.py +++ b/tensorrt_llm/executor/ray_executor.py @@ -7,21 +7,26 @@ e.msg = """Cannot import Ray. Please install 'ray' package to use ray orchestrator""" raise +import time + from ray.util.placement_group import (PlacementGroup, PlacementGroupSchedulingStrategy, get_current_placement_group, placement_group) from tensorrt_llm._ray_utils import unwrap_ray_errors +from tensorrt_llm._tmp_utils import is_timestamp_debug_enabled from tensorrt_llm._utils import get_free_port from tensorrt_llm.logger import logger from .._utils import nvtx_range_debug from .executor import GenerationExecutor +from .ipc import IpcQueue from .postproc_worker import PostprocWorkerConfig from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper from .request import GenerationRequest from .result import GenerationResult, RayAsyncQueue, RaySyncQueue +from tensorrt_llm._utils import nvtx_range __all__ = [ "RayExecutor", @@ -93,6 +98,13 @@ def __init__(self, self.response_queue.warmup.remote() self.response_sync_queue.warmup.remote() + # Setup IPC queue for request passing if enabled + self.request_queue = None + if self._use_ipc_queue(): + print("==== Use IPC queue ====") + self.request_queue = IpcQueue(is_server=True, + name="ray_request_queue") + worker_kwargs = dict(**worker_kwargs, postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor) @@ -113,6 +125,10 @@ def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle): def use_ray_queue(self) -> bool: return True + @staticmethod + def _use_ipc_queue() -> bool: + return os.environ.get("RAY_DEBUG_USE_IPC", "0") == "1" + def create_workers(self, worker_cls, worker_kwargs): # When set to be a fraction, it allows Ray to schedule # multiple actors on a single GPU for colocate use cases. @@ -126,10 +142,21 @@ def create_workers(self, worker_cls, worker_kwargs): "MASTER_ADDR": self.master_address, # head-IP for NCCL/Gloo "MASTER_PORT": str(self.master_port) }) + + # Pass IPC queue address to workers if using IPC mode + if self._use_ipc_queue() and self.request_queue is not None: + runtime_env["env_vars"]["RAY_IPC_REQUEST_QUEUE_ADDR"] = self.request_queue.address[0] + print(f"==== IPC queue address: {self.request_queue.address[0]} ====") + if self.request_queue.address[1] is not None: + # Pass HMAC key as hex string + runtime_env["env_vars"]["RAY_IPC_REQUEST_QUEUE_KEY"] = self.request_queue.address[1].hex() + print(f"==== IPC queue key: {self.request_queue.address[1].hex()} ====") self.placement_group, self.bundle_indices = self._get_placement_group( tp_size=self.tp_size) + self.enqueue_timings = [] + self.workers = [ RayWorkerWrapper.options( num_gpus=num_gpus, @@ -188,6 +215,7 @@ def collective_rpc(self, **kwargs)) return refs if non_block else ray.get(refs) + @nvtx_range("ray_executor.submit") def submit(self, request: GenerationRequest) -> GenerationResult: """ Low-level API to the executor. Return a "future" GenerationResult @@ -204,12 +232,30 @@ def submit(self, request: GenerationRequest) -> GenerationResult: disaggregated_params=request.disaggregated_params, logprob_params=logprob_params) - with nvtx_range_debug("request_queue.put"): - self.call_all_ray_workers("enqueue_request", - leader_only=True, - request=request, - async_call=False, - result_wait_queue=result.queue) + if request.timestamps is not None: + request.timestamps['executor_submit_request'] = time.time() + + enqueue_start = time.perf_counter() + + if self._use_ipc_queue(): + # Use IPC queue path (similar to MPI) + with nvtx_range_debug("request_queue.put"): + # Store result queue in request for worker to access + request._result_queue = result.queue + self.request_queue.put(request) + else: + # Use original Ray RPC path + with nvtx_range_debug("request_queue.put"): + self.call_all_ray_workers("enqueue_request", + leader_only=True, + request=request, + async_call=True, + result_wait_queue=result.queue) + + enqueue_elapsed = (time.perf_counter() - enqueue_start) * 1000 + + if is_timestamp_debug_enabled(): + self.enqueue_timings.append(enqueue_elapsed) return result @@ -226,7 +272,11 @@ def abort_request(self, request_id: int) -> None: request_id=request_id) def shutdown(self): - # Release actors + # Close IPC queue if it was created + if hasattr(self, 'request_queue') and self.request_queue is not None: + self.request_queue.close() + self.request_queue = None + self.response_queue = None self.response_sync_queue = None self.async_response_queue_weakref = None diff --git a/tensorrt_llm/executor/ray_gpu_worker.py b/tensorrt_llm/executor/ray_gpu_worker.py index 8b11cfc0de1..f1477b18ea8 100644 --- a/tensorrt_llm/executor/ray_gpu_worker.py +++ b/tensorrt_llm/executor/ray_gpu_worker.py @@ -13,9 +13,12 @@ from ..llmapi.tokenizer import TokenizerBase from ..sampling_params import BatchedLogitsProcessor from .base_worker import BaseWorker +from .ipc import IpcQueue from .postproc_worker import PostprocWorkerConfig -from .request import GenerationRequest +from .request import GenerationRequest, CancellingRequest from .result import GenerationResult +from .utils import RequestError +from tensorrt_llm._utils import nvtx_range __all__ = [ "RayGPUWorker", @@ -183,6 +186,12 @@ def __init__( if self.global_rank > 1: logger.set_rank(self.global_rank) + # Setup IPC queue for request reading if enabled (leader only) + self.request_queue = None + self.request_reader_thread = None + if self._use_ipc_queue() and self.global_rank == 0: + self._setup_ipc_queue() + self.setup_engine() def _get_comm_ranks_device_id(self): @@ -196,6 +205,86 @@ def _get_comm_ranks_device_id(self): torch.distributed.all_gather_object(device_ids, self.device_id) return comm_ranks, device_ids + @staticmethod + def _use_ipc_queue() -> bool: + """Check if IPC queue should be used instead of Ray RPC for enqueue.""" + return os.environ.get("RAY_DEBUG_USE_IPC", "0") == "1" + + def _setup_ipc_queue(self): + """Setup IPC queue client connection to receive requests.""" + queue_addr = os.environ.get("RAY_IPC_REQUEST_QUEUE_ADDR") + queue_key_hex = os.environ.get("RAY_IPC_REQUEST_QUEUE_KEY") + + # print(f"===Setting up IPC queue on Worker {self.global_rank}: Queue Address: {queue_addr}, Queue Key: {queue_key_hex}===") + + if queue_addr is None: + raise RuntimeError("RAY_DEBUG_USE_IPC=1 but RAY_IPC_REQUEST_QUEUE_ADDR not set") + + # Reconstruct the HMAC key from hex if present + queue_key = bytes.fromhex(queue_key_hex) if queue_key_hex else None + + self.request_queue = IpcQueue( + address=(queue_addr, queue_key), + is_server=False, + name=f"ray_worker_{self.global_rank}_request_queue" + ) + + # Start thread to read from queue - using regular Thread like MPI path + # (not ManagedThread to reduce overhead) + import threading + self.request_reader_thread = threading.Thread( + target=self._request_reader_task, + daemon=True, + name=f"ray_worker_{self.global_rank}_request_reader" + ) + self.request_reader_thread.start() + + def _request_reader_task(self): + """Thread task to read requests from IPC queue. + + EXACT MPI REPLICATION: Pure tight blocking loop, no batching, no logic. + Just continuously drain IPC → enqueue to ExecutorRequestQueue. + """ + import time + try: + logger.info(f"[Rank {self.global_rank}] Starting IPC queue reader (pure MPI replication)") + + # Instrumentation to measure bottleneck + drain_count = 0 + drain_start = None + + # EXACT copy of MPI worker.py:409-421 + while (req := self.request_queue.get()) is not None: + if drain_start is None: + drain_start = time.perf_counter() + + drain_count += 1 + + if isinstance(req, CancellingRequest): + self.abort_request(req.id) + elif isinstance(req, GenerationRequest): + try: + result_wait_queue = getattr(req, '_result_queue', None) + self._enqueue_request(req, result_wait_queue) + except RequestError as e: + logger.error(f"[Rank {self.global_rank}] enqueue_request failed: {e}") + else: + logger.error(f"[Rank {self.global_rank}] Unknown request type: {type(req)}") + + # Log every 100 requests to see drain rate + if drain_count % 100 == 0: + elapsed = time.perf_counter() - drain_start + rate = drain_count / elapsed + logger.info(f"[Rank {self.global_rank}] IPC reader drained {drain_count} requests in {elapsed:.3f}s ({rate:.1f} req/s)") + + logger.info(f"[Rank {self.global_rank}] Received None from IPC queue, stopping reader thread") + except Exception as e: + logger.error(f"[Rank {self.global_rank}] Error in request reader task: {e}") + import traceback + logger.error(traceback.format_exc()) + raise + + @nvtx_range("ray_gpu_worker.enqueue_request") def enqueue_request(self, request: GenerationRequest, result_wait_queue: Queue | None = None) -> int: @@ -214,6 +303,18 @@ def shutdown(self): logger.debug(f'Worker {self.rank} shutting down...') + # Close IPC queue if it exists + # This will cause the reader thread's blocking get() to return None or raise exception + if self.request_queue is not None: + self.request_queue.close() + self.request_queue = None + + # Wait for reader thread to finish + if self.request_reader_thread is not None and self.request_reader_thread.is_alive(): + logger.info(f"[Rank {self.global_rank}] Waiting for IPC queue reader thread to finish") + self.request_reader_thread.join(timeout=5.0) + self.request_reader_thread = None + if self.engine is not None: self.engine.shutdown() self.engine = None diff --git a/tensorrt_llm/executor/request.py b/tensorrt_llm/executor/request.py index c7c7abea74f..40a84911465 100644 --- a/tensorrt_llm/executor/request.py +++ b/tensorrt_llm/executor/request.py @@ -8,6 +8,7 @@ from tensorrt_llm.inputs.multimodal import MultimodalParams +from .._tmp_utils import is_timestamp_debug_enabled from ..disaggregated_params import DisaggregatedParams from ..llmapi.llm_utils import KvCacheRetentionConfig from ..sampling_params import SamplingParams @@ -130,6 +131,18 @@ def __init__( self.cache_salt_id = cache_salt_id self.arrival_time = arrival_time + if is_timestamp_debug_enabled(): + self.timestamps = { + 'scheduling_wait_time': 0.0, + 'pre_forward_overhead': 0.0, + 'forward_step_time': 0.0, + 'post_processing_time': 0.0, + 'num_iterations': 0, + 'last_iteration_end': None, + } + else: + self.timestamps = None + def set_id(self, id): assert self.id is None, f"Request ID is already set: {self.id}" self.id = id diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 195255775fc..28d2214c15d 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -19,6 +19,8 @@ except ModuleNotFoundError: from tensorrt_llm import ray_stub as ray +from tensorrt_llm._tmp_utils import is_timestamp_debug_enabled + from .._ray_utils import unwrap_ray_errors from .._utils import mpi_disabled, nvtx_range_debug from ..bindings import executor as tllm @@ -28,6 +30,7 @@ from ..metrics import MetricNames, MetricsCollector, RequestEventTiming from ..sampling_params import LogprobParams, SamplingParams from .utils import ErrorResponse, has_event_loop, is_llm_response +from tensorrt_llm._utils import nvtx_range if TYPE_CHECKING: from .executor import GenerationExecutor @@ -132,6 +135,8 @@ class CompletionOutput: additional_generation_outputs: Optional[Dict[str, torch.Tensor]] = None disaggregated_params: Optional[DisaggregatedParams] = None request_perf_metrics: Optional[tllm.RequestPerfMetrics] = None + timestamps: Optional[Dict[str, float]] = field( + default_factory=lambda: {} if is_timestamp_debug_enabled() else None) # hidden fields for tracking the diffs _last_text_len: int = field(default=0, init=False, repr=False) @@ -280,8 +285,9 @@ def __init__(self, else: self.queue = ray_queue self.aqueue = None - with unwrap_ray_errors(): - ray.get(self.queue.register.remote(id)) + with nvtx_range("result.ray_queue.register_id"): + with unwrap_ray_errors(): + ray.get(self.queue.register.remote(id)) else: if has_event_loop(): self.aqueue = AsyncQueue() @@ -519,6 +525,10 @@ def _handle_response(self, response_result.sequence_index, logprobs_result, req_perf_metrics_dict) + if hasattr(response, 'timestamps') and response.timestamps: + response.timestamps['handle_response'] = time.time() + self._outputs[0].timestamps = response.timestamps + if response_result.context_logits is not None: self._context_logits = response_result.context_logits @@ -845,6 +855,7 @@ def clear_logprob_params(self) -> None: def _handle_ray_response(self, response: Any): return response + @nvtx_range("result._result_step") def _result_step(self, timeout: Optional[float] = None): if mpi_disabled(): with unwrap_ray_errors():