diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index c5c04aceb78..038658bbe5c 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -24,6 +24,7 @@ #include "tensorrt_llm/runtime/utils/debugUtils.h" #include "tensorrt_llm/thop/attentionOp.h" #include "tensorrt_llm/thop/thUtils.h" +#include #include #include #include @@ -466,7 +467,8 @@ class Runner : public RunnerBase = spec_decoding_tensor_params[1].value().data_ptr(); enqueue_params.spec_decoding_packed_mask = spec_decoding_tensor_params[2].value().data_ptr(); enqueue_params.spec_decoding_is_generation_length_variable = true; - enqueue_params.spec_decoding_max_generation_length = input_seq_length + 1; + assert(spec_decoding_tensor_params[1].value().dim() == 2); // [batch_size, max_draft_len + 1] + enqueue_params.spec_decoding_max_generation_length = spec_decoding_tensor_params[1].value().sizes()[1]; } // Current mlaGeneration will using fmha to do attention, so we don't go into enqueueGeneration diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 7e1d3a060b4..518aac4546d 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -11,6 +11,8 @@ if TYPE_CHECKING: from ..speculative.utils import SpecDecodingTensor + from ..speculative.interface import SpecMetadata + from ..speculative.spec_tree_manager import SpecTreeManager from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils, RotaryScalingType) @@ -335,10 +337,15 @@ def restore_from_spec_dec(self) -> None: def update_spec_dec_param( self, + batch_size, is_spec_decoding_enabled, is_spec_dec_tree, is_spec_dec_dynamic_tree, - max_draft_tokens, + max_draft_len, + max_total_draft_tokens, + model_is_wrapped: Optional[bool] = False, + spec_metadata: Optional['SpecMetadata'] = None, + spec_tree_manager: Optional['SpecTreeManager'] = None, spec_decoding_tensor: Optional['SpecDecodingTensor'] = None): """ Hook to be called when using TRTLLM attention backend in spec-dec mode. diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index b2c9710f6fa..35040e9e058 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: from ..speculative.utils import SpecDecodingTensor + from ..speculative.interface import SpecMetadata + from ..speculative.spec_tree_manager import SpecTreeManager from tensorrt_llm._utils import get_sm_version from tensorrt_llm.bindings.internal import thop @@ -1055,13 +1057,30 @@ def prepare_context_mla_with_cached_kv(self, def update_spec_dec_param( self, + batch_size, is_spec_decoding_enabled, is_spec_dec_tree, is_spec_dec_dynamic_tree, - max_draft_tokens, + max_draft_len, + max_total_draft_tokens, + model_is_wrapped: Optional[bool] = False, + spec_metadata: Optional['SpecMetadata'] = None, + spec_tree_manager: Optional['SpecTreeManager'] = None, spec_decoding_tensor: Optional['SpecDecodingTensor'] = None, ): - + ''' + Update the spec-dec parameters for the TRTLLM attention layer. + Args: + batch_size: int, the number of requests in the batch. + is_spec_decoding_enabled: bool, whether the attention need to be spec_decoding mode, which is determined by attention_need_spec_dec_mode() function. + is_spec_dec_tree: bool, whether the spec-dec mode is a tree, i.e., static tree or dynamic tree. For linear-tree, it is always False. + is_spec_dec_dynamic_tree: bool, whether using dynamic tree. + max_draft_len: int, the number of the draft layers. + max_total_draft_tokens: int, the number of all nodes in the tree (except the root). + model_is_wrapped: Optional[bool] = False, whether the drafter model is wrapped (i.e, CDL). + spec_metadata: Optional['SpecMetadata'] = None, the metadata of the spec-dec. + spec_tree_manager: Optional['SpecTreeManager'] = None, the spec_tree_manager for draft token tree. + ''' if spec_decoding_tensor is not None: spec_decoding_position_offsets = spec_decoding_tensor.position_offsets spec_decoding_packed_mask = spec_decoding_tensor.packed_mask @@ -1075,9 +1094,9 @@ def update_spec_dec_param( ) < 100 if get_sm_version() >= 100: - if is_spec_dec_tree or is_spec_dec_dynamic_tree: - assert not is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree." - assert not is_spec_dec_dynamic_tree, "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree." + if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree: + assert not self.is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree." + assert not self.is_spec_dec_dynamic_tree, "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree." # use_spec_decoding is default to true by default, change in runtime by layers / requests self.use_spec_decoding = self.is_spec_decoding_enabled @@ -1087,16 +1106,18 @@ def update_spec_dec_param( # Parameters can be fixed and not changed during runtime if the if self.is_spec_decoding_enabled: + # These buffers are accessed more like removing input padding, + # rather than using max_total_draft_tokens + 1 as the offset between different requests. self.spec_decoding_position_offsets = torch.empty( - [self.max_num_requests, max_draft_tokens + 1], + [self.max_num_requests, max_total_draft_tokens + 1], dtype=torch.int, device='cuda', ) self.spec_decoding_packed_mask = torch.empty( [ - self.max_num_requests, max_draft_tokens + 1, - math.ceil((max_draft_tokens + 1) / 32) + self.max_num_requests, max_total_draft_tokens + 1, + math.ceil((max_total_draft_tokens + 1) / 32) ], dtype=torch.int, device='cuda', @@ -1108,7 +1129,11 @@ def update_spec_dec_param( device='cuda', ) - if self.is_spec_dec_dynamic_tree: + is_target_model = not spec_metadata.is_draft_model if hasattr( + spec_metadata, 'is_draft_model') else False + + # Case 1: dynamic tree + if self.is_spec_dec_tree and self.is_spec_dec_dynamic_tree: assert spec_decoding_position_offsets is not None, "spec_decoding_position_offsets is required for dynamic tree" assert spec_decoding_packed_mask is not None, "spec_decoding_packed_mask is required for dynamic tree" self.spec_decoding_position_offsets.copy_( @@ -1120,35 +1145,86 @@ def update_spec_dec_param( spec_decoding_generation_lengths, non_blocking=True) else: self.generate_spec_decoding_generation_length( - max_draft_tokens=max_draft_tokens) + max_draft_len=max_total_draft_tokens) + + # Case 2/3: static tree + elif self.is_spec_dec_tree and not self.is_spec_dec_dynamic_tree and spec_metadata is not None: + assert spec_metadata.spec_dec_mode.is_eagle3( + ), "Tree decoding is only supported for Eagle3 now" + + # Case 2: static tree and target model + if is_target_model: + # For the target model, we update the spec-dec parameters with the spec_tree_manager, which is prepared in advance. + self.spec_decoding_position_offsets[:batch_size, :].copy_( + spec_tree_manager.spec_dec_position_offsets[0, :], + non_blocking=True) + self.spec_decoding_packed_mask[:batch_size, :, :].copy_( + spec_tree_manager.spec_dec_packed_mask[0, :, :], + non_blocking=True) + self.spec_decoding_generation_lengths[:batch_size].fill_( + spec_tree_manager.max_total_draft_tokens + 1) + + # Case 3: static tree and the first drafter layer + else: + assert model_is_wrapped == True, "The drafter model should be wrapped" + # The first drafter layer will take the padded tokens as input (padding to the max_draft_len + 1) + # But the spec-dec parameters are still in the shape of max_total_draft_tokens + 1. + # Considering that these spec-dec params are accessed consecutively (without padding) in the attention Op, + # we need to write them consecutively when setting them. + # For the next drafter layers, we will prepare these spec-dec params in the drafting loops. + # position_offsets + position_offset = torch.arange( + max_draft_len + 1, + dtype=torch.int, + device='cpu', + pin_memory=True).repeat(batch_size) + self.spec_decoding_position_offsets.reshape( + -1)[:(max_draft_len + 1) * batch_size].copy_( + position_offset, non_blocking=True) + # packed_mask + dummy_idx = torch.arange(max_draft_len + 1) + spec_decoding_packed_mask = torch.pow( + 2, dummy_idx + 1) - 1 # [max_draft_len + 1] + spec_decoding_packed_mask = spec_decoding_packed_mask.repeat( + batch_size) # [batch_size * (max_draft_len + 1)] + self.spec_decoding_packed_mask.reshape( + -1)[:(max_draft_len + 1) * batch_size].copy_( + spec_decoding_packed_mask, non_blocking=True) + # generation_lengths + self.generate_spec_decoding_generation_length( + max_draft_len=max_draft_len) + + # Case 4: linear tree else: + # Prepare for the linear-tree. # Populate the mask that won't change during inference phase. self.generate_spec_decoding_position_offsets( - max_draft_tokens=max_draft_tokens) + max_draft_len=max_draft_len) self.generate_spec_decoding_packed_mask( - max_draft_tokens=max_draft_tokens) + max_draft_len=max_draft_len) self.generate_spec_decoding_generation_length( - max_draft_tokens=max_draft_tokens) + max_draft_len=max_draft_len) - def generate_spec_decoding_position_offsets(self, max_draft_tokens): - position_offset = torch.arange(max_draft_tokens + 1, + def generate_spec_decoding_position_offsets(self, max_draft_len): + position_offset = torch.arange(max_draft_len + 1, dtype=torch.int, device='cpu', pin_memory=True) - # fill all the batches with same position offset self.spec_decoding_position_offsets.copy_(position_offset, non_blocking=True) - def generate_spec_decoding_packed_mask(self, max_draft_tokens): - dummy_idx = torch.arange(max_draft_tokens + 1) + def generate_spec_decoding_packed_mask(self, max_draft_len): + # FIXME: remove this limitation + assert max_draft_len < 32, "max_draft_len should be less than 32, will be fixed later" + dummy_idx = torch.arange(max_draft_len + 1) spec_decoding_packed_mask = torch.pow(2, dummy_idx + 1) - 1 self.spec_decoding_packed_mask[:, :, 0].copy_(spec_decoding_packed_mask, non_blocking=True) - def generate_spec_decoding_generation_length(self, max_draft_tokens): + def generate_spec_decoding_generation_length(self, max_draft_len): spec_decoding_generation_length = torch.full((self.max_num_requests, ), - max_draft_tokens + 1) + max_draft_len + 1) self.spec_decoding_generation_lengths[:self.max_num_requests].copy_( spec_decoding_generation_length, non_blocking=True) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index a5d4a9033a4..d8f0f1b8be9 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -47,7 +47,7 @@ get_spec_metadata, update_spec_config_from_model_config) from ..speculative.drafting_loops import ChainDrafter -from ..speculative.eagle3 import Eagle3ResourceManager +from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata from ..speculative.mtp import SampleStateTensorsMTP from ..speculative.utils import SpecDecodingTensor from ..utils import (get_model_extra_attrs, @@ -711,7 +711,7 @@ def _release_batch_context(self, batch: Optional[ScheduledRequests], def _get_num_extra_decoding_steps(self) -> int: """Determines extra decoding steps needed for fused drafting loops.""" if isinstance(self.model, ChainDrafter): - return self.model.max_draft_len + return self.model.max_total_draft_tokens else: assert not self.model_is_wrapped, ( f"Please add logic to determine num_extra_decoding_steps for drafting loop {type(self.model)}" @@ -1210,7 +1210,8 @@ def _prepare_tp_inputs( attn_metadata: AttentionMetadata, spec_metadata: Optional[SpecMetadata] = None, new_tensors_device: Optional[SampleStateTensors] = None, - cache_indirection_buffer: Optional[torch.Tensor] = None): + cache_indirection_buffer: Optional[torch.Tensor] = None, + resource_manager: Optional[ResourceManager] = None): """ Prepare inputs for Pytorch Model. """ @@ -1243,6 +1244,9 @@ def _prepare_tp_inputs( multimodal_params_list = [] mrope_position_ids = [] num_accepted_draft_tokens = [] # per request + # if using tree decoding, we need to store the request type and accepted path for each request, + # which will be used to update the hidden_states_read_indices. + request_accepted_path = {} # per request for request in scheduled_requests.context_requests: request_ids.append(request.py_request_id) @@ -1257,6 +1261,9 @@ def _prepare_tp_inputs( gather_ids.append(len(input_ids) - 1) sequence_lengths.append(len(prompt_tokens)) num_accepted_draft_tokens.append(len(prompt_tokens) - 1) + request_accepted_path[ + request. + py_request_id] = request.py_num_accepted_draft_tokens_indices prompt_lengths.append(len(prompt_tokens)) past_seen_token_num = begin_compute num_cached_tokens_per_seq.append(past_seen_token_num) @@ -1331,11 +1338,22 @@ def _prepare_tp_inputs( assert spec_config.spec_dec_mode.support_overlap_scheduler( ), f"{spec_config.decoding_type} does not support overlap scheduler" + spec_resource_manager, spec_tree_manager = None, None + if spec_config is not None: + spec_resource_manager = resource_manager.get_resource_manager( + ResourceManagerType.SPEC_RESOURCE_MANAGER) + if spec_resource_manager is not None and hasattr( + spec_resource_manager, 'spec_tree_manager'): + spec_tree_manager = spec_resource_manager.spec_tree_manager + # will contain previous batch indices of generation requests previous_batch_indices = [] previous_pos_indices = [] for request in extend_requests: request_ids.append(request.py_request_id) + request_accepted_path[ + request. + py_request_id] = request.py_num_accepted_draft_tokens_indices # the request has no previous tensor: # (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or # (2) a dummy request; or @@ -1353,7 +1371,7 @@ def _prepare_tp_inputs( past_seen_token_num = request.max_beam_num_tokens - 1 draft_lens.append(num_draft_tokens) if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx( - self.attn_backend): + self.attn_backend) and spec_config.is_linear_tree: # We're treating the prompt lengths as context requests here, so # the the prompt lens should not include the cached tokens. prompt_lengths.append(1 + num_draft_tokens) @@ -1366,10 +1384,20 @@ def _prepare_tp_inputs( list( range(len(position_ids), len(position_ids) + 1 + self.runtime_draft_len))) - position_ids.extend( - list( - range(past_seen_token_num, - past_seen_token_num + 1 + num_draft_tokens))) + # For the target model + tree decoding + if not self.is_draft_model and not spec_config.is_linear_tree: + assert spec_tree_manager is not None + assert num_draft_tokens == spec_tree_manager.max_total_draft_tokens + position_ids.extend( + past_seen_token_num + + spec_tree_manager.spec_dec_position_offsets[ + 0] # [max_total_draft_tokens + 1] + ) + else: + position_ids.extend( + list( + range(past_seen_token_num, + past_seen_token_num + 1 + num_draft_tokens))) num_cached_tokens_per_seq.append(past_seen_token_num) request.cached_tokens = num_cached_tokens_per_seq[-1] # update batch index @@ -1389,10 +1417,21 @@ def _prepare_tp_inputs( list( range(len(position_ids), len(position_ids) + 1 + self.runtime_draft_len))) - position_ids.extend( - list( - range(past_seen_token_num, past_seen_token_num + 1 + - self.runtime_draft_len))) + # For the target model + tree decoding + if not self.is_draft_model and not spec_config.is_linear_tree: + assert spec_tree_manager is not None + assert num_draft_tokens == spec_tree_manager.max_total_draft_tokens + position_ids.extend( + past_seen_token_num + + spec_tree_manager.spec_dec_position_offsets[ + 0] # [max_total_draft_tokens + 1] + ) + else: + position_ids.extend( + list( + range( + past_seen_token_num, past_seen_token_num + 1 + + self.runtime_draft_len))) # previous tensor previous_batch_indices.append(previous_batch_idx) previous_pos_indices.extend([previous_batch_idx] * @@ -1407,7 +1446,7 @@ def _prepare_tp_inputs( self.runtime_draft_len + 1) request.cached_tokens = num_cached_tokens_per_seq[-1] if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx( - self.attn_backend): + self.attn_backend) and spec_config.is_linear_tree: prompt_lengths.append(1 + self.runtime_draft_len) else: prompt_lengths.append(request.py_prompt_len) @@ -1429,6 +1468,9 @@ def _prepare_tp_inputs( sequence_lengths.append(1 + self.original_max_draft_len) num_accepted_draft_tokens.append( request.py_num_accepted_draft_tokens) + request_accepted_path[ + request. + py_request_id] = request.py_num_accepted_draft_tokens_indices prompt_lengths.append(request.py_prompt_len) past_seen_token_num = begin_compute num_cached_tokens_per_seq.append(past_seen_token_num) @@ -1671,7 +1713,9 @@ def previous_seq_slots_device(): # so that we can update the kv_lens_cuda correctly in _preprocess_inputs. attn_metadata.num_chunked_ctx_requests = 0 if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx( - self.attn_backend): + self.attn_backend) and spec_config.is_linear_tree: + # For the tree decoding, we want to use XQA to process the draft tokens for the target model. + # Therefore, we do not treat them as the chunked context requests. attn_metadata.num_contexts += len(extend_requests) attn_metadata.num_chunked_ctx_requests = len(extend_requests) @@ -1735,6 +1779,8 @@ def previous_seq_slots_device(): spec_metadata.seq_lens = sequence_lengths spec_metadata.num_accepted_draft_tokens = self.num_accepted_draft_tokens_cuda[:len( num_accepted_draft_tokens)] + if isinstance(spec_metadata, Eagle3SpecMetadata): + spec_metadata.request_accepted_path = request_accepted_path spec_metadata.prepare() inputs['spec_metadata'] = spec_metadata @@ -2232,14 +2278,14 @@ def _get_lora_params_from_requests(self, return lora_params @nvtx_range("_prepare_inputs") - def _prepare_inputs( - self, - scheduled_requests: ScheduledRequests, - kv_cache_manager: KVCacheManager, - attn_metadata: AttentionMetadata, - spec_metadata: Optional[SpecMetadata] = None, - new_tensors_device: Optional[SampleStateTensors] = None, - cache_indirection_buffer: Optional[torch.Tensor] = None): + def _prepare_inputs(self, + scheduled_requests: ScheduledRequests, + kv_cache_manager: KVCacheManager, + attn_metadata: AttentionMetadata, + spec_metadata: Optional[SpecMetadata] = None, + new_tensors_device: Optional[SampleStateTensors] = None, + cache_indirection_buffer: Optional[torch.Tensor] = None, + resource_manager: Optional[ResourceManager] = None): if self.mapping is not None and 'cp_type' in self.mapping.cp_config: cp_type = self.mapping.cp_config['cp_type'] if CpType.STAR == cp_type: @@ -2255,7 +2301,8 @@ def _prepare_inputs( return self._prepare_tp_inputs(scheduled_requests, kv_cache_manager, attn_metadata, spec_metadata, new_tensors_device, - cache_indirection_buffer) + cache_indirection_buffer, + resource_manager) @torch.inference_mode() @with_model_extra_attrs(lambda self: self.model.extra_attrs) @@ -2275,6 +2322,9 @@ def forward( if self.enable_spec_decode: spec_resource_manager = resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) + spec_tree_manager = None + if isinstance(spec_resource_manager, Eagle3ResourceManager): + spec_tree_manager = spec_resource_manager.spec_tree_manager spec_metadata = self._set_up_spec_metadata(spec_resource_manager, no_cache=kv_cache_manager is None) @@ -2283,9 +2333,16 @@ def forward( spec_resource_manager, self.is_draft_model, self.attn_backend, self.model_is_wrapped, spec_metadata.is_spec_dec_tree) attn_metadata.update_spec_dec_param( - is_spec_dec_mode, spec_metadata.is_spec_dec_tree, - spec_metadata.is_spec_dec_dynamic_tree, - self.original_max_draft_len, spec_decoding_tensor) + batch_size=scheduled_requests.batch_size, + is_spec_decoding_enabled=is_spec_dec_mode, + is_spec_dec_tree=spec_metadata.is_spec_dec_tree, + is_spec_dec_dynamic_tree=spec_metadata.is_spec_dec_dynamic_tree, + max_draft_len=self.original_max_draft_len, + max_total_draft_tokens=self.original_max_total_draft_tokens, + model_is_wrapped=self.model_is_wrapped, + spec_metadata=spec_metadata, + spec_tree_manager=spec_tree_manager, + spec_decoding_tensor=spec_decoding_tensor) else: spec_resource_manager = None spec_metadata = None @@ -2322,7 +2379,7 @@ def forward( inputs, gather_ids = self._prepare_inputs( padded_requests, kv_cache_manager, attn_metadata, spec_metadata, - new_tensors_device, cache_indirection_buffer) + new_tensors_device, cache_indirection_buffer, resource_manager) self.iter_counter += 1 with with_shared_pool(self.cuda_graph_runner.get_graph_pool()): diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 40123060498..d1703f129e1 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -581,7 +581,7 @@ def update_kv_cache_draft_token_location(self, requests = scheduled_batch.all_requests() accepted_draft_token_offsets, packed_accepted_draft_tokens_indices, rewind_draft_token_separate_adjustments = self.locate_accepted_draft_tokens( requests) - past_key_value_lengths = attn_metadata.kv_lens_cuda + past_key_value_lengths = attn_metadata.kv_lens_cuda[:len(requests)] if attn_metadata.kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_pool_pointers is not None and attn_metadata.host_kv_cache_pool_mapping is not None: use_paged_kv_cache = True else: diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 276aa977003..00c6c44e47d 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -770,7 +770,7 @@ def _process_draft_tokens_tree( assert seq_slot is not None eagle_paths = spec_tree_manager.get_eagle_paths(seq_slot) - all_draft_tokens = request.py_draft_tokens # [max_total_draft_tokens] + all_draft_tokens = torch.tensor(request.py_draft_tokens) # [max_total_draft_tokens] all_target_tokens = new_tokens_tensor[:, seq_slot, :].squeeze( -1 ) # [max_total_draft_tokens] @@ -803,6 +803,9 @@ def _process_draft_tokens_tree( longest_accepted_len = cur_accepted_len longest_match_path_idx = path_idx + request.py_num_accepted_draft_tokens_indices = list( + range(1, spec_tree_manager.max_draft_len + 1) + ) if longest_accepted_len == 0: # No draft tokens are accepted. # Take the top-1 token of the first layer as the next new token. @@ -819,67 +822,11 @@ def _process_draft_tokens_tree( if self._handle_stop_criteria(request, new_token): break + request.py_num_accepted_draft_tokens_indices[: num_accepted_draft_tokens - 1] = ( + eagle_paths[longest_match_path_idx][1:longest_accepted_len] + ) # exclude the root node return num_accepted_draft_tokens - 1 - def _tree_sampling_batch( - self, - requests: list[LlmRequest], - max_num_sequences: int, - seq_slots: torch.Tensor, - model_outputs: dict[str, torch.Tensor], - spec_tree_manager: SpecTreeManager, - ): - if ( - spec_tree_manager.use_dynamic_tree - # FIXME: 'draft_layer_id' is undefined - and draft_layer_id == spec_tree_manager.max_draft_len - 1 # noqa: F821 - ): - # TODO: Re-sample the draft tokens for the last layer. - raise NotImplementedError("Dynamic tree is not fully supported yet.") - - raw_logits = model_outputs["logits"] - num_requests = len(requests) - assert raw_logits.shape[0] % num_requests == 0 - num_logits_per_request = raw_logits.shape[0] // num_requests - request_index = torch.arange(num_requests) - - draft_layer_id = spec_tree_manager.cur_draft_layer_idx - # 1) Get the topK list for the specific draft layer. - top_k_list = spec_tree_manager.get_top_k_list(draft_layer_id) - assert len(top_k_list) == num_logits_per_request - - # Considering that the beam_width of spec-dec can only be 1, we ignore this dimension here. - new_draft_tokens_cuda = torch.empty( - (max_num_sequences, spec_tree_manager.max_total_draft_tokens + 1), - dtype=torch.int64, - device=raw_logits.device, - ) - - top_k_list_cumsum = torch.cumsum(top_k_list, dim=0) - # Different nodes have different topK value. - for i, top_k_list_i in enumerate(top_k_list): - # 2) Extract the logits needed for this layer. - logits = raw_logits[request_index * num_logits_per_request + i, :] - assert logits.shape[0] == len(requests) - # 3) Sample the logits according to the topK value. - indices = torch.topk(logits, k=top_k_list_i, dim=-1).indices - # 4) Write to the temporary output tensor. - new_draft_tokens_cuda[ - seq_slots, top_k_list_cumsum[i] - top_k_list_i : top_k_list_cumsum[i] - ] = indices[request_index] - - # 5) Append eagle3 d2t. - self._apply_d2t(new_draft_tokens_cuda, model_outputs) - - # 6) Copy back to the output tensor. - int_new_draft_tokens = ( - new_draft_tokens_cuda.transpose(0, 1).to(torch.int, non_blocking=True).unsqueeze(dim=-1) - ) - - new_draft_tokens_host = int_new_draft_tokens.to("cpu", non_blocking=True) - - return new_draft_tokens_host - @torch.inference_mode() def _process_draft_tokens_rejection_sampling( self, @@ -1433,7 +1380,6 @@ def _process_requests( resource_manager: Optional[ResourceManager] = None, ) -> torch.Tensor: seq_slots = seq_slots.to(dtype=torch.int32) # int32 suffices here - spec_tree_manager = self.get_spec_tree_manager(resource_manager) raw_logits_cuda = model_outputs["logits"] @@ -1467,19 +1413,6 @@ def _process_requests( logits_cuda = self._apply_min_length_penalty(logits_cuda, requests, req_num_steps_list) - # Fast path for drafter model's tree sampling. - if spec_tree_manager is not None and logits_cuda.size(0) == len( - scheduled_requests.all_requests() - ): - new_tokens_host = self._tree_sampling_batch( - requests, - self.max_num_sequences, - seq_slots, - model_outputs, - spec_tree_manager, - ) - return new_tokens_host - # Indexer for accessing tokens in 'logits_cuda', corresponding to the # requests in 'requests'. steps_dim_size = new_tokens_cuda.size(0) diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index 3dd8683ba3e..b29b67a86d0 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -64,6 +64,7 @@ def pad_draft_tokens_for_cuda_graph( """ for req in scheduled_requests.generation_requests: max_draft_tokens = self.max_draft_len + self.max_total_draft_tokens num_draft_tokens = get_draft_token_length(req) req.py_draft_tokens.extend( 0 for _ in range(max_draft_tokens - num_draft_tokens)) diff --git a/tensorrt_llm/_torch/speculative/drafting_loops.py b/tensorrt_llm/_torch/speculative/drafting_loops.py index 886f0111ef8..b37856d0fe3 100644 --- a/tensorrt_llm/_torch/speculative/drafting_loops.py +++ b/tensorrt_llm/_torch/speculative/drafting_loops.py @@ -9,12 +9,14 @@ """ from contextlib import contextmanager +from typing import List, Optional, Tuple import torch from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata from tensorrt_llm._torch.speculative.eagle3 import Eagle3SpecMetadata from tensorrt_llm._torch.speculative.interface import SpecMetadata +from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager @contextmanager @@ -105,6 +107,191 @@ def prepare_for_generation(attn_metadata: AttentionMetadata, return new_position_ids +def prepare_for_generation_with_tree_decoding( + prepare_for_layer_idx: int, new_draft_tokens: List[torch.Tensor], + attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata, + spec_tree_manager: SpecTreeManager, + position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + ''' + Prepare the inputs for the next draft layer. What we need to prepare are: + 1) inputs_ids + 2) position_ids + 3) attn_metadata + 3.1) kv_lens_cuda + 3.2) _seq_lens, _seq_lens_cuda + 3.3) host_request_types + 3.4) num_contexts + 3.5) use_spec_decoding + 3.6) spec_decoding_position_offsets + 3.7) spec_decoding_packed_mask + 3.8) spec_decoding_generation_lengths + 4) spec_metadata + 4.1) num_tokens + 4.2) gather_ids + 4.3) hidden_states_read_indices, hidden_states_write_indices + 4.4) is_first_draft + + ''' + + batch_size = attn_metadata.num_seqs + next_layer_gen_len_per_req = spec_tree_manager.spec_dec_generation_lengths_for_drafter_model[ + prepare_for_layer_idx - 1] + + # 1) Prepare the inputs_ids + all_draft_tokens = torch.cat( + new_draft_tokens, + dim=-1) # [batch_size, num_draft_tokens_has_been_generated] + cur_tokens_gather_idx = spec_tree_manager.tokens_gather_idx[ + prepare_for_layer_idx - + 1] - 1 # shape: [next_layer_gen_len_per_req]. -1 is toshift the root node + new_input_ids = all_draft_tokens[:, cur_tokens_gather_idx].reshape( + -1) # [batch_size * next_layer_gen_len_per_req] + + num_accepted_draft_tokens = spec_metadata.num_accepted_draft_tokens[: + batch_size] + seq_lens = attn_metadata.seq_lens_cuda[:batch_size] + last_tokens_idx = None + + # 2) Prepare the position_ids + if prepare_for_layer_idx == 1: + last_tokens_idx = torch.cumsum( + seq_lens, dim=0, + dtype=torch.long) - seq_lens + num_accepted_draft_tokens + new_position_ids = position_ids[0, last_tokens_idx] + 1 # [batch_size] + assert new_position_ids.shape == (batch_size, ) + # For the layer_idx == 1, the input tokens are both expanded from root node. + # Therefore, their position ids are the same. + new_position_ids = torch.repeat_interleave( + new_position_ids, repeats=next_layer_gen_len_per_req, + dim=0) # [batch_size * next_layer_gen_len_per_req] + else: + position_ids = position_ids.reshape(batch_size, -1) + position_ids_start_idx = position_ids[:, 0] # [batch_size] + assert position_ids_start_idx.shape == (batch_size, ) + + new_position_ids = spec_tree_manager.spec_dec_position_offsets_for_drafter_model[ + prepare_for_layer_idx - 1].unsqueeze(0).repeat( + batch_size, 1) # [batch_size, num_next_layer_input_tokens] + new_position_ids = new_position_ids + position_ids_start_idx.unsqueeze( + 1) # [batch_size, num_next_layer_input_tokens] + new_position_ids = new_position_ids.reshape( + -1) # [batch_size * num_next_layer_input_tokens] + + assert new_position_ids.shape == new_input_ids.shape + + # 3) Prepare the attn_metadata + ## 3.1) kv_lens_cuda + if prepare_for_layer_idx == 1: + attn_metadata.kv_lens_cuda[: + batch_size] -= seq_lens - num_accepted_draft_tokens - 1 + attn_metadata.kv_lens_cuda[:batch_size] += next_layer_gen_len_per_req + else: + prev_layer_gen_len_per_req = spec_tree_manager.spec_dec_generation_lengths_for_drafter_model[ + prepare_for_layer_idx - 2] + attn_metadata.kv_lens_cuda[: + batch_size] -= prev_layer_gen_len_per_req # reset to original length before the drafter loop. + attn_metadata.kv_lens_cuda[:batch_size] += next_layer_gen_len_per_req + + # FIXME, update without D2H + # Not updating kv_lens here has no effect on the calculation results, but only affects the calculation performance. + # attn_metadata.kv_lens[:batch_size] = attn_metadata.kv_lens_cuda[:batch_size].cpu() + + ## 3.2) _seq_lens, _seq_lens_cuda + attn_metadata._seq_lens[:batch_size].fill_(next_layer_gen_len_per_req) + attn_metadata._seq_lens_cuda[:batch_size].fill_(next_layer_gen_len_per_req) + attn_metadata.on_update() + + # Update once is enough + if prepare_for_layer_idx == 1: + ## 3.3) host_request_types + attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1) + ## 3.4) num_contexts + attn_metadata.num_contexts = 0 + ## 3.5) use_spec_decoding + attn_metadata.use_spec_decoding = True + + # NOTE: For the spec_decoding_position_offsets, spec_decoding_packed_mask, spec_decoding_generation_lengths, + # They are stored contiguously without padding. This is why we need to reshape them here. + # They're initially allocated a size of [batch_size, max_total_draft_tokens + 1]. + # However, within the drafter, the number of draft tokens input to each drafter layer is less than max_total_draft_tokens, + # so we only need to use part of the buffer. We need to store them contiguously. + ## 3.6) spec_decoding_position_offsets + attn_metadata.spec_decoding_position_offsets.reshape( + -1 + )[:batch_size * + next_layer_gen_len_per_req] = spec_tree_manager.spec_dec_position_offsets_for_drafter_model[ + prepare_for_layer_idx - 1].repeat(batch_size) + + ## 3.7) spec_decoding_packed_mask + attn_metadata.spec_decoding_packed_mask.reshape( + -1, attn_metadata.spec_decoding_packed_mask.size(-1) + )[:batch_size * + next_layer_gen_len_per_req, :] = spec_tree_manager.spec_dec_packed_mask_for_drafter_model[ + prepare_for_layer_idx - 1].repeat(batch_size, 1) + + ## 3.8) spec_decoding_generation_lengths + attn_metadata.spec_decoding_generation_lengths[: + batch_size] = next_layer_gen_len_per_req + + # 4) spec_metadata + ## 4.1) num_tokens + spec_metadata.num_tokens = batch_size * next_layer_gen_len_per_req + + ## 4.2) gather_ids + offset = torch.arange( + batch_size, + device=position_ids.device) * next_layer_gen_len_per_req # [batch_size] + spec_metadata.gather_ids = spec_tree_manager.logits_gather_idx[ + prepare_for_layer_idx - 1].unsqueeze(0).repeat( + batch_size, 1) # [1, num_tokens_has_children] + spec_metadata.gather_ids = spec_metadata.gather_ids + offset.unsqueeze( + 1) # [batch_size, num_tokens_has_children] + spec_metadata.gather_ids = spec_metadata.gather_ids.reshape( + -1) # [batch_size * num_tokens_has_children] + + ## 4.3) hidden_states_read_indices, hidden_states_write_indices + if isinstance(spec_metadata, Eagle3SpecMetadata): + start_idx = None + if prepare_for_layer_idx == 1: + old_write_indices = spec_metadata.hidden_states_write_indices + start_idx = old_write_indices[ + last_tokens_idx] # [batch_size], already take the accepted tokens into account. + else: + prev_layer_gen_len_per_req = spec_tree_manager.spec_dec_generation_lengths_for_drafter_model[ + prepare_for_layer_idx - 2] + last_tokens_idx = torch.arange( + batch_size, + device=position_ids.device) * prev_layer_gen_len_per_req + old_read_indices = spec_metadata.hidden_states_read_indices + start_idx = old_read_indices[last_tokens_idx] # [batch_size] + + start_idx = start_idx.unsqueeze(1) # [batch_size, 1] + + start_read_idx = start_idx + spec_tree_manager.hidden_states_read_indices_offset_for_drafter_model[ + prepare_for_layer_idx - + 1] # [batch_size, next_layer_gen_len_per_req] + spec_metadata.hidden_states_read_indices[:batch_size * + next_layer_gen_len_per_req].copy_( + start_read_idx.reshape(-1) + ) # [batch_size * next_layer_gen_len_per_req] + + start_write_idx = start_idx + spec_tree_manager.hidden_states_write_indices_offset_for_drafter_model[ + prepare_for_layer_idx - + 1] # [batch_size, next_layer_gen_len_per_req] + spec_metadata.hidden_states_write_indices[:batch_size * + next_layer_gen_len_per_req].copy_( + start_write_idx.reshape( + -1) + ) # [batch_size * next_layer_gen_len_per_req] + + if prepare_for_layer_idx == 1: + ## 4.4) is_first_draft + spec_metadata.eagle3_resource_manager.is_first_draft = False + spec_metadata.is_first_draft = False + + return new_input_ids, new_position_ids + + class ChainDrafter(torch.nn.Module): def __init__(self, max_draft_len: int, max_total_draft_tokens: int, @@ -119,46 +306,171 @@ def __init__(self, max_draft_len: int, max_total_draft_tokens: int, def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata, **kwargs) -> dict[str, torch.Tensor]: + spec_tree_manager = None + if isinstance(spec_metadata, Eagle3SpecMetadata): + spec_tree_manager = spec_metadata.eagle3_resource_manager.spec_tree_manager + logits = self.draft_model.forward(input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata, spec_metadata=spec_metadata, return_context_logits=True) - logits = logits[spec_metadata.gather_ids] + batch_size = attn_metadata.num_seqs + vocab_size = logits.shape[-1] + logits = logits[spec_metadata.gather_ids] # [batch_size, vocab_size] + + new_draft_tokens = [ + self.sample(draft_layer_idx=0, + batch_size=batch_size, + logits=logits, + spec_tree_manager=spec_tree_manager) + ] + assert logits.shape == (batch_size, vocab_size) + # When using tree decoding, the first layer's draft tokens are all from the root node's logits. + # Therefore, we repeat the logits and collect them. + draft_logits = [ + logits if spec_tree_manager is None else logits.repeat( + spec_tree_manager.top_k_list_cuda[0], 1).reshape( + batch_size, -1, vocab_size) + ] - new_draft_tokens = [self.sample(logits)] - draft_logits = [logits] with save_metadata_state(attn_metadata, spec_metadata): batch_size = attn_metadata.num_seqs + if spec_tree_manager is None: + new_input_ids = new_draft_tokens[-1] + new_position_ids = prepare_for_generation( + attn_metadata, spec_metadata, position_ids) + else: + new_input_ids, new_position_ids = prepare_for_generation_with_tree_decoding( + prepare_for_layer_idx= + 1, # prepare for the 1st layer, start from the 0-th layer. + new_draft_tokens=new_draft_tokens, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + spec_tree_manager=spec_tree_manager, + position_ids=position_ids) - new_position_ids = prepare_for_generation(attn_metadata, - spec_metadata, - position_ids) - for i in range(self.max_draft_len - 1): + for layer_idx in range(1, self.max_draft_len): logits = self.draft_model.forward( - input_ids=new_draft_tokens[-1], + input_ids=new_input_ids, position_ids=new_position_ids, attn_metadata=attn_metadata, - spec_metadata=spec_metadata) - new_draft_tokens.append(self.sample(logits)) - draft_logits.append(logits) - new_position_ids += 1 - attn_metadata.kv_lens_cuda[:batch_size] += 1 - if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata): - spec_metadata.hidden_states_read_indices[:batch_size].copy_( - spec_metadata.hidden_states_write_indices[:batch_size]) - - return { - "new_draft_tokens": torch.stack(new_draft_tokens), - "draft_logits": torch.stack(draft_logits) - } - - def sample(self, logits: torch.Tensor) -> torch.Tensor: + spec_metadata=spec_metadata, + return_context_logits=False + if spec_tree_manager is None else True) + if spec_tree_manager is not None: + # if using tree decoding, only the last 'num_tokens_has_children' tokens need to be sampled. + logits = logits[spec_metadata.gather_ids] + + new_draft_tokens.append( + self.sample(draft_layer_idx=layer_idx, + batch_size=batch_size, + logits=logits, + spec_tree_manager=spec_tree_manager)) + + if spec_tree_manager is None: + draft_logits.append(logits) + else: + # logits: [batch_size * num_tokens_has_children, vocab_size] + cur_top_k_list = spec_tree_manager.top_k_list_cuda[ + layer_idx] # [num_tokens_has_children] + cur_top_k_list = cur_top_k_list.repeat( + batch_size) # [batch_size * num_tokens_has_children] + logits = torch.repeat_interleave( + logits, repeats=cur_top_k_list, dim=0 + ) # [batch_size * num_tokens_has_children, vocab_size] + logits = logits.reshape( + batch_size, -1, vocab_size + ) # [batch_size, next_layer_draft_tokens, vocab_size] + draft_logits.append(logits) + + if spec_tree_manager is None: + new_input_ids = new_draft_tokens[-1] + new_position_ids += 1 + attn_metadata.kv_lens_cuda[:batch_size] += 1 + if layer_idx == 0 and isinstance(spec_metadata, + Eagle3SpecMetadata): + spec_metadata.hidden_states_read_indices[:batch_size].copy_( + spec_metadata. + hidden_states_write_indices[:batch_size]) + elif layer_idx < spec_tree_manager.max_draft_len - 1: + new_input_ids, new_position_ids = prepare_for_generation_with_tree_decoding( + prepare_for_layer_idx=layer_idx + 1, + new_draft_tokens=new_draft_tokens, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + spec_tree_manager=spec_tree_manager, + position_ids=new_position_ids) + + if spec_tree_manager is None: + return { + "new_draft_tokens": torch.stack(new_draft_tokens), + "draft_logits": torch.stack(draft_logits) + } + else: + # new_draft_tokens: List[torch.Tensor], each tensor is of shape [batch_size, num_draft_tokens_each_layers] + # len(new_draft_tokens) == max_draft_len + return_new_draft_tokens = torch.cat( + new_draft_tokens, + dim=-1) # [batch_size, max_total_draft_tokens] + return_new_draft_tokens = torch.transpose( + return_new_draft_tokens, 0, + 1) # [max_total_draft_tokens, batch_size] + + # draft_logits: List[torch.Tensor], each tensor is of shape [batch_size, num_draft_tokens_each_layers, vocab_size] + return_draft_logits = torch.cat( + draft_logits, + dim=1) # [batch_size, max_total_draft_tokens, vocab_size] + return_draft_logits = torch.transpose( + return_draft_logits, 0, + 1) # [max_total_draft_tokens, batch_size, vocab_size] + + assert return_new_draft_tokens.shape[ + 0] == return_draft_logits.shape[0] + assert return_new_draft_tokens.shape[ + 1] == return_draft_logits.shape[1] + + return { + "new_draft_tokens": return_new_draft_tokens, + "draft_logits": return_draft_logits + } + + def sample( + self, + draft_layer_idx: int, + batch_size: int, + logits: torch.Tensor, + spec_tree_manager: Optional[SpecTreeManager] = None + ) -> torch.Tensor: # TODO: inject the sampler here so we can support non-greedy - tokens = torch.argmax(logits, dim=-1) - if hasattr(self.draft_model.model, "d2t"): - d2t = self.draft_model.model.d2t.data - return tokens + d2t[tokens] + + if spec_tree_manager is None: + tokens = torch.argmax(logits, dim=-1) + if hasattr(self.draft_model.model, "d2t"): + d2t = self.draft_model.model.d2t.data + return tokens + d2t[tokens] + else: + max_topk_list = spec_tree_manager.max_top_k_list_cuda[ + draft_layer_idx] + indices = torch.topk(logits, k=max_topk_list, dim=-1).indices + top_k_list = spec_tree_manager.top_k_list_cuda[draft_layer_idx] + top_k_list = top_k_list.repeat(batch_size) + rows = torch.arange(top_k_list.shape[0], + dtype=torch.int32, + device=logits.device) + row_indices = torch.repeat_interleave(rows, repeats=top_k_list) + col_indices = torch.cat([torch.arange(c) for c in top_k_list]) + tokens = indices[ + row_indices, + col_indices] # [batch_size * num_draft_tokens_this_layer] + + if hasattr(self.draft_model.model, "d2t"): + d2t = self.draft_model.model.d2t.data + tokens = tokens + d2t[tokens] + + # reshape, for better gather later. + tokens = tokens.reshape( + batch_size, -1) # [batch_size, num_draft_tokens_this_layer] return tokens diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 423f4354bd9..c2a72fb2daa 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set import torch from torch import nn @@ -48,8 +48,10 @@ def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype, self.max_total_draft_tokens = self.max_draft_len # empty hidden states tensor - max_num_tokens = min(max_num_tokens, - max_num_requests * self.max_seq_len) + max_num_tokens = min(max_num_tokens, max_num_requests * + self.max_seq_len) + (self.max_total_draft_tokens + + 1) * max_num_requests + self.hidden_states = torch.empty( (max_num_tokens, self.hidden_size * config.num_capture_layers), dtype=self.dtype, @@ -122,6 +124,10 @@ class Eagle3SpecMetadata(SpecMetadata): eagle_choices: Optional[List[List[int]]] = None max_total_draft_tokens: int = 0 + # This is to store the request type and accepted path for each request. + # For each request, {key: request_ids, value: accepted_path} + # 'accepted_path' is a list of accepted tokens indices. + request_accepted_path: Optional[Dict[int, List[int]]] = None def __post_init__(self): if self.is_draft_model: @@ -159,6 +165,7 @@ def __post_init__(self): def prepare(self): is_first_draft = self.eagle3_resource_manager.is_first_draft + spec_tree_manager = self.eagle3_resource_manager.spec_tree_manager # Update start indices # Here, we assume the sequence lengths (seq_lens) during the draft model # forward will not exceed those of the target model. So pre-allocate @@ -169,20 +176,52 @@ def prepare(self): slot_id = self.eagle3_resource_manager.slot_manager.get_slot( req_id) self.eagle3_resource_manager.start_indices[slot_id] = start_idx - start_idx += seq_len + # Make sure that the space between two requests is at least max_total_draft_tokens + 1. + start_idx += max(seq_len, self.max_total_draft_tokens + 1) + assert start_idx < self.eagle3_resource_manager.hidden_states.shape[ + 0], f"start_idx {start_idx} is greater than hidden_states.shape[0] {self.eagle3_resource_manager.hidden_states.shape[0]}" + # Prepare hidden states gather ids hidden_states_read_indices = [] hidden_states_write_indices = [] for req_id, seq_len in zip(self.request_ids, self.seq_lens): slot_id = self.eagle3_resource_manager.slot_manager.get_slot(req_id) start_idx = self.eagle3_resource_manager.start_indices[slot_id] + # 1) target model or (is_first_draft and is_linear_tree) # If this is the first draft or the target model forward, we need to - # read/write all of the hidden states, otherwise, only read the last token - if is_first_draft or not self.is_draft_model: + # read/write all of the hidden states + if not self.is_draft_model or (is_first_draft + and spec_tree_manager is None): hidden_states_read_indices.extend( list(range(start_idx, start_idx + seq_len))) hidden_states_write_indices.extend( list(range(start_idx, start_idx + seq_len))) + # 2)is_first_draft and draft_token_tree + # After target model forward, some draft tokens will be accepted. + # These draft tokens' hidden states will be used for draft model's first drafter layer. + elif is_first_draft and spec_tree_manager is not None: + assert req_id in self.request_accepted_path.keys( + ), f"Request {req_id} not found in request_accepted_path" + accepted_path = self.request_accepted_path[req_id] + + if accepted_path == []: + # This is a context request. We need to read all the hidden states. + hidden_states_read_indices.extend( + list(range(start_idx, start_idx + seq_len))) + else: + # This is a generation request. We only read the accepted tokens' hidden states. + assert len( + accepted_path + ) + 1 == seq_len, f"Accepted path length + 1 ({len(accepted_path) + 1}) is not equal to sequence length ({seq_len})" + accepted_path = [0] + accepted_path # add the root node + hidden_states_read_indices.extend([ + start_idx + accepted_draft_token_offset + for accepted_draft_token_offset in accepted_path + ]) + # For the write indices, we just write all the hidden states. + hidden_states_write_indices.extend( + list(range(start_idx, start_idx + seq_len))) + # otherwise: only read the last token else: old_seq_len = self.eagle3_resource_manager.seq_lens[slot_id] hidden_states_read_indices.append(start_idx + old_seq_len - 1) diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 41be42a54df..fbd563db37a 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -130,16 +130,32 @@ def attention_need_spec_dec_mode( spec_resource_manager: BaseResourceManager, is_draft_model: bool, attention_backend: Type[AttentionBackend], - use_chain_drafter: bool, + use_chain_drafter: bool, # CDL is_spec_dec_tree: bool, ): """ If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode). + Args: + spec_resource_manager: the resource manager for the spec-dec mode. + is_draft_model: whether the model is a draft model. + attention_backend: the attention backend. + use_chain_drafter: whether to use capturable drafting loops (CDL). For the target model, it is always False. + is_spec_dec_tree: whether the spec-dec mode is a tree, i.e., static tree or dynamic tree. """ is_trtllm_attention = issubclass(attention_backend, TrtllmAttention) - return self.is_eagle3_one_model() or ( - self.is_eagle3() and spec_resource_manager.is_first_draft - and is_trtllm_attention and use_chain_drafter and is_draft_model) + # Case 1: one model + use_case_1 = self.is_eagle3_one_model() + # Case 2: eagle3 two model + draft model + CDL + is_first_draft + TRTLLM attention + use_case_2 = self.is_eagle3( + ) and spec_resource_manager.is_first_draft and use_chain_drafter and is_draft_model and is_trtllm_attention + # Case 3: eagle3 two model + tree decoding + draft model + CDL + TRTLLM attention + use_case_3 = self.is_eagle3( + ) and is_spec_dec_tree and is_draft_model and use_chain_drafter and is_trtllm_attention + # Case 4: eagle3 two model + tree decoding + target model + TRTLLM attention + use_case_4 = self.is_eagle3( + ) and is_spec_dec_tree and not is_draft_model and is_trtllm_attention + + return use_case_1 or use_case_2 or use_case_3 or use_case_4 @staticmethod def from_string(name: Optional[str]) -> "SpeculativeDecodingMode": diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 959a17b7d1c..3b202f61de5 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -167,6 +167,9 @@ def _create_accepted_tokens_request_for_trtllm_attn( new_request.state = LlmRequestState.GENERATION_IN_PROGRESS new_request.py_num_accepted_draft_tokens = request.py_num_accepted_draft_tokens new_request.py_is_first_draft = True + # For tree decoding, we need to store the accepted tokens indices for these requests, + # which will be used to update the hidden_states_read_indices. + new_request.py_num_accepted_draft_tokens_indices = request.py_num_accepted_draft_tokens_indices return new_request def _create_draft_request_for_request( @@ -570,7 +573,7 @@ def process_static_draft_outputs( # Chunked prefill request in progress; no need to append draft tokens continue py_draft_logits = [] - for token_idx in range(self.max_draft_len): + for token_idx in range(self.max_total_draft_tokens): target_model_req.py_draft_tokens.append( draft_tokens_host[token_idx][req_idx]) py_draft_logits.append(draft_logits[token_idx][req_idx]) diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index 74ec518d602..10edc7ecc13 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -171,6 +171,8 @@ def __init__( assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool." self.spec_config = spec_config self.max_draft_len = spec_config.max_draft_len + self.max_total_draft_tokens = spec_config.max_total_draft_tokens + assert self.max_draft_len == self.max_total_draft_tokens, "NGram only supports linear tree." self.spec_resource_manager = ngram_pool_manager def prepare_draft_tokens( diff --git a/tensorrt_llm/_torch/speculative/spec_tree_manager.py b/tensorrt_llm/_torch/speculative/spec_tree_manager.py index b84083cfb3b..b9ce12965a8 100644 --- a/tensorrt_llm/_torch/speculative/spec_tree_manager.py +++ b/tensorrt_llm/_torch/speculative/spec_tree_manager.py @@ -1,3 +1,4 @@ +import math from typing import List, Optional import torch @@ -11,6 +12,8 @@ class SpecTreeManager: cur_draft_layer_idx: int # The current index of the drafter layer # Auxiliary buffers + # The top k list for each draft layer. + top_k_list = [] # The user input eagle choices, only available when using static tree. eagle_choices: Optional[List[List[int]]] = None # If dynamice tree, each request has their own tree. If static tree, all requests share the same tree. @@ -20,11 +23,59 @@ class SpecTreeManager: # shape: [num_trees, max_total_draft_tokens + 1, max_draft_len + 1] eagle_paths: Optional[torch.Tensor] = None - # The spec decoding mask. - # shape: [num_trees, max_total_draft_tokens + 1, max_total_draft_tokens + 1], include the root node. + # The spec decoding mask. Include the root node. + # shape: [num_trees, max_total_draft_tokens + 1, max_total_draft_tokens + 1], device tensor. spec_dec_mask_matrix: Optional[torch.Tensor] = None - # shape: [num_trees, max_total_draft_tokens + 1], pad the 0-1 matrix to int32 vector. - spec_dec_pack_mask: Optional[torch.Tensor] = None + + # The packed decoding mask for the target model to verify the draft tokens. Pad the 0-1 matrix to int32 vector. + # shape: [num_trees, max_total_draft_tokens + 1], device tensor. + spec_dec_packed_mask: Optional[torch.Tensor] = None + + # The spec position offsets for the target model to verify the draft tokens. + # shape: [num_trees, max_total_draft_tokens + 1], device tensor. + spec_dec_position_offsets: Optional[torch.Tensor] = None + + # TODO: Optimized together with the subsequent dynamic tree. + # Auxiliary buffers for the static tree. + # Considering that the static tree does not modify the tree structure during inference, we can calculate some buffers in advance. + # NOTE: Most of these buffers are introduced due to limitations of XQA: + # With tree attention, XQA cannot simply take the tokens to be processed in the next round as input. Instead, it needs to take ALL of their parent nodes as input. + # This incurs additional computation, but it is unavoidable. + + # NOTE: The reason why most of these auxiliary buffers are with `len == max_draft_len - 1` is that: we do not need to prepare specific input data for the first draft layer. + + # The top k value for each draft layer. Device tensor. + top_k_list_cuda: list[torch.Tensor] = None + + # The max top k value for each draft layer. Device tensor. + max_top_k_list_cuda: list[torch.Tensor] = None + + # Gather the required draft tokens from all currently generated draft tokens as the input of the next draft layer. + # Only the nodes has child(s) this layer and all their parents nodes will be gathered. + # Device tensor. len(tokens_gather_idx) == max_draft_len - 1. Each element is a tensor with shape [num_tokens_for_next_layer]. + tokens_gather_idx: list[torch.Tensor] = None + + # Gather the required logits from all currently generated logits. + # Device tensor. len(tokens_gather_idx) == max_draft_len - 1. + logits_gather_idx: list[torch.Tensor] = None + + # The packed mask for the drafter model's attention (i.e., xqa). + # Device tensor. len(spec_dec_packed_mask_for_drafter_model) == max_draft_len - 1. Each element is a tensor with shape [max_total_draft_tokens + 1, math.ceil((self.max_total_draft_tokens + 1) / 32)]. + spec_dec_packed_mask_for_drafter_model: list[torch.Tensor] = None + + # The position offset for the drafter model's attention (i.e., xqa). + # Device tensor. len(spec_dec_position_offsets_for_drafter_model) == max_draft_len - 1. Each element is a tensor with shape [max_total_draft_tokens + 1]. + spec_dec_position_offsets_for_drafter_model: list[torch.Tensor] = None + + # The generation length for the drafter model's attention (i.e., xqa). + # Device tensor. shape: [max_draft_len] + spec_dec_generation_lengths_for_drafter_model: torch.Tensor = None + + # The read/write indices offset for the drafter model. len(hidden_states_write_indices_offset_for_drafter_model) == max_draft_len - 1. Each element is a tensor with shape [num_tokens_for_next_layer]. + hidden_states_write_indices_offset_for_drafter_model: list[ + torch.Tensor] = None + hidden_states_read_indices_offset_for_drafter_model: list[ + torch.Tensor] = None def __init__(self, max_num_requests: int, use_dynamic_tree: bool, max_total_draft_tokens: int, max_draft_len: int, @@ -36,7 +87,8 @@ def __init__(self, max_num_requests: int, use_dynamic_tree: bool, self.eagle_choices = eagle_choices self.num_trees = max_num_requests if use_dynamic_tree else 1 self.dynamic_tree_max_topK = dynamic_tree_max_topK - self.cur_draft_layer_idx = -1 + self.cur_draft_layer_idx = 0 + self.top_k_list = [] # Initialize the buffers self.eagle_paths = torch.ones( @@ -46,25 +98,32 @@ def __init__(self, max_num_requests: int, use_dynamic_tree: bool, device='cpu', pin_memory=True, ) * -1 + self.spec_dec_mask_matrix = torch.eye( self.max_total_draft_tokens + 1, - dtype=torch.bool, - device='cpu', - pin_memory=True, + dtype=torch.int32, + device='cuda', ).unsqueeze(0).repeat(self.num_trees, 1, 1) - self.spec_dec_pack_mask = torch.zeros( - (self.num_trees + 1, self.max_total_draft_tokens + 1), + + self.spec_dec_packed_mask = torch.zeros( + (self.num_trees, self.max_total_draft_tokens + 1, + math.ceil((self.max_total_draft_tokens + 1) / 32)), dtype=torch.int32, - device='cpu', - pin_memory=True, + device='cuda', + ) + self.spec_dec_position_offsets = torch.zeros( + (self.num_trees, self.max_total_draft_tokens + 1), + dtype=torch.int32, + device='cuda', ) - self.top_k_list = [] if self.use_dynamic_tree: self.init_tree_info_for_dynamic_tree() else: self.init_tree_info_for_static_tree() + # self.dump_tree_info() + def init_tree_info_for_dynamic_tree(self): # For the dynamic tree # To the internal layer, the number of nodes is the same as the dynamic_tree_max_topK. @@ -77,13 +136,13 @@ def init_tree_info_for_dynamic_tree(self): # For the static tree def init_tree_info_for_static_tree(self): - index_mapping_set = {} - nodes_list_per_layer = [[] for _ in range(self.max_draft_len + 1)] + self.index_mapping_set = {} + self.nodes_list_per_layer = [[] for _ in range(self.max_draft_len + 1)] child_nodes_list = [[] for _ in range(self.max_total_draft_tokens + 1)] # 1) Map the index for i, choice in enumerate(self.eagle_choices): - index_mapping_set[str(choice)] = i + 1 + self.index_mapping_set[str(choice)] = i + 1 # 2) Reconstruct the eagle_paths self.eagle_paths.fill_(-1) @@ -91,44 +150,157 @@ def init_tree_info_for_static_tree(self): for i, choice in enumerate(self.eagle_choices): self.eagle_paths[0][i + 1][0] = 0 for j in range(len(choice)): - self.eagle_paths[0][i + 1][j + 1] = index_mapping_set[str( + self.eagle_paths[0][i + 1][j + 1] = self.index_mapping_set[str( choice[:j + 1])] # 3) Compute node_list_per_layer - nodes_list_per_layer[0].append(0) # root node + self.nodes_list_per_layer[0].append(0) # root node for choice in self.eagle_choices: cur_layer = len(choice) - nodes_list_per_layer[cur_layer].append( - index_mapping_set[str(choice)]) + self.nodes_list_per_layer[cur_layer].append( + self.index_mapping_set[str(choice)]) # 4) Compute child_nodes_list for choice in self.eagle_choices: if len(choice) == 1: # root node's children - child_nodes_list[0].append(index_mapping_set[str(choice)]) + child_nodes_list[0].append(self.index_mapping_set[str(choice)]) else: - child_nodes_list[index_mapping_set[str(choice[:-1])]].append( - index_mapping_set[str(choice)]) + child_nodes_list[self.index_mapping_set[str( + choice[:-1])]].append(self.index_mapping_set[str(choice)]) # 5) Compute top_k_list for i in range(self.max_draft_len): - cur_layer_nodes = nodes_list_per_layer[i] + cur_layer_nodes = self.nodes_list_per_layer[i] tmp_top_k_list = [ len(child_nodes_list[node]) for node in cur_layer_nodes if len(child_nodes_list[node]) > 0 ] - assert sum(tmp_top_k_list) == len(nodes_list_per_layer[i + 1]) + assert sum(tmp_top_k_list) == len(self.nodes_list_per_layer[i + 1]) self.top_k_list.append( torch.tensor(tmp_top_k_list, dtype=torch.int32, device='cpu', pin_memory=True)) - # 6) Compute the spec decoding according to the eagle_paths + # 6) Compute the spec decoding according to the eagle_paths for the target model for i, path in enumerate(self.eagle_paths[0]): indices = path[path > -1] self.spec_dec_mask_matrix[0][i, indices] = 1 - self.spec_dec_pack_mask = self.compute_spec_dec_pack_mask( - self.spec_dec_mask_matrix) + self.compute_spec_dec_packed_mask(self.spec_dec_mask_matrix, + self.spec_dec_packed_mask) + + # 7) Compute the spec position offsets for the target model + start_idx = 0 + for i in range(self.max_draft_len + 1): + num_nodes_this_layer = len(self.nodes_list_per_layer[i]) + self.spec_dec_position_offsets[:, start_idx:start_idx + + num_nodes_this_layer] = i + start_idx += num_nodes_this_layer + + ### Compute the auxiliary buffers for the drafter model + # 8) Copy top_k_list_cuda + self.top_k_list_cuda = [] + for i in range(self.max_draft_len): + self.top_k_list_cuda.append(self.top_k_list[i].to(device='cuda')) + self.max_top_k_list_cuda = torch.tensor( + [max(top_k_list) for top_k_list in self.top_k_list_cuda], + dtype=torch.int32, + device='cuda') + + # 9) Compute the tokens_gather_idx, include the root node + self.tokens_gather_idx = [] + for cur_layer_nodes in self.nodes_list_per_layer[1:]: + parents_set = set() + for node in cur_layer_nodes: + if len(child_nodes_list[node]) > 0: + parents_set.update( + self.spec_dec_mask_matrix[0][node].nonzero().reshape( + -1)[1:].tolist()) + self.tokens_gather_idx.append( + torch.tensor(list(parents_set), + dtype=torch.int32, + device='cuda')) + + # 10) Compute the logits_gather_idx + self.logits_gather_idx = [] + for nodes_per_layer, cur_gather_idx in zip( + self.nodes_list_per_layer[1:], self.tokens_gather_idx): + num_nodes_has_children = 0 + for node in nodes_per_layer: + if len(child_nodes_list[node]) > 0: + num_nodes_has_children += 1 + cur_input_tokens_len = cur_gather_idx.shape[0] + assert cur_input_tokens_len >= num_nodes_has_children + self.logits_gather_idx.append( + torch.tensor(range( + cur_input_tokens_len - num_nodes_has_children, + cur_input_tokens_len), + dtype=torch.int32, + device='cuda')) + + for cur_gather_idx in self.tokens_gather_idx: + tmp_logits_gather_idx = [] + for idx, node in enumerate(cur_gather_idx.tolist()): + if len(child_nodes_list[node]) > 0: + tmp_logits_gather_idx.append(idx) + self.logits_gather_idx.append( + torch.tensor(tmp_logits_gather_idx, + dtype=torch.int32, + device='cuda')) + + # 11) Compute the spec_dec_packed_mask_for_drafter_model + self.spec_dec_packed_mask_for_drafter_model = [] + for cur_gather_idx in self.tokens_gather_idx: + tmp_mast_matrix = self.spec_dec_mask_matrix[0][ + cur_gather_idx, :][:, cur_gather_idx] + tmp_packed_mask = torch.zeros( + (1, cur_gather_idx.shape[0], + math.ceil((self.max_total_draft_tokens + 1) / 32)), + dtype=torch.int32, + device='cuda') + self.compute_spec_dec_packed_mask(tmp_mast_matrix.unsqueeze(0), + tmp_packed_mask) + self.spec_dec_packed_mask_for_drafter_model.append( + tmp_packed_mask.squeeze(0)) + + # 12) Compute the spec_dec_position_offsets_for_drafter_model + self.spec_dec_position_offsets_for_drafter_model = [] + for cur_gather_idx in self.tokens_gather_idx: + self.spec_dec_position_offsets_for_drafter_model.append( + torch.tensor( + self.spec_dec_position_offsets[0][cur_gather_idx.tolist()] - + 1, # shift the root node + dtype=torch.int32, + device='cuda')) + + # 13) Compute the spec_dec_generation_lengths_for_drafter_model + self.spec_dec_generation_lengths_for_drafter_model = torch.tensor( + [ + cur_gather_idx.shape[0] + for cur_gather_idx in self.tokens_gather_idx + ], + dtype=torch.int32, + device='cuda') + + # 14) Compute the hidden_states_write_indices_for_drafter_model + self.hidden_states_write_indices_offset_for_drafter_model = [] + for cur_gather_idx in self.tokens_gather_idx: + self.hidden_states_write_indices_offset_for_drafter_model.append( + torch.tensor(cur_gather_idx.tolist(), + dtype=torch.int32, + device='cuda')) + + # 15) Compute the hidden_states_read_indices_for_drafter_model + self.hidden_states_read_indices_offset_for_drafter_model = [] + for cur_gather_idx in self.tokens_gather_idx: + tmp_parent_nodes = [] + for node in cur_gather_idx: + tmp_parent_nodes.append( + self.spec_dec_mask_matrix[0][node].nonzero().reshape( + -1).tolist()[-2]) + self.hidden_states_read_indices_offset_for_drafter_model.append( + torch.tensor(tmp_parent_nodes, dtype=torch.int32, + device='cuda')) # Get the eagle_paths def get_eagle_paths(self, tree_idx=0): @@ -149,18 +321,75 @@ def get_top_k_list(self, draft_layer_id): return self.top_k_list[draft_layer_id] # Compute the packed mask according to the mask matrix - def compute_spec_dec_pack_mask(self, mask_matrix): - # mask_matrix: shape: [num_trees, max_total_draft_tokens + 1, max_total_draft_tokens + 1] - int_tensor = mask_matrix.to(torch.int32) - weights = torch.pow( - 2, torch.arange(mask_matrix.shape[-1], device=mask_matrix.device)) - packed_mask = torch.sum(int_tensor * weights, dim=-1) - return packed_mask + def compute_spec_dec_packed_mask(self, mask_matrix, packed_mask): + # mask_matrix: shape: [num_trees, num_process_tokens, num_process_tokens] + # packed_mask: shape: [num_trees, num_process_tokens, math.ceil((max_total_draft_tokens + 1) / 32)] + assert mask_matrix.ndim == 3 + assert packed_mask.ndim == 3 + num_trees = mask_matrix.size(0) + num_process_tokens = mask_matrix.size(1) + assert mask_matrix.shape == (num_trees, num_process_tokens, + num_process_tokens) + assert packed_mask.shape == (num_trees, num_process_tokens, + math.ceil( + (self.max_total_draft_tokens + 1) / + 32)) + if num_process_tokens == 0: + return + + num_blocks = math.ceil((self.max_total_draft_tokens + 1) / 32) + int_tensor = mask_matrix.reshape( + -1, num_process_tokens + ) # shape: [num_trees * num_process_tokens, num_process_tokens] + packed_mask = packed_mask.reshape( + -1, + num_blocks) # shape: [num_trees * num_process_tokens, num_blocks] + + for block_idx in range(num_blocks): + start_idx = block_idx * 32 + end_idx = min(start_idx + 32, num_process_tokens) + if end_idx < start_idx: + break + block_bits = int_tensor[:, start_idx:end_idx] + weight = torch.pow( + 2, + torch.arange(end_idx - start_idx, + dtype=torch.int32, + device=int_tensor.device)) + block_value = torch.sum(block_bits * weight, dim=-1) + packed_mask[:, block_idx] = block_value + + packed_mask = packed_mask.reshape(num_trees, num_process_tokens, + num_blocks) # Print the tree info def dump_tree_info(self): + print(f"TopK list: {self.top_k_list}") if not self.use_dynamic_tree: + print(f"Max top k list cuda: {self.max_top_k_list_cuda}") print(f"Static tree: {self.eagle_paths}") - print(f"TopK list: {self.top_k_list}") - print(f"Spec dec mask matrix: {self.spec_dec_mask_matrix.int()}") - print(f"Spec dec pack mask: {self.spec_dec_pack_mask}") + print(f"Index mapping set: {self.index_mapping_set}") + print(f"Nodes list per layer: {self.nodes_list_per_layer}") + print( + f"Spec dec position offsets: {self.spec_dec_position_offsets}") + print(f"Spec dec mask matrix: {self.spec_dec_mask_matrix.int()}") + print(f"Spec dec pack mask: {self.spec_dec_packed_mask}") + print("Auxiliary buffers for the static tree.") + print(f"TopK list cuda: {self.top_k_list_cuda}") + print(f"Tokens gather idx: {self.tokens_gather_idx}") + print(f"Logits gather idx: {self.logits_gather_idx}") + print( + f"Spec dec packed mask for drafter model: {self.spec_dec_packed_mask_for_drafter_model}" + ) + print( + f"Spec dec position offsets for drafter model: {self.spec_dec_position_offsets_for_drafter_model}" + ) + print( + f"Spec dec generation lengths for drafter model: {self.spec_dec_generation_lengths_for_drafter_model}" + ) + print( + f"Hidden states write indices offset for drafter model: {self.hidden_states_write_indices_offset_for_drafter_model}" + ) + print( + f"Hidden states read indices offset for drafter model: {self.hidden_states_read_indices_offset_for_drafter_model}" + ) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 4a01844fb82..b417ccd3c4c 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -519,6 +519,10 @@ def spec_dec_mode(self): return TorchSpeculativeDecodingMode.from_string( self.decoding_type.upper()) + @functools.cached_property + def is_linear_tree(self) -> bool: + return self.max_draft_len == self.max_total_draft_tokens + class KvCacheConnectorConfig(StrictBaseModel): """ diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 3f1bab6c942..9a8fefb7500 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -2057,6 +2057,41 @@ def test_ptp_quickstart_advanced_eagle3(llm_root, llm_venv, model_name, _check_mem_usage(running_log, [25.2, 0, 0, 0]) +@pytest.mark.parametrize("model_name,model_path,eagle_model_path", [ + ("Llama-3.1-8b-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct", + "EAGLE3-LLaMA3.1-Instruct-8B"), +]) +def test_draft_token_tree_quickstart_advanced_eagle3(llm_root, llm_venv, + model_name, model_path, + eagle_model_path): + print(f"Testing {model_name}.") + example_root = Path(os.path.join(llm_root, "examples", "llm-api")) + with tempfile.NamedTemporaryFile(mode='w+t', + suffix=f".{model_name}.log", + dir="./", + delete=True, + delete_on_close=True) as running_log: + llm_venv.run_cmd([ + str(example_root / "quickstart_advanced.py"), + "--prompt", + "You are a good assistant. Please tell me the capital of France is", + "--spec_decode_max_draft_len", + "3", + "--spec_decode_algo", + "eagle3", + "--model_dir", + f"{llm_models_root()}/{model_path}", + "--draft_model_dir", + f"{llm_models_root()}/{eagle_model_path}", + "--disable_kv_cache_reuse", + "--disable_overlap_scheduler", + "--eagle_choices", + "[[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [2, 0], [0, 0, 0], [0, 1, 0], [1, 0, 0]]", + ], + stdout=running_log) + _check_mem_usage(running_log, [27, 0, 0, 0]) + + @pytest.mark.parametrize("model_name,model_path", [ ("Llama-3.1-8B-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct"), ]) diff --git a/tests/unittest/_torch/modeling/test_modeling_llama.py b/tests/unittest/_torch/modeling/test_modeling_llama.py index 0fdfa7ff0fa..9cf2beb3d52 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama.py @@ -516,7 +516,7 @@ def run_forward(input_ids, position_ids, attn_metadata): use_spec_decoding = True is_spec_dec_tree = True is_spec_dec_dynamic_tree = True - max_draft_tokens = gen_input_ids_0.size(-1) - 1 + max_total_draft_tokens = gen_input_ids_0.size(-1) - 1 attn_metadata_gen_phase_0 = metadata_cls( seq_lens=torch.tensor([gen_input_ids_0.size(-1)], dtype=torch.int), @@ -540,10 +540,12 @@ def run_forward(input_ids, position_ids, attn_metadata): packed_mask=spec_decoding_packed_mask) attn_metadata_gen_phase_0.update_spec_dec_param( + batch_size=batch_size, is_spec_decoding_enabled=is_spec_decoding_enabled, is_spec_dec_dynamic_tree=is_spec_dec_dynamic_tree, is_spec_dec_tree=is_spec_dec_tree, - max_draft_tokens=max_draft_tokens, + max_draft_len=max_total_draft_tokens, + max_total_draft_tokens=max_total_draft_tokens, spec_decoding_tensor=spec_decoding_tensor, ) @@ -586,10 +588,12 @@ def run_forward(input_ids, position_ids, attn_metadata): [gen_input_ids_1.size(-1)], dtype=torch.int) attn_metadata_gen_phase_0.kv_cache_params.num_cached_tokens_per_seq = num_cached_tokens_per_seq_1 attn_metadata_gen_phase_0.update_spec_dec_param( + batch_size=batch_size, is_spec_decoding_enabled=is_spec_decoding_enabled, is_spec_dec_tree=is_spec_dec_tree, is_spec_dec_dynamic_tree=False, - max_draft_tokens=gen_input_ids_1.size(-1) - 1) + max_draft_len=gen_input_ids_1.size(-1) - 1, + max_total_draft_tokens=gen_input_ids_1.size(-1) - 1) gen_position_ids_1 = [ torch.full( @@ -630,10 +634,12 @@ def run_forward(input_ids, position_ids, attn_metadata): is_spec_dec_tree=is_spec_dec_tree, is_spec_dec_dynamic_tree=False) attn_metadata_ref.update_spec_dec_param( + batch_size=batch_size, is_spec_decoding_enabled=is_spec_decoding_enabled, is_spec_dec_tree=is_spec_dec_tree, is_spec_dec_dynamic_tree=False, - max_draft_tokens=gen_input_ids_ref.size(-1) - 1, + max_draft_len=gen_input_ids_ref.size(-1) - 1, + max_total_draft_tokens=gen_input_ids_ref.size(-1) - 1, ) gen_position_ids_ref = [ diff --git a/tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py b/tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py new file mode 100644 index 00000000000..6ad12313a32 --- /dev/null +++ b/tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py @@ -0,0 +1,663 @@ +import math +import os +import sys +import unittest + +import torch +from utils.llm_data import llm_models_root + +from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata +from tensorrt_llm._torch.speculative.drafting_loops import \ + prepare_for_generation_with_tree_decoding +from tensorrt_llm._torch.speculative.eagle3 import (Eagle3ResourceManager, + Eagle3SpecMetadata) +from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager +from tensorrt_llm.llmapi import EagleDecodingConfig + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + + +def test_draft_token_static_tree_prepare_for_generation(): + # Fix parameters + models_path = llm_models_root() + eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" # It will not actually be used. + use_dynamic_tree = False + max_new_tokens = 128 + kv_cache_manager = None + + # Create related object and run test + def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens, + max_draft_len, eagle_choices, input_seq_lens_cuda, + input_kv_lens_cuda, input_num_accepted_draft_tokens, + input_hidden_states_write_indices, + input_hidden_states_read_indices, input_draft_tokens, + input_position_ids, ref_inputs_ids, ref_position_ids, + ref_attn_metadata, ref_spec_metadata): + + # 1) Create attention metadata + attn_metadata = TrtllmAttentionMetadata( + max_num_requests=max_batch_size, + max_num_tokens=max_new_tokens, + kv_cache_manager=kv_cache_manager) + + # Set initial values + attn_metadata._seq_lens_cuda = input_seq_lens_cuda # set from input + attn_metadata.kv_lens_cuda = input_kv_lens_cuda # set from input + attn_metadata._seq_lens = torch.zeros([max_batch_size], device='cpu') + attn_metadata.host_request_types = torch.zeros([max_batch_size], + device='cuda') + attn_metadata.spec_decoding_position_offsets = torch.zeros( + [max_batch_size, max_total_draft_tokens + 1], + dtype=torch.int, + device='cuda', + ) + attn_metadata.spec_decoding_packed_mask = torch.zeros( + [ + max_batch_size, max_total_draft_tokens + 1, + math.ceil(max_total_draft_tokens / 32) + ], + dtype=torch.int, + device='cuda', + ) + attn_metadata.spec_decoding_generation_lengths = torch.zeros( + [max_batch_size], + dtype=torch.int, + device='cuda', + ) + + # 2) Create spec metadata + spec_config = EagleDecodingConfig( + max_draft_len=max_draft_len, + max_total_draft_tokens=max_total_draft_tokens, + speculative_model_dir=eagle_model_dir, + eagle3_one_model=False, + eagle_choices=eagle_choices, + use_dynamic_tree=use_dynamic_tree, + ) + + eagle3_resource_manager = Eagle3ResourceManager( + spec_config, + torch.bfloat16, + 1024, + max_batch_size, + max_new_tokens, + max_new_tokens, + ) + + spec_tree_manager = SpecTreeManager( + max_num_requests=max_batch_size, + use_dynamic_tree=spec_config.use_dynamic_tree, + max_draft_len=spec_config.max_draft_len, + max_total_draft_tokens=spec_config.max_total_draft_tokens, + eagle_choices=spec_config.eagle_choices, + dynamic_tree_max_topK=spec_config.dynamic_tree_max_topK, + ) + + spec_metadata = Eagle3SpecMetadata( + max_draft_len=spec_config.max_draft_len, + spec_dec_mode=spec_config.spec_dec_mode, + max_num_requests=max_batch_size, + num_layers=32, + hidden_size=1024, + max_num_tokens=max_new_tokens, + dtype=torch.bfloat16, + is_draft_model=True, + eagle3_resource_manager=eagle3_resource_manager, + layers_to_capture=spec_config.eagle3_layers_to_capture, + max_total_draft_tokens=spec_config.max_total_draft_tokens, + eagle_choices=spec_config.eagle_choices, + is_spec_dec_tree=spec_config.eagle_choices is not None + or spec_config.use_dynamic_tree, + is_spec_dec_dynamic_tree=spec_config.use_dynamic_tree, + ) + + # Set initial values + spec_metadata.num_accepted_draft_tokens = input_num_accepted_draft_tokens # set from input + spec_metadata.num_tokens = 0 + spec_metadata.hidden_states_write_indices = input_hidden_states_write_indices # set from input + spec_metadata.hidden_states_read_indices = input_hidden_states_read_indices # set from input + + # 3) Run the function + output_input_ids, output_position_ids = prepare_for_generation_with_tree_decoding( + prepare_for_layer_idx=prepare_for_layer_idx, + new_draft_tokens=input_draft_tokens, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + spec_tree_manager=spec_tree_manager, + position_ids=input_position_ids, + ) + + # Compare input_ids and position_ids + print( + f"output_input_ids: {output_input_ids}, ref_output_input_ids: {ref_inputs_ids}" + ) + print( + f"output_position_ids: {output_position_ids}, ref_output_position_ids: {ref_position_ids}" + ) + + # Compare the attention metadata + print( + f"attn_metadata.kv_lens_cuda: {attn_metadata.kv_lens_cuda}, ref_attn_metadata.kv_lens_cuda: {ref_attn_metadata['kv_lens_cuda']}" + ) + print( + f"attn_metadata._seq_lens: {attn_metadata._seq_lens}, ref_attn_metadata._seq_lens: {ref_attn_metadata['_seq_lens']}" + ) + print( + f"attn_metadata._seq_lens_cuda: {attn_metadata._seq_lens_cuda}, ref_attn_metadata._seq_lens_cuda: {ref_attn_metadata['_seq_lens_cuda']}" + ) + print( + f"attn_metadata.host_request_types: {attn_metadata.host_request_types}, ref_attn_metadata.host_request_types: {ref_attn_metadata['host_request_types']}" + ) + print( + f"attn_metadata.num_contexts: {attn_metadata.num_contexts}, ref_attn_metadata.num_contexts: {ref_attn_metadata['num_contexts']}" + ) + print( + f"attn_metadata.spec_decoding_position_offsets: {attn_metadata.spec_decoding_position_offsets}, ref_attn_metadata.spec_decoding_position_offsets: {ref_attn_metadata['spec_decoding_position_offsets']}" + ) + print( + f"attn_metadata.spec_decoding_packed_mask: {attn_metadata.spec_decoding_packed_mask}, ref_attn_metadata.spec_decoding_packed_mask: {ref_attn_metadata['spec_decoding_packed_mask']}" + ) + print( + f"attn_metadata.spec_decoding_generation_lengths: {attn_metadata.spec_decoding_generation_lengths}, ref_attn_metadata.spec_decoding_generation_lengths: {ref_attn_metadata['spec_decoding_generation_lengths']}" + ) + + # Compare the spec metadata + print( + f"spec_metadata.num_tokens: {spec_metadata.num_tokens}, ref_spec_metadata.num_tokens: {ref_spec_metadata['num_tokens']}" + ) + print( + f"spec_metadata.gather_ids: {spec_metadata.gather_ids}, ref_spec_metadata.gather_ids: {ref_spec_metadata['gather_ids']}" + ) + print( + f"spec_metadata.hidden_states_read_indices: {spec_metadata.hidden_states_read_indices}, ref_spec_metadata.hidden_states_read_indices: {ref_spec_metadata['hidden_states_read_indices']}" + ) + print( + f"spec_metadata.hidden_states_write_indices: {spec_metadata.hidden_states_write_indices}, ref_spec_metadata.hidden_states_write_indices: {ref_spec_metadata['hidden_states_write_indices']}" + ) + + assert torch.all(output_input_ids == ref_inputs_ids) + assert torch.all(output_position_ids == ref_position_ids) + assert torch.all( + attn_metadata.kv_lens_cuda == ref_attn_metadata['kv_lens_cuda']) + assert torch.all( + attn_metadata._seq_lens == ref_attn_metadata['_seq_lens']) + assert torch.all( + attn_metadata._seq_lens_cuda == ref_attn_metadata['_seq_lens_cuda']) + assert torch.all(attn_metadata.host_request_types == + ref_attn_metadata['host_request_types']) + assert torch.all( + torch.tensor(attn_metadata.num_contexts) == torch.tensor( + ref_attn_metadata['num_contexts'])) + assert torch.all(attn_metadata.spec_decoding_generation_lengths == + ref_attn_metadata['spec_decoding_generation_lengths']) + total_process_tokens = attn_metadata.spec_decoding_generation_lengths.sum( + ) + print(f"total_process_tokens: {total_process_tokens}") + assert torch.all( + attn_metadata.spec_decoding_position_offsets.reshape( + -1)[:total_process_tokens] == + ref_attn_metadata['spec_decoding_position_offsets'] + [:total_process_tokens]) + assert torch.all( + attn_metadata.spec_decoding_packed_mask.reshape( + -1, attn_metadata.spec_decoding_packed_mask.size( + -1))[:total_process_tokens, :] == + ref_attn_metadata['spec_decoding_packed_mask'] + [:total_process_tokens, :]) + + assert torch.all( + torch.tensor(spec_metadata.num_tokens) == torch.tensor( + ref_spec_metadata['num_tokens'])) + assert torch.all( + spec_metadata.gather_ids == ref_spec_metadata['gather_ids']) + assert torch.all( + spec_metadata.hidden_states_read_indices[:ref_spec_metadata[ + 'hidden_states_read_indices'].shape[0]] == + ref_spec_metadata['hidden_states_read_indices']) + assert torch.all( + spec_metadata.hidden_states_write_indices[:ref_spec_metadata[ + 'hidden_states_write_indices'].shape[0]] == + ref_spec_metadata['hidden_states_write_indices']) + + ################## CASE 1 static tree, batch size = 1, prefill, prepare_for_layer_idx = 1 ########################## + max_batch_size = 1 + prepare_for_layer_idx = 1 + max_total_draft_tokens = 12 + max_draft_len = 3 + eagle_choices = [[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [0, 0, 0], [0, 1, 1], [1, 0, 0]] + + prompt_len_1 = 15 + + input_draft_tokens = [ + torch.tensor([20, 21, 22], dtype=torch.int32, + device='cuda').reshape(1, 3) + ] + input_position_ids = torch.arange(prompt_len_1, + dtype=torch.int32, + device='cuda').reshape(1, prompt_len_1) + + input_seq_lens_cuda = torch.tensor([prompt_len_1], + dtype=torch.int32, + device='cuda') + input_kv_lens_cuda = torch.tensor([prompt_len_1], + dtype=torch.int32, + device='cuda') + input_num_accepted_draft_tokens = torch.tensor([prompt_len_1 - 1], + dtype=torch.int32, + device='cuda') + input_hidden_states_write_indices = torch.zeros([max_new_tokens], + dtype=torch.long, + device='cuda') + input_hidden_states_write_indices[:prompt_len_1] = torch.arange( + prompt_len_1, dtype=torch.long, device='cuda') + input_hidden_states_read_indices = torch.zeros([max_new_tokens], + dtype=torch.long, + device='cuda') + + ref_inputs_ids = torch.tensor([20, 21, 22], + dtype=torch.int32, + device='cuda') + ref_position_ids = torch.tensor([15, 15, 15], + dtype=torch.int32, + device='cuda') + + ref_attn_metadata = {} + ref_attn_metadata['kv_lens_cuda'] = torch.tensor([18], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['_seq_lens'] = torch.tensor([3], + dtype=torch.int32, + device='cpu') + ref_attn_metadata['_seq_lens_cuda'] = torch.tensor([3], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['host_request_types'] = torch.tensor([0], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['num_contexts'] = 0 + ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor( + [0, 0, 0], dtype=torch.int32, device='cuda') + ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor( + [1, 2, 4], dtype=torch.int32, device='cuda').unsqueeze(1) + ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor( + [3], dtype=torch.int32, device='cuda') + + ref_spec_metadata = {} + ref_spec_metadata['num_tokens'] = 3 + ref_spec_metadata['gather_ids'] = torch.tensor([0, 1, 2], + dtype=torch.int32, + device='cuda') + ref_spec_metadata['hidden_states_read_indices'] = torch.tensor( + [14, 14, 14], dtype=torch.int32, device='cuda') + ref_spec_metadata['hidden_states_write_indices'] = torch.tensor( + [15, 16, 17], dtype=torch.int32, device='cuda') + + run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens, + max_draft_len, eagle_choices, input_seq_lens_cuda, + input_kv_lens_cuda, input_num_accepted_draft_tokens, + input_hidden_states_write_indices, + input_hidden_states_read_indices, input_draft_tokens, + input_position_ids, ref_inputs_ids, ref_position_ids, + ref_attn_metadata, ref_spec_metadata) + + ################## CASE 2 static tree, batch size = 2, both prefill, prepare_for_layer_idx = 1 ########################## + max_batch_size = 2 + prepare_for_layer_idx = 1 + max_total_draft_tokens = 12 + max_draft_len = 3 + eagle_choices = [[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [0, 0, 0], [0, 1, 1], [1, 0, 0]] + + prompt_len_1 = 15 + prompt_len_2 = 10 + + input_draft_tokens = [ + torch.tensor([20, 21, 22, 30, 31, 32], dtype=torch.int32, + device='cuda').reshape(2, 3) + ] + input_position_ids_1 = torch.arange(prompt_len_1, + dtype=torch.int32, + device='cuda').reshape(1, prompt_len_1) + input_position_ids_2 = torch.arange(prompt_len_2, + dtype=torch.int32, + device='cuda').reshape(1, prompt_len_2) + input_position_ids = torch.cat([input_position_ids_1, input_position_ids_2], + dim=1) + + input_seq_lens_cuda = torch.tensor([prompt_len_1, prompt_len_2], + dtype=torch.int32, + device='cuda') + input_kv_lens_cuda = torch.tensor([prompt_len_1, prompt_len_2], + dtype=torch.int32, + device='cuda') + input_num_accepted_draft_tokens = torch.tensor( + [prompt_len_1 - 1, prompt_len_2 - 1], dtype=torch.int32, device='cuda') + input_hidden_states_write_indices = torch.zeros([max_new_tokens], + dtype=torch.long, + device='cuda') + input_hidden_states_write_indices[:prompt_len_1 + + prompt_len_2] = torch.arange( + prompt_len_1 + prompt_len_2, + dtype=torch.long, + device='cuda') + input_hidden_states_read_indices = torch.zeros([max_new_tokens], + dtype=torch.long, + device='cuda') + + ref_inputs_ids = torch.tensor([20, 21, 22, 30, 31, 32], + dtype=torch.int32, + device='cuda') + ref_position_ids = torch.tensor([15, 15, 15, 10, 10, 10], + dtype=torch.int32, + device='cuda') + + ref_attn_metadata = {} + ref_attn_metadata['kv_lens_cuda'] = torch.tensor([18, 13], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['_seq_lens'] = torch.tensor([3, 3], + dtype=torch.int32, + device='cpu') + ref_attn_metadata['_seq_lens_cuda'] = torch.tensor([3, 3], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['host_request_types'] = torch.tensor([0, 0], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['num_contexts'] = 0 + ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor( + [0, 0, 0, 0, 0, 0], dtype=torch.int32, device='cuda') + ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor( + [1, 2, 4, 1, 2, 4], dtype=torch.int32, device='cuda').unsqueeze(1) + ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor( + [3, 3], dtype=torch.int32, device='cuda') + + ref_spec_metadata = {} + ref_spec_metadata['num_tokens'] = 6 + ref_spec_metadata['gather_ids'] = torch.tensor([0, 1, 2, 3, 4, 5], + dtype=torch.int32, + device='cuda') + ref_spec_metadata['hidden_states_read_indices'] = torch.tensor( + [14, 14, 14, 24, 24, 24], dtype=torch.int32, device='cuda') + ref_spec_metadata['hidden_states_write_indices'] = torch.tensor( + [15, 16, 17, 25, 26, 27], dtype=torch.int32, device='cuda') + + run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens, + max_draft_len, eagle_choices, input_seq_lens_cuda, + input_kv_lens_cuda, input_num_accepted_draft_tokens, + input_hidden_states_write_indices, + input_hidden_states_read_indices, input_draft_tokens, + input_position_ids, ref_inputs_ids, ref_position_ids, + ref_attn_metadata, ref_spec_metadata) + + ################## CASE 3 static tree, batch size = 2, one prefill, one decode, prepare_for_layer_idx = 1 ########################## + max_batch_size = 2 + prepare_for_layer_idx = 1 + max_total_draft_tokens = 12 + max_draft_len = 3 + eagle_choices = [[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [0, 0, 0], [0, 1, 1], [1, 0, 0]] + + prompt_len_1 = 15 # prefill + prompt_len_2 = 18 + seq_len_2 = 3 + 1 # accepted 2 draft tokens. For the 0-th drafter layer, the sequence length will be pad to max_draft_len + 1 + + input_draft_tokens = [ + torch.tensor([20, 21, 22, 30, 31, 32], dtype=torch.int32, + device='cuda').reshape(2, 3) + ] # sample from the 0-th drafter layer + input_position_ids_1 = torch.arange(prompt_len_1, + dtype=torch.int32, + device='cuda').reshape(1, prompt_len_1) + input_position_ids_2 = torch.tensor( + [18, 19, 20, 21], dtype=torch.int32, + device='cuda').reshape(1, max_draft_len + 1) # for target model + input_position_ids = torch.cat([input_position_ids_1, input_position_ids_2], + dim=1) + + input_seq_lens_cuda = torch.tensor([prompt_len_1, seq_len_2], + dtype=torch.int32, + device='cuda') + input_kv_lens_cuda = torch.tensor([prompt_len_1, prompt_len_2 + seq_len_2], + dtype=torch.int32, + device='cuda') + input_num_accepted_draft_tokens = torch.tensor( + [prompt_len_1 - 1, 2], dtype=torch.int32, + device='cuda') # Suppose 2 are received. + input_hidden_states_write_indices = torch.zeros([max_new_tokens], + dtype=torch.long, + device='cuda') + input_hidden_states_write_indices[:prompt_len_1 + seq_len_2] = torch.arange( + prompt_len_1 + seq_len_2, dtype=torch.long, device='cuda') + input_hidden_states_read_indices = torch.zeros([max_new_tokens], + dtype=torch.long, + device='cuda') + + ref_inputs_ids = torch.tensor([20, 21, 22, 30, 31, 32], + dtype=torch.int32, + device='cuda') + ref_position_ids = torch.tensor([15, 15, 15, 21, 21, 21], + dtype=torch.int32, + device='cuda') + + ref_attn_metadata = {} + ref_attn_metadata['kv_lens_cuda'] = torch.tensor([18, 24], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['_seq_lens'] = torch.tensor([3, 3], + dtype=torch.int32, + device='cpu') + ref_attn_metadata['_seq_lens_cuda'] = torch.tensor([3, 3], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['host_request_types'] = torch.tensor([0, 0], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['num_contexts'] = 0 + ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor( + [0, 0, 0, 0, 0, 0], dtype=torch.int32, device='cuda') + ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor( + [1, 2, 4, 1, 2, 4], dtype=torch.int32, device='cuda').unsqueeze(1) + ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor( + [3, 3], dtype=torch.int32, device='cuda') + + ref_spec_metadata = {} + ref_spec_metadata['num_tokens'] = 6 + ref_spec_metadata['gather_ids'] = torch.tensor([0, 1, 2, 3, 4, 5], + dtype=torch.int32, + device='cuda') + ref_spec_metadata['hidden_states_read_indices'] = torch.tensor( + [14, 14, 14, 17, 17, 17], dtype=torch.int32, device='cuda') + ref_spec_metadata['hidden_states_write_indices'] = torch.tensor( + [15, 16, 17, 18, 19, 20], dtype=torch.int32, device='cuda') + + run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens, + max_draft_len, eagle_choices, input_seq_lens_cuda, + input_kv_lens_cuda, input_num_accepted_draft_tokens, + input_hidden_states_write_indices, + input_hidden_states_read_indices, input_draft_tokens, + input_position_ids, ref_inputs_ids, ref_position_ids, + ref_attn_metadata, ref_spec_metadata) + + ################## CASE 4 static tree, batch size = 1, one prefill, prepare_for_layer_idx = 2 ########################## + max_batch_size = 1 + prepare_for_layer_idx = 2 + max_total_draft_tokens = 12 + max_draft_len = 3 + eagle_choices = [[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [0, 0, 0], [0, 1, 1], [1, 0, 0]] + + prompt_len_1 = 15 + + input_draft_tokens = [ + torch.tensor([20, 21, 22], dtype=torch.int32, device='cuda').reshape( + 1, 3), # sample after the 0-th drafter layer + torch.tensor([30, 31, 32, 33, 34, 35], dtype=torch.int32, + device='cuda').reshape( + 1, 6) # sample after the 1-th drafter layer + ] + input_position_ids = torch.tensor([15, 15, 15], + dtype=torch.int32, + device='cuda').unsqueeze(0) + input_seq_lens_cuda = torch.tensor([3], dtype=torch.int32, device='cuda') + input_kv_lens_cuda = torch.tensor([18], dtype=torch.int32, device='cuda') + input_num_accepted_draft_tokens = torch.tensor([prompt_len_1 - 1], + dtype=torch.int32, + device='cuda') + + input_hidden_states_read_indices = torch.zeros([max_new_tokens], + dtype=torch.long, + device='cuda') + input_hidden_states_read_indices[:3] = torch.tensor([14, 14, 14], + dtype=torch.long, + device='cuda') + input_hidden_states_write_indices = torch.zeros([max_new_tokens], + dtype=torch.long, + device='cuda') + input_hidden_states_write_indices[:3] = torch.tensor([15, 16, 17], + dtype=torch.long, + device='cuda') + + ref_inputs_ids = torch.tensor([20, 21, 30, 31, 33], + dtype=torch.int32, + device='cuda') + ref_position_ids = torch.tensor([15, 15, 16, 16, 16], + dtype=torch.int32, + device='cuda') + + ref_attn_metadata = {} + ref_attn_metadata['kv_lens_cuda'] = torch.tensor([20], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['_seq_lens'] = torch.tensor([5], + dtype=torch.int32, + device='cpu') + ref_attn_metadata['_seq_lens_cuda'] = torch.tensor([5], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['host_request_types'] = torch.tensor([0], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['num_contexts'] = 0 + ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor( + [0, 0, 1, 1, 1], dtype=torch.int32, device='cuda') + ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor( + [1, 2, 5, 9, 18], dtype=torch.int32, device='cuda').unsqueeze(1) + ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor( + [5], dtype=torch.int32, device='cuda') + + ref_spec_metadata = {} + ref_spec_metadata['num_tokens'] = 5 + ref_spec_metadata['gather_ids'] = torch.tensor([2, 3, 4], + dtype=torch.int32, + device='cuda') + ref_spec_metadata['hidden_states_read_indices'] = torch.tensor( + [14, 14, 15, 15, 16], dtype=torch.int32, device='cuda') + ref_spec_metadata['hidden_states_write_indices'] = torch.tensor( + [15, 16, 18, 19, 21], dtype=torch.int32, device='cuda') + + run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens, + max_draft_len, eagle_choices, input_seq_lens_cuda, + input_kv_lens_cuda, input_num_accepted_draft_tokens, + input_hidden_states_write_indices, + input_hidden_states_read_indices, input_draft_tokens, + input_position_ids, ref_inputs_ids, ref_position_ids, + ref_attn_metadata, ref_spec_metadata) + + ################## CASE 5 static tree, batch size = 2, one prefill, one decode, prepare_for_layer_idx = 2 ########################## + max_batch_size = 2 + prepare_for_layer_idx = 2 + max_total_draft_tokens = 12 + max_draft_len = 3 + eagle_choices = [[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [0, 0, 0], [0, 1, 1], [1, 0, 0]] + + prompt_len_1 = 15 + prompt_len_2 = 18 # decode + + input_draft_tokens = [ + torch.tensor([20, 21, 22, 23, 24, 25], dtype=torch.int32, + device='cuda').reshape( + 2, 3), # sample after the 0-th drafter layer + torch.tensor([30, 31, 32, 33, 34, 35, 40, 41, 42, 43, 44, 45], + dtype=torch.int32, + device='cuda').reshape( + 2, 6) # sample after the 1-th drafter layer + ] + input_position_ids = torch.tensor([15, 15, 15, 21, 21, 21], + dtype=torch.int32, + device='cuda').unsqueeze(0) + input_seq_lens_cuda = torch.tensor([3, 3], dtype=torch.int32, device='cuda') + input_kv_lens_cuda = torch.tensor([18, 24], + dtype=torch.int32, + device='cuda') + input_num_accepted_draft_tokens = torch.tensor([prompt_len_1 - 1, 2], + dtype=torch.int32, + device='cuda') + + input_hidden_states_read_indices = torch.zeros([max_new_tokens], + dtype=torch.long, + device='cuda') + input_hidden_states_read_indices[:6] = torch.tensor( + [14, 14, 14, 26, 26, 26], dtype=torch.long, device='cuda') + input_hidden_states_write_indices = torch.zeros([max_new_tokens], + dtype=torch.long, + device='cuda') + input_hidden_states_write_indices[:6] = torch.tensor( + [15, 16, 17, 27, 28, 29], dtype=torch.long, device='cuda') + + ref_inputs_ids = torch.tensor([20, 21, 30, 31, 33, 23, 24, 40, 41, 43], + dtype=torch.int32, + device='cuda') + ref_position_ids = torch.tensor([15, 15, 16, 16, 16, 21, 21, 22, 22, 22], + dtype=torch.int32, + device='cuda') + + ref_attn_metadata = {} + ref_attn_metadata['kv_lens_cuda'] = torch.tensor([20, 26], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['_seq_lens'] = torch.tensor([5, 5], + dtype=torch.int32, + device='cpu') + ref_attn_metadata['_seq_lens_cuda'] = torch.tensor([5, 5], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['host_request_types'] = torch.tensor([0], + dtype=torch.int32, + device='cuda') + ref_attn_metadata['num_contexts'] = 0 + ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor( + [0, 0, 1, 1, 1, 0, 0, 1, 1, 1], dtype=torch.int32, device='cuda') + ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor( + [1, 2, 5, 9, 18, 1, 2, 5, 9, 18], dtype=torch.int32, + device='cuda').unsqueeze(1) + ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor( + [5, 5], dtype=torch.int32, device='cuda') + + ref_spec_metadata = {} + ref_spec_metadata['num_tokens'] = 10 + ref_spec_metadata['gather_ids'] = torch.tensor([2, 3, 4, 7, 8, 9], + dtype=torch.int32, + device='cuda') + ref_spec_metadata['hidden_states_read_indices'] = torch.tensor( + [14, 14, 15, 15, 16, 26, 26, 27, 27, 28], + dtype=torch.int32, + device='cuda') + ref_spec_metadata['hidden_states_write_indices'] = torch.tensor( + [15, 16, 18, 19, 21, 27, 28, 30, 31, 33], + dtype=torch.int32, + device='cuda') + + run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens, + max_draft_len, eagle_choices, input_seq_lens_cuda, + input_kv_lens_cuda, input_num_accepted_draft_tokens, + input_hidden_states_write_indices, + input_hidden_states_read_indices, input_draft_tokens, + input_position_ids, ref_inputs_ids, ref_position_ids, + ref_attn_metadata, ref_spec_metadata) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py b/tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py index eac47f48e60..1a3100d879e 100644 --- a/tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py +++ b/tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py @@ -5,23 +5,30 @@ import torch from utils.llm_data import llm_models_root -from tensorrt_llm import SamplingParams -from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, - SamplingConfig) -from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler +from tensorrt_llm._torch.speculative.drafting_loops import ChainDrafter from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager from tensorrt_llm.llmapi import EagleDecodingConfig sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +class DummyModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.model_config = None + self.config = None + self.model = {} + + def forward(self, *args, **kwargs) -> torch.Tensor: + pass + + def test_draft_token_static_tree_sampling(): # Fix parameters models_path = llm_models_root() eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" # It will not actually be used. - beam_width = 1 use_dynamic_tree = False - max_new_tokens = 128 # Create related object and run test def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, @@ -43,57 +50,28 @@ def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, eagle_choices=spec_config.eagle_choices, dynamic_tree_max_topK=spec_config.dynamic_tree_max_topK, ) - spec_tree_manager.cur_draft_layer_idx = draft_layer_id - torch_sampler = TorchSampler( - TorchSampler.Args( - max_draft_len=spec_config.max_draft_len, - max_seq_len=1024, - max_total_draft_tokens=spec_config.max_total_draft_tokens, - max_num_sequences=max_batch_size, - max_beam_width=beam_width)) - # Prepare tree sampling inputs - scheduled_requests = [] - for req_id in range(max_batch_size): - req = LlmRequest( - request_id=req_id, - max_new_tokens=max_new_tokens, - input_tokens=range(req_id * 10, (req_id + 1) * 10), - sampling_config=SamplingConfig( - SamplingParams()._get_sampling_config()), - is_streaming=False, - ) - scheduled_requests.append(req) - seq_slots = torch.tensor(range(max_batch_size), - dtype=torch.int64, - device='cuda') - new_tokens = torch.zeros((spec_config.max_total_draft_tokens + 1, - max_batch_size, beam_width), - dtype=torch.int, - device='cuda') - - model_outputs = { - "logits": logits, - } + # Create the chain drafter + chain_drafter = ChainDrafter( + max_draft_len=spec_config.max_draft_len, + max_total_draft_tokens=spec_config.max_total_draft_tokens, + draft_model=DummyModel(), + ) - new_tokens_host = torch_sampler._tree_sampling_batch( - requests=scheduled_requests, - max_num_sequences=max_batch_size, - seq_slots=seq_slots, - model_outputs=model_outputs, - spec_tree_manager=spec_tree_manager) - new_tokens_host = new_tokens_host.squeeze(dim=-1).transpose( - 0, 1) # shape: [max_batch_size, max_total_draft_tokens + 1] + output_tokens = chain_drafter.sample( + draft_layer_idx=draft_layer_id, + batch_size=max_batch_size, + logits=logits, + spec_tree_manager=spec_tree_manager, + ) print( - f"new_tokens_host.shape: {new_tokens_host.shape}, new_tokens_host: {new_tokens_host}" + f"ref_new_tokens.shape: {ref_new_tokens.shape}, ref_new_tokens: {ref_new_tokens}" ) print( - f"ref_new_tokens.shape: {ref_new_tokens.shape}, ref_new_tokens: {ref_new_tokens}" + f"output_tokens.shape: {output_tokens.shape}, output_tokens: {output_tokens}" ) - - assert torch.all(new_tokens_host[:, :num_new_draft_tokens] == - ref_new_tokens[:, :num_new_draft_tokens]) + assert torch.all(output_tokens == ref_new_tokens) ################## CASE 1 static tree, batch size = 1, draft_layer_id = 0 ########################## max_batch_size = 1 @@ -109,10 +87,9 @@ def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, ], # top3 indices = [4, 1, 9] ], device='cuda') - ref_new_tokens = torch.tensor( - [ - [4, 1, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], device='cpu') # shape: [max_batch_size, max_total_draft_tokens + 1] + ref_new_tokens = torch.tensor([ + [4, 1, 9], + ], device='cuda') # shape: [max_batch_size, num_new_draft_tokens] num_new_draft_tokens = 3 run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, max_draft_len, eagle_choices, logits, ref_new_tokens, @@ -136,10 +113,9 @@ def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, ], # top1 indices = [9] ], device='cuda') - ref_new_tokens = torch.tensor( - [ - [4, 1, 9, 7, 1, 9, 0, 0, 0, 0, 0, 0, 0], - ], device='cpu') # shape: [max_batch_size, max_total_draft_tokens + 1] + ref_new_tokens = torch.tensor([ + [4, 1, 9, 7, 1, 9], + ], device='cuda') # shape: [max_batch_size, num_new_draft_tokens] num_new_draft_tokens = 6 run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, max_draft_len, eagle_choices, logits, ref_new_tokens, @@ -162,10 +138,9 @@ def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, ], # top1 indices = [9] ], device='cuda') - ref_new_tokens = torch.tensor( - [ - [4, 7, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], device='cpu') # shape: [max_batch_size, max_total_draft_tokens + 1] + ref_new_tokens = torch.tensor([ + [4, 7, 9], + ], device='cuda') # shape: [max_batch_size, num_new_draft_tokens] num_new_draft_tokens = 3 run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, max_draft_len, eagle_choices, logits, ref_new_tokens, @@ -186,12 +161,10 @@ def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, ], # top3 indices = [4, 2, 9] ], device='cuda') - ref_new_tokens = torch.tensor( - [ - [4, 1, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [4, 2, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], - device='cpu') # shape: [max_batch_size, max_total_draft_tokens + 1] + ref_new_tokens = torch.tensor([ + [4, 1, 9], + [4, 2, 9], + ], device='cuda') # shape: [max_batch_size, num_new_draft_tokens] num_new_draft_tokens = 3 run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, max_draft_len, eagle_choices, logits, ref_new_tokens, @@ -222,10 +195,9 @@ def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, device='cuda') ref_new_tokens = torch.tensor( [ - [4, 1, 9, 7, 1, 9, 0, 0, 0, 0, 0, 0, 0], - [5, 1, 8, 6, 1, 8, 0, 0, 0, 0, 0, 0, 0], - ], - device='cpu') # shape: [max_batch_size, max_total_draft_tokens + 1] + [4, 1, 9, 7, 1, 9], + [5, 1, 8, 6, 1, 8], + ], device='cuda') # shape: [max_batch_size, num_new_draft_tokens] num_new_draft_tokens = 6 run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, max_draft_len, eagle_choices, logits, ref_new_tokens, @@ -254,12 +226,10 @@ def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, ], # top1 indices = [0] ], device='cuda') - ref_new_tokens = torch.tensor( - [ - [4, 7, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [5, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], - device='cpu') # shape: [max_batch_size, max_total_draft_tokens + 1] + ref_new_tokens = torch.tensor([ + [4, 7, 9], + [5, 8, 0], + ], device='cuda') # shape: [max_batch_size, num_new_draft_tokens] num_new_draft_tokens = 3 run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, max_draft_len, eagle_choices, logits, ref_new_tokens, @@ -288,14 +258,8 @@ def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, device='cuda') ref_new_tokens = torch.tensor( [ - [ - 6, 9, 4, 5, 18, 8, 12, 15, 11, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0 - ], - ], - device='cpu') # shape: [max_batch_size, max_total_draft_tokens + 1] + [6, 9, 4, 5, 18, 8, 12, 15, 11, 17], + ], device='cuda') # shape: [max_batch_size, num_new_draft_tokens] num_new_draft_tokens = 10 run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, max_draft_len, eagle_choices, logits, ref_new_tokens, @@ -390,44 +354,8 @@ def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, 11, # top-1 14, # top-1 10, # top-1 - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0 ]], - device='cpu') # shape: [max_batch_size, max_total_draft_tokens + 1] + device='cuda') # shape: [max_batch_size, num_new_draft_tokens] num_new_draft_tokens = 28 run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, max_draft_len, eagle_choices, logits, ref_new_tokens,