From 59d2bf2726bd80a3df4dcab75a4f23542eef6986 Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Thu, 18 Dec 2025 04:56:10 -0600 Subject: [PATCH 1/7] Reorganize code to slice tensors in context class --- .../inference/contexts/dynamic_context.py | 22 +++++++ .../text_generation_controller.py | 60 +++++++------------ 2 files changed, 43 insertions(+), 39 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 0c172d0b47f..1b994ae3079 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -758,6 +758,17 @@ def initialize_all_tensors(self) -> None: self.token_to_position_in_request = torch.empty_like(self.token_to_input_ids) self.token_to_local_position_within_kv_block = torch.empty_like(self.token_to_input_ids) + # Static tensor addresses of active slices to enable fast inference kernels. + self.active_request_metadata: Dict[str, Tensor] = {} + for label, _, on_gpu in self.request_metadata_types: + if on_gpu: + tensor = torch.empty_like(self.request_metadata[label]) + else: + tensor = torch.empty_like( + self.request_metadata[label], device="cpu", pin_memory=True + ) + self.active_request_metadata[label] = tensor + # NOTE: Need to build this outside the UVM / TMS context to avoid IMA. if self.is_hybrid_model: self.mamba_metadata = MambaMetadata( @@ -952,6 +963,17 @@ def get_active_request_count(self): """Returns the current number of active requests.""" return self.total_request_count - self.paused_request_count + def build_active_slices(self): + """Build the active slices of specific tensors. This is run on every forward step.""" + active_slice = slice(self.paused_request_count, self.total_request_count) + batch_size = self.total_request_count - self.paused_request_count + + # Request metadata all needs to be sliced. + for label, _, _ in self.request_metadata_types: + self.active_request_metadata[label][:batch_size].copy_( + self.request_metadata[label][active_slice], non_blocking=True + ) + def append_key_value_cache(self, layer_number: int, key: Tensor, value: Tensor) -> None: """Append to KV cache. diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 61272c07b36..810fad76639 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -125,16 +125,6 @@ def _init_dynamic_sampling_tensors(self): # Last accepted sequence indices for serial MTP computation self._last_accepted_seq_indices = None - # Keep track of request metadata. - self._request_metadata: Dict[str, Tensor] = {} - for label, dtype, on_gpu in context.request_metadata_types: - tensor = context.request_metadata[label] - if not on_gpu: - # Create pinned tensors for request metadata that lives on CPU. - # This is metadata which requires D2H copies, such as top_k for torch sampling. - tensor = torch.empty_like(tensor, device="cpu", pin_memory=True) - self._request_metadata[label] = tensor - # Used for inefficient torch sampling. if self._sampling_backend == "torch": self._torch_sampling_buckets: List[Tuple] = [] @@ -556,12 +546,14 @@ def _dynamic_step_context_init( position_ids (Tensor): The active position IDs. """ context = self.inference_wrapped_model.inference_context - active_request_slice = slice(context.paused_request_count, context.total_request_count) # Remove Float16Module wrapper if it exists unwrapped_model = unwrap_model(self.inference_wrapped_model.model) model_config = get_model_config(unwrapped_model) + # Build active slices of all relevant tensors. + context.build_active_slices() + # Initialize attention state. context.initialize_attention_state( construct_graph_dimensions=construct_graph_dimensions, @@ -597,14 +589,6 @@ def _dynamic_step_context_init( # Turn off symmetric all reduces for prefill unwrapped_model.set_symmetric_ar(None) - # Get request metadata for this step. - for label, dtype, on_gpu in context.request_metadata_types: - if not on_gpu: - # We need a D2H copy from the context to the pinned memory buffer. - self._request_metadata[label].copy_( - context.request_metadata[label], non_blocking=True - ) - # Get flat tokens, position ids. # If we are running a dummy forward step we want to use the token count agreed upon # by all EP ranks rather than the minimum number of tokens. @@ -660,7 +644,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) def _dynamic_step_sample_bookkeeping(self): """Perform bookkeeping necessary to sample logits for dynamic batching.""" context = self.inference_wrapped_model.inference_context - active_request_slice = slice(context.paused_request_count, context.total_request_count) + active_request_count = context.total_request_count - context.paused_request_count if self._sampling_backend == "torch": # Bucketize the core sampling parameters. @@ -668,9 +652,9 @@ def _dynamic_step_sample_bookkeeping(self): bucket_map = defaultdict(list) # Shorthands for the dictionary comprehension. - temp = self._request_metadata["temperature"][active_request_slice].tolist() - top_k = self._request_metadata["top_k"][active_request_slice].tolist() - top_p = self._request_metadata["top_p"][active_request_slice].tolist() + temp = context.active_request_metadata["temperature"][:active_request_count].tolist() + top_k = context.active_request_metadata["top_k"][:active_request_count].tolist() + top_p = context.active_request_metadata["top_p"][:active_request_count].tolist() for request_index, (t, k, p) in enumerate(zip(temp, top_k, top_p)): sampling_params = (t, k, p) @@ -1201,12 +1185,12 @@ def _dynamic_step_log_probs_bookkeeping(self) -> Tuple[bool, bool]: return_log_probs (bool): Whether to return the sampled log_probs. """ context = self.inference_wrapped_model.inference_context - active_request_slice = slice(context.paused_request_count, context.total_request_count) - - return_log_probs = self._request_metadata["return_log_probs"][active_request_slice] - top_n_log_probs = self._request_metadata["top_n_logprobs"][active_request_slice] > 0 + active_request_count = context.total_request_count - context.paused_request_count - return return_log_probs.any(), top_n_log_probs.any() + return ( + (context.active_request_metadata["return_log_probs"][:active_request_count]).any(), + (context.active_request_metadata["top_n_logprobs"][:active_request_count] > 0).any(), + ) def _router_record_bookkeeping(self) -> Optional[Dict[int, Tensor]]: """Collect and map routing indices per request for MoE router recording. @@ -1416,7 +1400,7 @@ def _dynamic_step_calculate_top_n_logprobs_speculative( num_decode_requests, self.num_speculative_tokens + 1, -1 ) accepted_counts = self._accepted_token_counts_per_request[:num_decode_requests] - top_n_per_request = self._request_metadata["top_n_logprobs"][active_request_slice][ + top_n_per_request = context.active_request_metadata["top_n_logprobs"][ :num_decode_requests ] max_top_n = int(top_n_per_request.max().item()) @@ -1450,12 +1434,12 @@ def _dynamic_step_calculate_top_n_logprobs_speculative( for i in range(num_prefill_requests): req_idx = num_decode_requests + i top_n = int( - self._request_metadata["top_n_logprobs"][active_request_slice][req_idx].item() + context.active_request_metadata["top_n_logprobs"][req_idx].item() ) if top_n > 0: request_lp = prefill_log_probs_per_request[i] skip_prompt = bool( - self._request_metadata["skip_prompt_log_probs"][req_idx].item() + context.active_request_metadata["skip_prompt_log_probs"][req_idx].item() ) if skip_prompt and request_lp.size(0) > 1: @@ -1506,9 +1490,7 @@ def _dynamic_step_calculate_top_n_logprobs( top_n_results = {} for req_idx in range(active_request_count): - top_n = int( - self._request_metadata["top_n_logprobs"][active_request_slice][req_idx].item() - ) + top_n = int(context.active_request_metadata["top_n_logprobs"][req_idx].item()) if top_n > 0: # Get top-n logprobs and indices for this request (single token) top_n_logits = torch.topk(log_probs[req_idx], k=top_n) @@ -1530,14 +1512,14 @@ def _dynamic_step_calculate_top_n_logprobs( top_n_results = {} for req_idx in range(active_request_count): - top_n = int( - self._request_metadata["top_n_logprobs"][active_request_slice][req_idx].item() - ) + top_n = int(context.active_request_metadata["top_n_logprobs"][req_idx].item()) if top_n > 0: request_log_probs = log_probs_per_request[ req_idx ] # [num_tokens_for_request, vocab_size] - skip_prompt = bool(self._request_metadata["skip_prompt_log_probs"][req_idx].item()) + skip_prompt = bool( + context.active_request_metadata["skip_prompt_log_probs"][req_idx].item() + ) # If skip_prompt_log_probs is True, only compute for last token if skip_prompt and request_log_probs.size(0) > 1: @@ -1698,7 +1680,7 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: # Note: termination_id tensor has per-request termination IDs from mixed sampling active_request_mask = ( self._sampled_tokens_cuda[:active_request_count] - != self._request_metadata["termination_id"][active_request_slice] + != context.active_request_metadata["termination_id"][:active_request_count] ).byte() & torch.less(active_sequence_lengths, max_sequence_lengths).byte() # Mark requests as finished if they hit stop words (detected in previous step's post_process_requests) From 5c15f2d3a80d7d704eb0aaac6eba9f825382d612 Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Thu, 18 Dec 2025 05:28:42 -0600 Subject: [PATCH 2/7] Slice additional tensors in context class --- .../inference/contexts/dynamic_context.py | 63 ++++++++++++------- .../text_generation_controller.py | 50 ++++++--------- 2 files changed, 59 insertions(+), 54 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 1b994ae3079..6ed7b7f77ff 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -769,6 +769,15 @@ def initialize_all_tensors(self) -> None: ) self.active_request_metadata[label] = tensor + self.active_request_ids = torch.empty_like(self.request_ids, dtype=torch.int64) + self.active_request_query_lengths = torch.empty_like(self.request_query_lengths) + self.active_request_output_lengths = torch.empty_like(self.request_output_lengths) + self.active_request_kv_length_offsets = torch.empty_like(self.request_kv_length_offsets) + self.active_request_to_kv_block_ids = torch.empty_like(self.request_to_kv_block_ids) + + self.active_sequence_lengths = torch.empty_like(self.request_query_lengths) + self.active_request_last_token_idxs = torch.empty_like(self.request_query_lengths) + # NOTE: Need to build this outside the UVM / TMS context to avoid IMA. if self.is_hybrid_model: self.mamba_metadata = MambaMetadata( @@ -949,24 +958,13 @@ def cu_kv_lengths(self) -> Tuple[Tensor, Tensor, int]: self.active_attn_metadata["mha_metadata"].state_data["max_seqlen_k"], ) - def get_active_sequence_lengths(self) -> Tensor: - """Total sequence length (query + key) for active requests.""" - lengths = self.request_kv_length_offsets + self.request_query_lengths - lengths = lengths[self.paused_request_count : self.total_request_count] - return lengths - - def get_max_sequence_lengths(self) -> Tensor: - """Maximum sequence length for active requests.""" - return self.request_output_lengths[self.paused_request_count : self.total_request_count] - def get_active_request_count(self): """Returns the current number of active requests.""" return self.total_request_count - self.paused_request_count - def build_active_slices(self): + def build_active_slices(self, batch_size: int): """Build the active slices of specific tensors. This is run on every forward step.""" active_slice = slice(self.paused_request_count, self.total_request_count) - batch_size = self.total_request_count - self.paused_request_count # Request metadata all needs to be sliced. for label, _, _ in self.request_metadata_types: @@ -974,6 +972,27 @@ def build_active_slices(self): self.request_metadata[label][active_slice], non_blocking=True ) + # The following tensor slices are used in various kernels. + self.active_request_ids[:batch_size].copy_(self.request_ids[active_slice]) + self.active_request_query_lengths[:batch_size].copy_( + self.request_query_lengths[active_slice] + ) + self.active_request_output_lengths[:batch_size].copy_( + self.request_output_lengths[active_slice] + ) + self.active_request_kv_length_offsets[:batch_size].copy_( + self.request_kv_length_offsets[active_slice] + ) + self.active_request_to_kv_block_ids[:batch_size].copy_( + self.request_to_kv_block_ids[active_slice] + ) + + self.active_sequence_lengths[:batch_size].copy_( + (self.active_request_query_lengths + self.active_request_kv_length_offsets)[:batch_size] + ) + graph_scratch_space = torch.cumsum(self.active_request_query_lengths[:batch_size], dim=0) + self.active_request_last_token_idxs[:batch_size].copy_(graph_scratch_space - 1) + def append_key_value_cache(self, layer_number: int, key: Tensor, value: Tensor) -> None: """Append to KV cache. @@ -1609,11 +1628,9 @@ def initialize_attention_state( else self.non_graph_attn_metadata # type: ignore[assignment] ) - # Update cu_query_seq_lengths, max_seqlen_q. - active_slice = slice(self.paused_request_count, self.total_request_count) - query_lengths_view = self.request_query_lengths[active_slice] - request_kv_length_offsets_view = self.request_kv_length_offsets[active_slice] - request_to_kv_block_ids_view = self.request_to_kv_block_ids[active_slice] + # Build active slices of all relevant tensors. + batch_size = self.total_request_count - self.paused_request_count + self.build_active_slices(batch_size) attn_dimensions = batch_dimensions if self.using_cuda_graph_this_step(): @@ -1630,15 +1647,16 @@ def initialize_attention_state( assert self.active_attn_metadata is not None self.active_attn_metadata["mha_metadata"].update( - request_query_lengths=query_lengths_view, - request_kv_length_offsets=request_kv_length_offsets_view, - request_to_kv_block_ids=request_to_kv_block_ids_view, + request_query_lengths=self.active_request_query_lengths[:batch_size], + request_kv_length_offsets=self.active_request_kv_length_offsets[:batch_size], + request_to_kv_block_ids=self.active_request_to_kv_block_ids[:batch_size], batch_dimensions=attn_dimensions, padded_batch_dimensions=self.padded_batch_dimensions, num_speculative_tokens=self.num_speculative_tokens, ) if self.is_hybrid_model: + active_slice = slice(self.paused_request_count, self.total_request_count) active_mamba_indices_view = self.mamba_metadata.request_to_mamba_state_idx[active_slice] token_to_request_idx_view = self.token_to_request_idx[: self.active_token_count] cu_seqlens = self.active_attn_metadata["mha_metadata"].state_data[ @@ -1796,9 +1814,8 @@ def last_token_logits(self, logits: Tensor) -> Tensor: f"logits.size(1) ({tuple(logits.shape)}) != " f"padded_active_token_count ({self.padded_active_token_count})." ) - logits_2d = logits.squeeze(0) - last_token_idxs = torch.cumsum(query_lengths, dim=0) - 1 - return logits_2d[last_token_idxs, :] + active_request_count = self.total_request_count - self.paused_request_count + return logits.squeeze(0)[self.active_request_last_token_idxs[:active_request_count], :] def _compute_prefix_match( self, req: DynamicInferenceRequest, prefill_chunk_length: int diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 810fad76639..91632291606 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -551,9 +551,6 @@ def _dynamic_step_context_init( unwrapped_model = unwrap_model(self.inference_wrapped_model.model) model_config = get_model_config(unwrapped_model) - # Build active slices of all relevant tensors. - context.build_active_slices() - # Initialize attention state. context.initialize_attention_state( construct_graph_dimensions=construct_graph_dimensions, @@ -1066,9 +1063,7 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_id request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ context.paused_request_count : context.total_request_count ] - request_query_lengths = context.request_query_lengths[ - context.paused_request_count : context.total_request_count - ] + request_query_lengths = context.active_request_query_lengths[:active_request_count] num_prefill_requests = request_in_prefill_status_tensor.sum().item() num_decode_requests = active_request_count - num_prefill_requests @@ -1220,9 +1215,9 @@ def _router_record_bookkeeping(self) -> Optional[Dict[int, Tensor]]: return None # Get active request info from context - active_request_slice = slice(context.paused_request_count, context.total_request_count) - active_request_ids = context.request_ids[active_request_slice].tolist() - active_query_lengths = context.request_query_lengths[active_request_slice].tolist() + active_request_count = context.total_request_count - context.paused_request_count + active_request_ids = context.active_request_ids[:active_request_count].tolist() + active_query_lengths = context.active_request_query_lengths[:active_request_count].tolist() active_token_count = context.active_token_count # Get TP group for all-gather if using sequence parallelism @@ -1291,9 +1286,7 @@ def _dynamic_step_calculate_log_probs_speculative( request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ context.paused_request_count : context.total_request_count ] - request_query_lengths = context.request_query_lengths[ - context.paused_request_count : context.total_request_count - ] + request_query_lengths = context.active_request_query_lengths[:active_request_count] num_prefill_requests = request_in_prefill_status_tensor.sum().item() num_decode_requests = active_request_count - num_prefill_requests @@ -1380,14 +1373,11 @@ def _dynamic_step_calculate_top_n_logprobs_speculative( """ context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count - active_request_slice = slice(context.paused_request_count, context.total_request_count) request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ context.paused_request_count : context.total_request_count ] - request_query_lengths = context.request_query_lengths[ - context.paused_request_count : context.total_request_count - ] + request_query_lengths = context.active_request_query_lengths[:active_request_count] num_prefill_requests = request_in_prefill_status_tensor.sum().item() num_decode_requests = active_request_count - num_prefill_requests @@ -1480,7 +1470,6 @@ def _dynamic_step_calculate_top_n_logprobs( context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count - active_request_slice = slice(context.paused_request_count, context.total_request_count) # Handle decode-only mode (only last token) if context.config.materialize_only_last_token_logits or context.is_decode_only(): @@ -1504,7 +1493,7 @@ def _dynamic_step_calculate_top_n_logprobs( # Note: logits may be padded, so we only take the first active_token_count tokens log_probs = log_probs_tensor[: context.active_token_count] - active_query_lengths = context.request_query_lengths[active_request_slice] + active_query_lengths = context.active_request_query_lengths[:active_request_count] # Split log_probs across request boundaries # log_probs has shape [active_token_count, vocab_size] @@ -1662,28 +1651,27 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: """ context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count - active_request_slice = slice(context.paused_request_count, context.total_request_count) - - # Active sequence lengths. - active_request_ids = context.request_ids[active_request_slice].long() - active_sequence_lengths = context.get_active_sequence_lengths() + # Request finished if termination_id or length >= max_sequence_length. + # Note: termination_id tensor has per-request termination IDs from mixed sampling if self.num_speculative_tokens > 0: - active_sequence_lengths += ( + seq_len_increment = ( self._accepted_token_counts_per_request[:active_request_count] + 1 ) else: - active_sequence_lengths += 1 - max_sequence_lengths = context.get_max_sequence_lengths() - - # Request finished if termination_id or length >= max_sequence_length. - # Note: termination_id tensor has per-request termination IDs from mixed sampling + seq_len_increment = 1 active_request_mask = ( self._sampled_tokens_cuda[:active_request_count] != context.active_request_metadata["termination_id"][:active_request_count] - ).byte() & torch.less(active_sequence_lengths, max_sequence_lengths).byte() + ).byte() & torch.less( + context.active_sequence_lengths[:active_request_count] + seq_len_increment, + context.active_request_output_lengths[:active_request_count], + ).byte() + + active_request_ids = context.active_request_ids[:active_request_count] - # Mark requests as finished if they hit stop words (detected in previous step's post_process_requests) + # Mark requests as finished if they hit stop words + # (detected in previous step's post_process_requests) if self._get_stop_word_finished_ids_callback is not None: request_ids_list = active_request_ids.tolist() stop_word_finished_ids = self._get_stop_word_finished_ids_callback(request_ids_list) From 6f9b3b8a8cf96044c99487de09115d6eaec83148 Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Thu, 18 Dec 2025 05:44:39 -0600 Subject: [PATCH 3/7] Slice tensors by padded_active_request_count --- .../inference/contexts/dynamic_context.py | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 6ed7b7f77ff..37794840c05 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -963,28 +963,41 @@ def get_active_request_count(self): return self.total_request_count - self.paused_request_count def build_active_slices(self, batch_size: int): - """Build the active slices of specific tensors. This is run on every forward step.""" - active_slice = slice(self.paused_request_count, self.total_request_count) + """Build the active slices of specific tensors. This is run on every forward step. + + If the context is reordered to active -> paused -> finished, this can be graphed. + """ + padded_slice = slice(self.paused_request_count, self.paused_request_count + batch_size) # Request metadata all needs to be sliced. for label, _, _ in self.request_metadata_types: self.active_request_metadata[label][:batch_size].copy_( - self.request_metadata[label][active_slice], non_blocking=True + self.request_metadata[label][padded_slice], non_blocking=True ) # The following tensor slices are used in various kernels. - self.active_request_ids[:batch_size].copy_(self.request_ids[active_slice]) + self.active_request_ids[:batch_size].copy_(self.request_ids[padded_slice]) self.active_request_query_lengths[:batch_size].copy_( - self.request_query_lengths[active_slice] + self.request_query_lengths[padded_slice] ) self.active_request_output_lengths[:batch_size].copy_( - self.request_output_lengths[active_slice] + self.request_output_lengths[padded_slice] ) self.active_request_kv_length_offsets[:batch_size].copy_( - self.request_kv_length_offsets[active_slice] + self.request_kv_length_offsets[padded_slice] ) self.active_request_to_kv_block_ids[:batch_size].copy_( - self.request_to_kv_block_ids[active_slice] + self.request_to_kv_block_ids[padded_slice] + ) + + self.active_request_output_lengths[:batch_size].copy_( + self.request_output_lengths[padded_slice] + ) + self.active_request_kv_length_offsets[:batch_size].copy_( + self.request_kv_length_offsets[padded_slice] + ) + self.active_request_to_kv_block_ids[:batch_size].copy_( + self.request_to_kv_block_ids[padded_slice] ) self.active_sequence_lengths[:batch_size].copy_( @@ -1611,6 +1624,9 @@ def initialize_attention_state( self.padded_active_request_count = self.padded_batch_dimensions.req_count self.padding_slice = slice(self.active_token_count, self.padded_active_token_count) + self.build_active_slices(self.padded_active_request_count) + batch_size = self.total_request_count - self.paused_request_count + # Update token position indexes. self.token_to_block_idx[self.active_token_count : self.padded_active_token_count] = ( self.kv_block_allocator.dummy_block_idx @@ -1628,10 +1644,6 @@ def initialize_attention_state( else self.non_graph_attn_metadata # type: ignore[assignment] ) - # Build active slices of all relevant tensors. - batch_size = self.total_request_count - self.paused_request_count - self.build_active_slices(batch_size) - attn_dimensions = batch_dimensions if self.using_cuda_graph_this_step(): # Treat some decode requests as prefill requests to fit the cuda graph batch dimension. From 6a2270235b08f875a59d3832d4eba5e0833ce3ce Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Thu, 25 Dec 2025 17:40:57 -0600 Subject: [PATCH 4/7] Move context tensor padding into dedicated method --- .../inference/contexts/dynamic_context.py | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 37794840c05..a369e14c586 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1006,6 +1006,21 @@ def build_active_slices(self, batch_size: int): graph_scratch_space = torch.cumsum(self.active_request_query_lengths[:batch_size], dim=0) self.active_request_last_token_idxs[:batch_size].copy_(graph_scratch_space - 1) + def pad_active_slices(self): + """Pad the active slices of specific tensors.""" + # Some tensors need to be padded at the token level. + padding_token_slice = slice(self.active_token_count, self.padded_active_token_count) + + self.token_to_block_idx[padding_token_slice] = self.kv_block_allocator.dummy_block_idx + self.token_to_local_position_within_kv_block[padding_token_slice] = 0 + self.token_to_position_in_request[padding_token_slice] = 0 + + # Other tensors need to be padded at the request level. + padding_request_slice = slice( + self.total_request_count - self.paused_request_count, + self.padded_active_request_count, + ) + def append_key_value_cache(self, layer_number: int, key: Tensor, value: Tensor) -> None: """Append to KV cache. @@ -1620,23 +1635,6 @@ def initialize_attention_state( prefill_req_count=padded_prefill_req_count, decode_req_count=padded_decode_req_count, ) - self.padded_active_token_count = self.padded_batch_dimensions.token_count - self.padded_active_request_count = self.padded_batch_dimensions.req_count - self.padding_slice = slice(self.active_token_count, self.padded_active_token_count) - - self.build_active_slices(self.padded_active_request_count) - batch_size = self.total_request_count - self.paused_request_count - - # Update token position indexes. - self.token_to_block_idx[self.active_token_count : self.padded_active_token_count] = ( - self.kv_block_allocator.dummy_block_idx - ) - self.token_to_local_position_within_kv_block[ - self.active_token_count : self.padded_active_token_count - ] = 0 - self.token_to_position_in_request[ - self.active_token_count : self.padded_active_token_count - ] = 0 self.active_attn_metadata = ( self.graph_attn_metadata # type: ignore[assignment] @@ -1657,6 +1655,14 @@ def initialize_attention_state( decode_req_count=adjusted_decode_req_count, ) + self.padded_active_token_count = self.padded_batch_dimensions.token_count + self.padded_active_request_count = self.padded_batch_dimensions.req_count + self.padding_slice = slice(self.active_token_count, self.padded_active_token_count) + + self.build_active_slices(self.padded_active_request_count) + self.pad_active_slices() + + batch_size = self.total_request_count - self.paused_request_count assert self.active_attn_metadata is not None self.active_attn_metadata["mha_metadata"].update( request_query_lengths=self.active_request_query_lengths[:batch_size], From 60bcc924c3b1963a674a99e39be38bec6af457e3 Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Sun, 30 Nov 2025 13:33:11 -0800 Subject: [PATCH 5/7] Store logit output in static tensor --- .../text_generation_controller.py | 84 ++++++++++++------- 1 file changed, 54 insertions(+), 30 deletions(-) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 91632291606..343f3b50d96 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -109,6 +109,10 @@ def _init_dynamic_sampling_tensors(self): """Initialize tensors needed for dynamic sampling.""" context = self.inference_wrapped_model.inference_context max_requests = context.max_requests + if context.config.materialize_only_last_token_logits: + max_logits = max_requests + else: + max_logits = context.max_tokens # Callback to get request IDs that should be marked as finished due to stop words self._get_stop_word_finished_ids_callback = None @@ -117,6 +121,15 @@ def _init_dynamic_sampling_tensors(self): logits_dtype = self.inference_wrapped_model.config.params_dtype self._sampling_backend = "torch" + self._enable_cuda_graph = False + + # Initialize bookkeeping tensors. + if self._enable_cuda_graph: + self._all_logits_cuda = torch.empty( + (1, max_logits, self.vocab_size), dtype=logits_dtype, device=device + ) + else: + self._all_logits_cuda = None self._sampled_tokens_cuda = torch.empty(max_requests, dtype=torch.int64, device=device) # Speculative tokens tensor will be allocated later when num_speculative_tokens is set by the engine self._accepted_tokens_per_request = None @@ -596,7 +609,7 @@ def _dynamic_step_context_init( else: return context.current_input_and_position_ids() - def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) -> Tensor: + def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor): """Forward step the model to get logits for dynamic batching. This also handles logits-broadcasting for pipeline parallelism. @@ -607,6 +620,11 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) """ context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count + logits_seq_len = ( + active_request_count + if context.config.materialize_only_last_token_logits + else context.padded_active_token_count + ) with torch.inference_mode(): logits = self.inference_wrapped_model.run_one_forward_step( @@ -619,6 +637,12 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) # will be computed serially after verification to ensure they are # conditioned on verified tokens only. + assert logits_seq_len == ( + active_request_count + if context.config.materialize_only_last_token_logits + else input_ids.shape[1] + ) + if self.model_is_pipeline_parallel: if context.config.materialize_only_last_token_logits: logits_seq_len = active_request_count @@ -636,7 +660,11 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) pp_group=self.pp_group, ) - return logits + # Copy logits to contiguous buffer. + if self._enable_cuda_graph: + self._all_logits_cuda[:, :logits_seq_len, :].copy_(logits) + else: + self._all_logits_cuda = logits def _dynamic_step_sample_bookkeeping(self): """Perform bookkeeping necessary to sample logits for dynamic batching.""" @@ -1053,7 +1081,7 @@ def _verify_speculative_tokens( return last_one_indices, accepted_tokens_mask, input_tokens_required - def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_ids: Tensor): + def _dynamic_step_sample_logits_and_verify_tokens(self, input_ids: Tensor): """ Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens. """ @@ -1069,6 +1097,7 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_id num_decode_requests = active_request_count - num_prefill_requests # Get the logit indices for tokens that need sampling. + logits = self._all_logits_cuda required_logit_indices = self._get_required_logit_indices( request_in_prefill_status_tensor, request_query_lengths, @@ -1132,24 +1161,22 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_id dim=1 ) - def _dynamic_step_sample_logits(self, logits: Tensor): - """Sample tokens from logits for dynamic batching. - - Args: - logits (Tensor): The logits from the forward pass. - """ + def _dynamic_step_sample_logits(self): + """Sample tokens from logits for dynamic batching.""" # TODO(ksanthanam): Evaluate whether it makes more sense to sample on 1 rank # and then broadcast the sampled tokens rather than broadcasting the raw logits. # Last token logits. context = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count + if context.config.materialize_only_last_token_logits: # When materialize_only_last_token_logits is true, last_token_logits is # already called in the forward pass of GPT. - required_token_logits = logits.squeeze(0) + required_token_logits = self._all_logits_cuda.squeeze(0)[:active_request_count, :] else: # todo : Should do verification here and get approrpiate las token logits - required_token_logits = context.last_token_logits(logits) + required_token_logits = context.last_token_logits(self._all_logits_cuda) if self._sampling_backend == "torch": # Concatenate the outputs once to prevent repeated small writes. @@ -1247,19 +1274,24 @@ def _router_record_bookkeeping(self) -> Optional[Dict[int, Tensor]]: return routing_indices_per_request - def _dynamic_step_calculate_log_probs(self, logits: Tensor) -> Optional[Tensor]: + def _dynamic_step_calculate_log_probs(self) -> Optional[Tensor]: """Calculate log probs from logits.""" context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count + logits_seq_len = ( + active_request_count + if context.config.materialize_only_last_token_logits + else context.padded_active_token_count + ) return context.calculate_log_probs( - logits, + self._all_logits_cuda[:, :logits_seq_len, :], self._sampled_tokens_cuda[:active_request_count], only_last_token_logits=context.config.materialize_only_last_token_logits, ) def _dynamic_step_calculate_log_probs_speculative( - self, logits: Tensor + self, ) -> Tuple[List[List[float]], Tensor]: """Calculate log probs from logits for speculative decoding. @@ -1271,9 +1303,6 @@ def _dynamic_step_calculate_log_probs_speculative( - log_prob(accepted_token[j]) comes from logits at position j - log_prob(newly_sampled_token) comes from logits at position accepted_count - Args: - logits (Tensor): The main model logits [1, seq_len, vocab_size]. - Returns: Tuple of (log_probs_list, log_probs_tensor): log_probs_list: List of lists, one per active request, containing @@ -1291,7 +1320,7 @@ def _dynamic_step_calculate_log_probs_speculative( num_prefill_requests = request_in_prefill_status_tensor.sum().item() num_decode_requests = active_request_count - num_prefill_requests - logits_squeezed = logits.squeeze(0).float() + logits_squeezed = self._all_logits_cuda.squeeze(0).float() log_probs_tensor = F.log_softmax(logits_squeezed[: context.active_token_count], dim=-1) log_probs_list_decode = [] @@ -1449,12 +1478,11 @@ def _dynamic_step_calculate_top_n_logprobs_speculative( return top_n_results if top_n_results else None def _dynamic_step_calculate_top_n_logprobs( - self, logits: Tensor, log_probs_tensor: Optional[Tensor] = None + self, log_probs_tensor: Optional[Tensor] = None ) -> Optional[Dict[int, List[Tuple[Tensor, Tensor]]]]: """Calculate top-n log probs from logits for dynamic batching. Args: - logits (Tensor): The logits to compute top-n log probs from. log_probs_tensor (Optional[Tensor]): Pre-computed log probabilities tensor. If provided, avoids recomputing log_softmax. Should be the tensor returned by calculate_log_probs. @@ -1743,7 +1771,7 @@ async def async_generate_output_tokens_dynamic_batch( # Forward pass produces only base logits. When speculative decoding is # active, MTP logits are computed serially after verification. - logits = self._dynamic_step_forward_logits(input_ids, position_ids) + self._dynamic_step_forward_logits(input_ids, position_ids) # Commit Mamba intermediate states before update_requests, which # may swap request indices. The Python lists tracking EOS block IDs @@ -1769,7 +1797,7 @@ async def async_generate_output_tokens_dynamic_batch( if self.num_speculative_tokens > 0: # Phase 1: Verify speculative tokens using base logits only. - self._dynamic_step_sample_logits_and_verify_tokens(logits, input_ids) + self._dynamic_step_sample_logits_and_verify_tokens(input_ids) # Phase 2: Rewind KV cache for rejected tokens. self._rewind_kv_cache() @@ -1781,25 +1809,21 @@ async def async_generate_output_tokens_dynamic_batch( # Phase 3: Compute MTP serially with correct (verified) inputs. self._compute_serial_mtp_and_sample() else: - self._dynamic_step_sample_logits(logits) + self._dynamic_step_sample_logits() log_probs = None top_n_logprobs = None if return_log_probs or return_top_n_logprobs: if self.num_speculative_tokens > 0: - log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs_speculative( - logits - ) + log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs_speculative() if return_top_n_logprobs: top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs_speculative( log_probs_tensor ) else: - log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs(logits) + log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs() if return_top_n_logprobs: - top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs( - logits, log_probs_tensor - ) + top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs(log_probs_tensor) if skip_bookkeeping: request_bookkeeping = {} From 37c081a9d779f10b95ee4f8b9990bfac1da31de7 Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Tue, 23 Dec 2025 10:11:48 -0600 Subject: [PATCH 6/7] Syntactic sugar for CG and awaiting --- megatron/core/inference/utils.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/megatron/core/inference/utils.py b/megatron/core/inference/utils.py index 7fb62dbd06a..c8cd0f411fe 100644 --- a/megatron/core/inference/utils.py +++ b/megatron/core/inference/utils.py @@ -5,7 +5,6 @@ import multiprocessing import sys import time - import torch from megatron.core.transformer.moe.moe_layer import MoELayer @@ -218,6 +217,31 @@ def tensor_swap(x, src_idxs, dst_idxs): x[dst_idxs], x[src_idxs] = x[src_idxs], x[dst_idxs] +def use_cuda_graph(graph_cache: dict, graph_key, fn): + """Record-or-replay a CUDA graph for fn(). + + On first call for a given graph_key, captures fn() into a CUDA graph. + On subsequent calls with the same key, replays the cached graph. + fn must be a zero-argument callable operating on static-address tensors. + """ + if graph_key in graph_cache: + graph_cache[graph_key].replay() + else: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + graph_cache[graph_key] = g + + +async def torch_awaitable(stream: torch.cuda.Stream | None = None): + """Syntactic sugar for returning an awaitable handle for non-distributed torch.""" + if stream is None: + stream = torch.cuda.current_stream() + event = stream.record_event() + while not event.query(): + await asyncio.sleep(0) + + async def await_process_call(call, process: multiprocessing.Process, timeout: float = 1.0): """Repeatedly wait for a multiprocessing callable to resolve, aborting upon process failure. From b34746e5dfa80d9eb954d2ca6eb463e469609c1b Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Thu, 19 Mar 2026 17:25:46 -0500 Subject: [PATCH 7/7] debug merge --- .../inference/contexts/dynamic_context.py | 12 ++++++ .../core/inference/engines/dynamic_engine.py | 42 +++++++++++++++++++ .../text_generation_controller.py | 35 ++++++++++++++++ 3 files changed, 89 insertions(+) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index a369e14c586..8bccf7e7fd3 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1659,11 +1659,22 @@ def initialize_attention_state( self.padded_active_request_count = self.padded_batch_dimensions.req_count self.padding_slice = slice(self.active_token_count, self.padded_active_token_count) + import os, sys + _rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + _dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a") + def _dbg(msg): + _dbg_f.write(f"[ATTN] {msg}\n"); _dbg_f.flush() + print(f"[rank{_rank}] [ATTN] {msg}", flush=True, file=sys.stderr) + + _dbg(f"build_active_slices start (padded_req={self.padded_active_request_count}, padded_tok={self.padded_active_token_count}, paused={self.paused_request_count}, total={self.total_request_count})") self.build_active_slices(self.padded_active_request_count) + _dbg("build_active_slices done") self.pad_active_slices() + _dbg("pad_active_slices done") batch_size = self.total_request_count - self.paused_request_count assert self.active_attn_metadata is not None + _dbg(f"mha_metadata.update start (batch_size={batch_size})") self.active_attn_metadata["mha_metadata"].update( request_query_lengths=self.active_request_query_lengths[:batch_size], request_kv_length_offsets=self.active_request_kv_length_offsets[:batch_size], @@ -1672,6 +1683,7 @@ def initialize_attention_state( padded_batch_dimensions=self.padded_batch_dimensions, num_speculative_tokens=self.num_speculative_tokens, ) + _dbg("mha_metadata.update done") if self.is_hybrid_model: active_slice = slice(self.paused_request_count, self.total_request_count) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 35d66878451..e03a613d4de 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -334,7 +334,16 @@ def create_cuda_graphs(self, reset_context: bool = True): reset_context (bool): Whether to reset the context after building cuda graphs. """ + import sys + _rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + _dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a") + def _dbg(msg): + _dbg_f.write(f"[CG] {msg}\n"); _dbg_f.flush() + print(f"[rank{_rank}] [CG] {msg}", flush=True, file=sys.stderr) + + _dbg(f"create_cuda_graphs start (impl={self.cuda_graph_impl})") if self.cuda_graph_impl != "local": + _dbg("skipping (not local)") return if ( @@ -393,12 +402,15 @@ def create_cuda_graphs(self, reset_context: bool = True): ) tbar = enumerate(context.cuda_graph_batch_dimensions_list) + _dbg(f"warmup loop start ({len(context.cuda_graph_batch_dimensions_list)} graphs)") if HAVE_TQDM: tbar = tqdm(tbar, total=len(context.cuda_graph_batch_dimensions_list)) for tbar_idx, cuda_graph_batch_dimension in tbar: + _dbg(f"warmup iter {tbar_idx}: context_init start ({cuda_graph_batch_dimension})") input_ids, position_ids = self.controller._dynamic_step_context_init( construct_graph_dimensions=cuda_graph_batch_dimension ) + _dbg(f"warmup iter {tbar_idx}: context_init done") # Progress. tbar_str = f"cuda graph warmup - {cuda_graph_batch_dimension}" if HAVE_TQDM: @@ -1630,12 +1642,23 @@ async def async_forward(self) -> Tuple[Dict, Dict, float]: step_time (float): How long this step took. """ + import sys + _rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + _dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a") + def _dbg(msg): + _dbg_f.write(f"[ENGINE] {msg}\n"); _dbg_f.flush() + print(f"[rank{_rank}] [ENGINE] {msg}", flush=True, file=sys.stderr) + + _dbg("async_forward enter") + # If suspended, no stepping. if self.state in (EngineState.SUSPENDED, EngineState.SUSPENDING): raise EngineSuspendedError(self.context.step_count) # schedule requests + _dbg("schedule_waiting_requests start") self.schedule_waiting_requests() + _dbg(f"schedule_waiting_requests done (total={self.context.total_request_count}, paused={self.context.paused_request_count}, tokens={self.context.active_token_count})") # Saving pre-step state, for printing output below. is_decode_only = self.context.is_decode_only() @@ -1654,7 +1677,9 @@ async def async_forward(self) -> Tuple[Dict, Dict, float]: self.is_decode_only = is_decode_only self.step_start_event.record() + _dbg("async_generate_output_tokens_dynamic_batch start") result = await self.controller.async_generate_output_tokens_dynamic_batch() + _dbg("async_generate_output_tokens_dynamic_batch done") self.step_end_event.record() self.step_end_event.synchronize() step_time = self.step_start_event.elapsed_time(self.step_end_event) / 1e3 @@ -2283,17 +2308,30 @@ async def run_engine_with_coordinator( self._loop = get_asyncio_loop(loop) self.use_coordinator = True + import sys + _rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + _dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a") + _iter = 0 + def _dbg(msg): + _dbg_f.write(f"[COORD iter={_iter}] {msg}\n"); _dbg_f.flush() + print(f"[rank{_rank}] [COORD iter={_iter}] {msg}", flush=True, file=sys.stderr) + try: while True: + _iter += 1 + _dbg(f"loop top (state={self.state})") self.schedule_requests() + _dbg(f"schedule done (active={self.context.get_active_request_count()}, waiting={len(self.waiting_request_ids)})") if self.state in (EngineState.RUNNING, EngineState.PAUSING): local_pending = self.context.get_active_request_count() + len( self.waiting_request_ids ) + _dbg(f"ep_consensus start (local_pending={local_pending})") global_work, all_pausing = await self._ep_establish_consensus( local_pending, signal_consensus=(self.state == EngineState.PAUSING) ) + _dbg(f"ep_consensus done (global_work={global_work}, all_pausing={all_pausing})") if all_pausing: # All EP peers are PAUSING: pause immediately. @@ -2303,15 +2341,19 @@ async def run_engine_with_coordinator( elif global_work > 0: # At least one EP peer has work: all must participate. if local_pending > 0: + _dbg("async_step start") await self.async_step() + _dbg("async_step done") else: # Dummy forward to participate in the EP collective. + _dbg("dummy_forward start") self.step_start_event.record() self.controller.dummy_forward() self.step_end_event.record() self.step_end_event.synchronize() self.context.step_count += 1 self.context.prefix_cache_lru_clock += 1 + _dbg("dummy_forward done") else: # No work, but not all pausing: idle. await asyncio.sleep(0.02) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 343f3b50d96..fc2e4076363 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -107,6 +107,11 @@ def set_stop_word_finished_ids_callback(self, callback): def _init_dynamic_sampling_tensors(self): """Initialize tensors needed for dynamic sampling.""" + import sys + _rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + _dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a") + _dbg_f.write("[INIT] _init_dynamic_sampling_tensors start\n"); _dbg_f.flush() + print(f"[rank{_rank}] [INIT] _init_dynamic_sampling_tensors start", flush=True, file=sys.stderr) context = self.inference_wrapped_model.inference_context max_requests = context.max_requests if context.config.materialize_only_last_token_logits: @@ -143,6 +148,8 @@ def _init_dynamic_sampling_tensors(self): self._torch_sampling_buckets: List[Tuple] = [] self._init_mtp_sampling_tensor() + _dbg_f.write("[INIT] _init_dynamic_sampling_tensors done\n"); _dbg_f.flush() + print(f"[rank{_rank}] [INIT] _init_dynamic_sampling_tensors done", flush=True, file=sys.stderr) def _init_mtp_sampling_tensor(self): """Initialize the MTP sampling tensor after num_speculative_tokens is set.""" @@ -626,11 +633,20 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor): else context.padded_active_token_count ) + import os, sys + _rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + _dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a") + def _dbg(msg): + _dbg_f.write(f"[FWD] {msg}\n"); _dbg_f.flush() + print(f"[rank{_rank}] [FWD] {msg}", flush=True, file=sys.stderr) + + _dbg(f"run_one_forward_step start (logits_seq_len={logits_seq_len})") with torch.inference_mode(): logits = self.inference_wrapped_model.run_one_forward_step( {"tokens": input_ids, "position_ids": position_ids, "attention_mask": None} ) # logits shape: [1, seq_len, vocab_size] + _dbg(f"run_one_forward_step done (logits={'None' if logits is None else tuple(logits.shape)})") # Note: When speculative decoding is active (num_speculative_tokens > 0), # the model skips MTP computation during the forward pass. MTP logits @@ -653,12 +669,14 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor): if is_pipeline_last_stage(self.pp_group): assert logits is not None and torch.Size(logits_shape) == logits.shape + _dbg("broadcast_from_last_pipeline_stage start") logits = broadcast_from_last_pipeline_stage( logits_shape, dtype=self.model_config.params_dtype, tensor=logits, pp_group=self.pp_group, ) + _dbg("broadcast_from_last_pipeline_stage done") # Copy logits to contiguous buffer. if self._enable_cuda_graph: @@ -1754,11 +1772,20 @@ async def async_generate_output_tokens_dynamic_batch( context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count + import os, sys + _rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + _dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a") + def _dbg(msg): + _dbg_f.write(f"[STEP] {msg}\n"); _dbg_f.flush() + print(f"[rank{_rank}] [STEP] {msg}", flush=True, file=sys.stderr) + # No tokens and no active requests? if context.active_token_count == 0 and active_request_count == 0: return None + _dbg(f"context_init start (tokens={context.active_token_count}, reqs={active_request_count})") input_ids, position_ids = self._dynamic_step_context_init() + _dbg(f"context_init done (input_ids.shape={tuple(input_ids.shape)})") cuda_graph_request_count = ( context.padded_active_request_count if context.is_decode_only() else None @@ -1771,7 +1798,9 @@ async def async_generate_output_tokens_dynamic_batch( # Forward pass produces only base logits. When speculative decoding is # active, MTP logits are computed serially after verification. + _dbg("forward_logits start") self._dynamic_step_forward_logits(input_ids, position_ids) + _dbg("forward_logits done") # Commit Mamba intermediate states before update_requests, which # may swap request indices. The Python lists tracking EOS block IDs @@ -1790,10 +1819,13 @@ async def async_generate_output_tokens_dynamic_batch( # asynchronous. # Todo [Siddharth]: Can we condition the sleep on a cuda event? # NOTE [TDE]: This will be moved once CPU and GPU methods are separated. + _dbg("yield start") await asyncio.sleep(0) + _dbg("yield done") return_log_probs, return_top_n_logprobs = self._dynamic_step_log_probs_bookkeeping() self._dynamic_step_sample_bookkeeping() + _dbg("sample_logits start") if self.num_speculative_tokens > 0: # Phase 1: Verify speculative tokens using base logits only. @@ -1810,6 +1842,7 @@ async def async_generate_output_tokens_dynamic_batch( self._compute_serial_mtp_and_sample() else: self._dynamic_step_sample_logits() + _dbg("sample_logits done") log_probs = None top_n_logprobs = None @@ -1825,10 +1858,12 @@ async def async_generate_output_tokens_dynamic_batch( if return_top_n_logprobs: top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs(log_probs_tensor) + _dbg("bookkeeping start") if skip_bookkeeping: request_bookkeeping = {} else: request_bookkeeping = self._dynamic_step_context_bookkeeping() + _dbg("bookkeeping done") ret = { # Clone needed: _sampled_tokens_cuda is a reused buffer overwritten each step.