From 462c89dcc68faa2e2e3f9f4bebbcb5af5fda712f Mon Sep 17 00:00:00 2001 From: Yue Weng <25103990+yweng0828@users.noreply.github.com> Date: Wed, 1 Oct 2025 06:42:42 +0000 Subject: [PATCH 1/2] add runtime logic Signed-off-by: Yue Weng <25103990+yweng0828@users.noreply.github.com> --- .../_torch/attention_backend/interface.py | 6 +- .../_torch/attention_backend/trtllm.py | 124 ++- .../_torch/auto_deploy/shim/ad_executor.py | 13 +- .../_torch/models/modeling_deepseekv3.py | 2 +- tensorrt_llm/_torch/pyexecutor/_util.py | 15 +- .../_torch/pyexecutor/cuda_graph_runner.py | 9 +- .../_torch/pyexecutor/model_engine.py | 176 ++++- tensorrt_llm/_torch/pyexecutor/py_executor.py | 18 +- .../_torch/pyexecutor/py_executor_creator.py | 12 +- tensorrt_llm/_torch/pyexecutor/sampler.py | 85 +- tensorrt_llm/_torch/speculative/drafter.py | 9 +- tensorrt_llm/_torch/speculative/eagle3.py | 2 +- tensorrt_llm/_torch/speculative/interface.py | 25 +- .../_torch/speculative/model_drafter.py | 99 ++- .../_torch/speculative/spec_tree_manager.py | 193 ++++- tensorrt_llm/_torch/speculative/utils.py | 4 +- tensorrt_llm/llmapi/llm_args.py | 58 +- tests/integration/defs/test_e2e.py | 35 + .../test_lists/test-db/l0_h100.yml | 1 + .../test_draft_token_tree_runtime.py | 724 ++++++++++++++++++ .../test_draft_token_tree_verification.py | 21 +- 21 files changed, 1400 insertions(+), 231 deletions(-) create mode 100644 tests/unittest/_torch/speculative/test_draft_token_tree_runtime.py diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index cdbe7c8c978..59df03af4e0 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -329,8 +329,10 @@ def restore_from_spec_dec(self) -> None: setattr(self, f, v) self._saved_tensors.clear() - def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree, - is_spec_dec_dynamic_tree, max_draft_tokens): + def update_spec_dec_param(self, scheduled_requests, + is_spec_decoding_enabled, spec_metadata, + spec_tree_manager, max_draft_len, + max_total_draft_tokens): """ Hook to be called when using TRTLLM attention backend in spec-dec mode. """ diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 90bc6df7848..ccbd30ad897 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1027,8 +1027,10 @@ def prepare_context_mla_with_cached_kv(self, self.ctx_kv_indptr[:self.num_contexts + 1].copy_( self.host_ctx_kv_indptr[:self.num_contexts + 1], non_blocking=True) - def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree, - is_spec_dec_dynamic_tree, max_draft_tokens): + def update_spec_dec_param(self, scheduled_requests, + is_spec_decoding_enabled, spec_metadata, + spec_tree_manager, max_draft_len, + max_total_draft_tokens): # spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree. self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version( ) < 100 @@ -1036,64 +1038,130 @@ def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree, # use_spec_decoding is default to true by default, change in runtime by layers / requests self.use_spec_decoding = self.is_spec_decoding_enabled - self.is_spec_dec_tree = is_spec_dec_tree - self.is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree + self.is_spec_dec_tree = False if spec_tree_manager is None else True + self.is_spec_dec_dynamic_tree = False if spec_tree_manager is None else spec_tree_manager.use_dynamic_tree # Parameters can be fixed and not changed during runtime if the if self.is_spec_decoding_enabled: - self.spec_decoding_position_offsets = torch.empty( - [self.max_num_requests, max_draft_tokens + 1], + self.spec_decoding_position_offsets = torch.zeros( + [self.max_num_requests, max_total_draft_tokens + 1], dtype=torch.int, device='cuda', ) - self.spec_decoding_packed_mask = torch.empty( + self.spec_decoding_packed_mask = torch.zeros( [ - self.max_num_requests, max_draft_tokens + 1, - math.ceil(max_draft_tokens / 32) + self.max_num_requests, max_total_draft_tokens + 1, + math.ceil(max_total_draft_tokens / 32) ], dtype=torch.int, device='cuda', ) - self.spec_decoding_generation_lengths = torch.empty( + self.spec_decoding_generation_lengths = torch.zeros( [self.max_num_requests], dtype=torch.int, device='cuda', ) - if self.is_spec_dec_dynamic_tree: - assert False, "currently dynamic tree is not supported" + # Prepare the spec-dec mask, position offset and generation length for static tree of dynamic tree. + if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree: + assert spec_metadata.spec_dec_mode.is_eagle3( + ), "Tree decoding is only supported for Eagle3 now" + # If is the drafter model + if spec_metadata.is_draft_model: + # Layer cur_draft_layer_idx already executed, we are now preparing the spec-dec params for executing cur_draft_layer_idx + 1 layer. + cur_draft_layer_idx = spec_tree_manager.cur_draft_layer_idx + + # If is the dynamic tree + if self.is_spec_dec_dynamic_tree: + # TODO: add dynamic tree logic + assert False, "Dynamic tree is not supported yet" + # If is the static tree + else: + # For the first drafter layer, we treat it as a chunked prefill. + # It takes the accepted tokens + new generated tokens as input. + if cur_draft_layer_idx == 0: + # In fact, the drafter model calls 'fmha_v2' or 'xqa' at this time. + self.generate_spec_decoding_position_offsets( + max_total_draft_tokens=max_draft_len) + self.generate_spec_decoding_packed_mask( + max_total_draft_tokens=max_draft_len) + self.generate_spec_decoding_generation_length( + max_total_draft_tokens=max_draft_len) + else: + num_process_draft_tokens = spec_tree_manager.get_cumulative_draft_lens( + cur_draft_layer_idx - 1) + + self.spec_decoding_position_offsets[:, :num_process_draft_tokens].copy_( + spec_tree_manager.spec_dec_position_offsets[ + 0:, 1:1 + num_process_draft_tokens] - + 1, # Exclude the root node + non_blocking=True) + self.spec_decoding_packed_mask[:, : + num_process_draft_tokens, :].copy_( + spec_tree_manager + . + spec_dec_pack_mask[ + 0:, 1: + num_process_draft_tokens + + 1, :] / 2, + non_blocking=True + ) + self.spec_decoding_generation_lengths[:].fill_( + num_process_draft_tokens) + + # If is the target model + else: + # If is the dynamic tree + if self.is_spec_dec_dynamic_tree: + # TODO: add dynamic tree logic + assert False, "Dynamic tree is not supported yet" + # If is the static tree + else: + self.spec_decoding_position_offsets[ + :, + ].copy_( + spec_tree_manager.spec_dec_position_offsets[0:, :], + non_blocking=True) + self.spec_decoding_packed_mask[:, :, :].copy_( + spec_tree_manager.spec_dec_pack_mask[0:, :, :], + non_blocking=True) + self.spec_decoding_generation_lengths[:].fill_( + spec_tree_manager.max_total_draft_tokens + 1) else: + # Prepare for the linear-tree. # Populate the mask that won't change during inference phase. self.generate_spec_decoding_position_offsets( - max_draft_tokens=max_draft_tokens) + max_total_draft_tokens=max_total_draft_tokens) self.generate_spec_decoding_packed_mask( - max_draft_tokens=max_draft_tokens) + max_total_draft_tokens=max_total_draft_tokens) self.generate_spec_decoding_generation_length( - max_draft_tokens=max_draft_tokens) + max_total_draft_tokens=max_total_draft_tokens) - def generate_spec_decoding_position_offsets(self, max_draft_tokens): - assert not self.is_spec_dec_tree, "only chained/linear tree is supported now" - position_offset = torch.arange(max_draft_tokens + 1, + def generate_spec_decoding_position_offsets(self, max_total_draft_tokens): + position_offset = torch.arange(max_total_draft_tokens + 1, dtype=torch.int, device='cpu', pin_memory=True) # fill all the batches with same position offset - self.spec_decoding_position_offsets.copy_(position_offset, - non_blocking=True) - - def generate_spec_decoding_packed_mask(self, max_draft_tokens): - assert not self.is_spec_dec_tree, "only chained/linear tree is supported now" - dummy_idx = torch.arange(max_draft_tokens + 1) + self.spec_decoding_position_offsets[:, :max_total_draft_tokens + + 1].copy_(position_offset, + non_blocking=True) + + def generate_spec_decoding_packed_mask(self, max_total_draft_tokens): + # TODO: fix this limitation + assert max_total_draft_tokens < 32, "max_total_draft_tokens should be less than 32, will be fixed later" + dummy_idx = torch.arange(max_total_draft_tokens + 1) spec_decoding_packed_mask = torch.pow(2, dummy_idx + 1) - 1 - self.spec_decoding_packed_mask[:, :, 0].copy_(spec_decoding_packed_mask, - non_blocking=True) + self.spec_decoding_packed_mask[:, :max_total_draft_tokens + 1, + 0].copy_(spec_decoding_packed_mask, + non_blocking=True) - def generate_spec_decoding_generation_length(self, max_draft_tokens): + def generate_spec_decoding_generation_length(self, max_total_draft_tokens): spec_decoding_generation_length = torch.full((self.max_num_requests, ), - max_draft_tokens + 1) + max_total_draft_tokens + 1) self.spec_decoding_generation_lengths[:self.max_num_requests].copy_( spec_decoding_generation_length, non_blocking=True) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 7b7b52e584c..73f1e0524ad 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -315,13 +315,11 @@ def create_autodeploy_executor(ad_config: LlmArgs): max_draft_len = ( 0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_len ) - max_total_draft_tokens = 0 - if ad_config.speculative_config is None: - max_total_draft_tokens = 0 - elif hasattr(ad_config.speculative_config, "max_total_draft_tokens"): - max_total_draft_tokens = ad_config.speculative_config.max_total_draft_tokens - else: - max_total_draft_tokens = max_draft_len + max_total_draft_tokens = ( + 0 + if ad_config.speculative_config is None + else ad_config.speculative_config.max_total_draft_tokens + ) # initialize model engine engine = ADEngine.build_from_config(ad_config=ad_config) @@ -374,6 +372,7 @@ def create_autodeploy_executor(ad_config: LlmArgs): max_input_len=ad_config.max_input_len, max_batch_size=ad_config.max_batch_size, max_draft_len=max_draft_len, + max_total_draft_tokens=max_total_draft_tokens, max_beam_width=ad_config.max_beam_width, ) return py_executor diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index ba2b33f18b1..b6289672583 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -512,7 +512,7 @@ def __init__( aux_stream: Optional[torch.cuda.Stream] = None, ): config = model_config.pretrained_config - predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1 + predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1 super().__init__(hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 71719d4f8ab..6bb1cd2b836 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -202,10 +202,10 @@ def _get_token_num_for_estimation(self) -> int: if not pytorch_backend_config.disable_overlap_scheduler: num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1 if spec_cfg is not None: - num_extra_tokens_per_seq += spec_cfg.max_draft_len + num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens if spec_cfg is not None: - num_extra_tokens_per_seq += spec_cfg.max_draft_len + num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg) if self._dummy_reqs is None: @@ -751,6 +751,8 @@ def create_py_executor_instance( max_beam_width=max_beam_width, max_draft_len=spec_config.max_draft_len if spec_config is not None else 0, + max_total_draft_tokens=spec_config.max_total_draft_tokens + if spec_config is not None else 0, kv_cache_transceiver=kv_cache_transceiver, guided_decoder=guided_decoder, start_worker=start_worker, @@ -767,13 +769,8 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int, max_num_sequences = max_batch_size * mapping.pp_size max_draft_len = (0 if speculative_config is None else speculative_config.max_draft_len) - max_total_draft_tokens = 0 - if speculative_config is None: - max_total_draft_tokens = 0 - elif hasattr(speculative_config, 'max_total_draft_tokens'): - max_total_draft_tokens = speculative_config.max_total_draft_tokens - else: - max_total_draft_tokens = max_draft_len + max_total_draft_tokens = (0 if speculative_config is None else + speculative_config.max_total_draft_tokens) return TorchSampler.Args( max_seq_len=max_seq_len, diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 4c097ac0d2a..71ee654f1b2 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -93,7 +93,8 @@ def enable_spec_decode(self): @property def max_possible_draft_len(self): engine = self._get_engine() - return (engine.original_max_draft_len if self.enable_spec_decode else 0) + return (engine.original_max_total_draft_tokens + if self.enable_spec_decode else 0) def get_graph_key( self, @@ -102,10 +103,12 @@ def get_graph_key( engine = self._get_engine() if engine.is_draft_model and spec_resource_manager is not None and isinstance( spec_resource_manager, Eagle3ResourceManager): + # If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'. + # Because we will pad the input to 'max_draft_len' length for the first draft layer. draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0 key = (batch_size, draft_len, spec_resource_manager.is_first_draft) else: - draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0 + draft_len = self.spec_config.max_total_draft_tokens if self.enable_spec_decode else 0 key = (batch_size, draft_len, False) return key @@ -344,7 +347,7 @@ def _get_padded_batch(self, batch: ScheduledRequests, self.padding_dummy_request = kv_cache_manager.add_dummy_requests( [CUDA_GRAPH_DUMMY_REQUEST_ID], is_gen=True, - max_num_draft_tokens=engine.max_draft_len, + max_num_draft_tokens=engine.max_total_draft_tokens, use_mrope=engine.use_mrope, max_beam_width=engine.max_beam_width)[0] self.padding_dummy_request.is_cuda_graph_dummy = True diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index a0f650f6aff..29d9da9bd7e 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -92,11 +92,11 @@ def warmup(self, resource_manager: ResourceManager) -> None: def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int], max_batch_size: int, max_num_tokens: int, - max_draft_len: int, + max_total_draft_tokens: int, enable_padding: bool) -> list[int]: # This is the largest possible batch size for a pure decoding batch. max_cuda_graph_bs = min(max_batch_size, - int(max_num_tokens / (1 + max_draft_len))) + int(max_num_tokens / (1 + max_total_draft_tokens))) result = [] # This function assumes cuda_graph_batch_sizes is sorted @@ -159,11 +159,13 @@ def __init__( ExpertStatistic.create(self.dist.rank) self.pytorch_backend_config = pytorch_backend_config self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0 + self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0 # The draft model won't have any draft tokens attached to # generation requests when we invoke it autoregressively if spec_config is not None and is_draft_model: spec_config.max_draft_len = 0 + spec_config.max_total_draft_tokens = 0 self.spec_config = spec_config self.is_spec_decode = spec_config is not None self.enable_spec_decode = self.is_spec_decode @@ -269,7 +271,7 @@ def __init__( self.spec_metadata = None update_spec_config_from_model_config(self.spec_config, self.model.config) - max_num_draft_tokens = self.original_max_draft_len * batch_size + max_num_draft_tokens = self.original_max_total_draft_tokens * batch_size self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ), dtype=torch.int, device='cuda') @@ -289,9 +291,11 @@ def __init__( self.without_logits = self.spec_config.spec_dec_mode.without_logits( ) or self.model_is_wrapped self.max_draft_len = spec_config.max_draft_len + self.max_total_draft_tokens = spec_config.max_total_draft_tokens else: self.without_logits = False self.max_draft_len = 0 + self.max_total_draft_tokens = 0 self.guided_decoder: Optional[CapturableGuidedDecoder] = None @@ -312,7 +316,7 @@ def __init__( self._cuda_graph_batch_sizes = _filter_cuda_graph_batch_sizes( pytorch_backend_config.cuda_graph_batch_sizes, self.batch_size, - self.max_num_tokens, self.original_max_draft_len, + self.max_num_tokens, self.original_max_total_draft_tokens, self._cuda_graph_padding_enabled ) if pytorch_backend_config.cuda_graph_batch_sizes else [] @@ -353,7 +357,7 @@ def __init__( @property def runtime_draft_len(self): - return self.max_draft_len if self.enable_spec_decode else 0 + return self.max_total_draft_tokens if self.enable_spec_decode else 0 def set_lora_model_config(self, lora_target_modules: list[str], @@ -460,6 +464,8 @@ def warmup(self, resource_manager: ResourceManager) -> None: def get_num_extra_decoding_steps(): if isinstance(self.model, ChainDrafter): + # We should use max_draft_len instead of max_total_draft_tokens here, + # because max_draft_len indicates the real number of draft layers. return self.model.max_draft_len else: assert not self.model_is_wrapped, ( @@ -597,7 +603,7 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int): num_ctx_requests + num_gen_tokens)), token_nums=[1] * num_gen_tokens, is_gen=True, - max_num_draft_tokens=self.max_draft_len, + max_num_draft_tokens=self.max_total_draft_tokens, use_mrope=self.use_mrope) if spec_resource_manager is not None: spec_resource_manager.add_dummy_requests(request_ids=list( @@ -612,7 +618,7 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int): curr_max_num_tokens = min( kv_cache_manager.get_num_available_tokens( - self.original_max_draft_len), self.max_num_tokens, + self.original_max_total_draft_tokens), self.max_num_tokens, self.batch_size * (self.max_seq_len - 1)) def get_autotune_warmup_request(): @@ -695,11 +701,11 @@ def release_batch(result: ScheduledRequests | None): cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes, reverse=True) # Create CUDA graphs for different draft lengths - draft_lengths = [self.max_draft_len] + draft_lengths = [self.max_total_draft_tokens] # For non-draft model, we also capture the CUDA graph instance for draft length 0, # so that when we disable spec decode at runtime, we can still run the captured graph. # Note that for one engine mode, we are not able to turn off spec decode at runtime. - if (not self.is_draft_model and self.max_draft_len > 0 + if (not self.is_draft_model and self.max_total_draft_tokens > 0 and not self.spec_config.spec_dec_mode.use_one_engine() # Assume that speculation is always on if the user didn't give us a max_concurrency # value. This will save on memory. @@ -707,7 +713,7 @@ def release_batch(result: ScheduledRequests | None): draft_lengths.append(0) if self.is_spec_decode and self.is_draft_model and spec_resource_manager is not None and isinstance( spec_resource_manager, Eagle3ResourceManager): - draft_lengths.append(self.original_max_draft_len) + draft_lengths.append(self.original_max_total_draft_tokens) for bs in cuda_graph_batch_sizes: if bs > self.batch_size: @@ -934,7 +940,7 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]): """ if self.enable_spec_decode and not self._disable_overlap_scheduler: # When enabling overlap scheduler, the kv cache for draft tokens will - # be prepared in advance by using the max_draft_len. But we need to use + # be prepared in advance by using the max_total_draft_tokens. But we need to use # new_tokens_lens_device to get the real past kv lengths and the # correct position ids. And to avoid blocking the async data transfer, # we need to preprocess the inputs in forward to update the position_ids and @@ -1089,7 +1095,8 @@ def _prepare_tp_inputs( attn_metadata: AttentionMetadata, spec_metadata: Optional[SpecMetadata] = None, new_tensors_device: Optional[SampleStateTensors] = None, - cache_indirection_buffer: Optional[torch.Tensor] = None): + cache_indirection_buffer: Optional[torch.Tensor] = None, + resource_manager: Optional[ResourceManager] = None): """ Prepare inputs for Pytorch Model. """ @@ -1209,6 +1216,12 @@ def _prepare_tp_inputs( assert spec_config.spec_dec_mode.support_overlap_scheduler( ), f"{spec_config.decoding_type} does not support overlap scheduler" + spec_resource_manager, spec_tree_manager = None, None + if spec_config is not None: + spec_resource_manager = resource_manager.get_resource_manager( + ResourceManagerType.SPEC_RESOURCE_MANAGER) + spec_tree_manager = spec_resource_manager.spec_tree_manager + # will contain previous batch indices of generation requests previous_batch_indices = [] previous_pos_indices = [] @@ -1244,10 +1257,23 @@ def _prepare_tp_inputs( list( range(len(position_ids), len(position_ids) + 1 + self.runtime_draft_len))) - position_ids.extend( - list( - range(past_seen_token_num, - past_seen_token_num + 1 + num_draft_tokens))) + + if spec_tree_manager is None: + position_ids.extend( + list( + range(past_seen_token_num, + past_seen_token_num + 1 + num_draft_tokens))) + else: + # for tree spec decoding + cur_position_ids = [past_seen_token_num + ] * (1 + num_draft_tokens) + slot_idx = request.py_seq_slot if spec_tree_manager.use_dynamic_tree else 0 + cur_position_ids = [ + pos + offset for pos, offset in zip( + cur_position_ids, spec_tree_manager. + spec_dec_position_offsets[slot_idx, :].tolist()) + ] + position_ids.extend(cur_position_ids) num_cached_tokens_per_seq.append(past_seen_token_num) # update batch index request.py_batch_idx = request.py_seq_slot @@ -1294,6 +1320,8 @@ def _prepare_tp_inputs( draft_lens.append(0) begin_compute = len( all_prompt_tokens) - self.original_max_draft_len - 1 + # We use 'original_max_draft_len' here but no 'original_max_total_draft_tokens' because for the first draft layer, + # the input tokens are the draft tokens that accepted by the target model and padded to 'max_draft_len' length. end_compute = begin_compute + self.original_max_draft_len + 1 prompt_tokens = all_prompt_tokens[begin_compute:end_compute] position_ids.extend( @@ -1334,13 +1362,83 @@ def _prepare_tp_inputs( if beam == first_beam: previous_batch_indices.append(request.py_batch_idx) past_seen_token_num = request.max_beam_num_tokens - position_ids.append(past_seen_token_num) - num_cached_tokens_per_seq.append(past_seen_token_num) + + # num_cached_tokens_per_seq.append(past_seen_token_num) prompt_lengths.append(request.py_prompt_len) draft_lens.append(0) - sequence_lengths.append(1) num_accepted_draft_tokens.append(0) - gather_ids.append(len(position_ids) - 1) + # For the two-model + static/dynamic tree. + # During the generation phase of the drafter model, the request will be treated as a normal generation request. + if self.is_draft_model and spec_tree_manager is not None and spec_tree_manager.cur_draft_layer_idx > 0: + # prepare input for the `cur_draft_layer_idx` layer. `cur_draft_layer_idx` has not been executed yet. + # 1) Prepare 'input_ids' + # Get the draft layer id to be executed. + cur_draft_layer_idx = spec_tree_manager.cur_draft_layer_idx + # How many draft tokens have been generated after executing the `cur_draft_layer_idx - 1` layer. + cumulative_draft_lens = spec_tree_manager.get_cumulative_draft_lens( + cur_draft_layer_idx - 1) + # How many draft tokens are generated in the `cur_draft_layer_idx - 1` layer. + cur_layer_draft_len = spec_tree_manager.get_current_layer_draft_len( + cur_draft_layer_idx - 1) + + # How many historical draft tokens need to be extracted from the request.get_tokens(0). + historical_draft_tokens_len = cumulative_draft_lens - cur_layer_draft_len + assert historical_draft_tokens_len >= 0 + + num_cached_tokens_per_seq.append( + past_seen_token_num - historical_draft_tokens_len) + + cumulative_draft_tokens = [] + # Because we need all the draft tokens that have been generated in this iteration as input, + # and these "historical" draft tokens exist in the request tokens, we need to extract them. + if historical_draft_tokens_len > 0: + cumulative_draft_tokens.extend( + request.get_tokens(0) + [-historical_draft_tokens_len:]) + + # Extend `cur_draft_layer_idx - 1` layer's draft tokens, which store in the new_tokens_device + # TODO: .cpu() may have performance issue. + + cumulative_draft_tokens.extend( + new_tokens_device.reshape( + -1, spec_tree_manager.max_total_draft_tokens + 1, + 1)[request.py_seq_slot, :cur_layer_draft_len, + 0].cpu().tolist()) + + assert len(cumulative_draft_tokens) == cumulative_draft_lens + input_ids.extend(cumulative_draft_tokens) + + # 2) Prepare position_ids + # TODO: .cpu() may have performance issue. + position_ids_start_value = past_seen_token_num - historical_draft_tokens_len + slot_idx = request.py_seq_slot if spec_tree_manager.use_dynamic_tree else 0 + cur_position_ids = spec_tree_manager.spec_dec_position_offsets[ + slot_idx, 1:cumulative_draft_lens + + 1].cpu().tolist() # '+1' is to skip the root node + position_ids.extend([ + position_ids_start_value + i - 1 + for i in cur_position_ids + ] # '-1' is one offset to the root node + ) + + # 3) Prepare sequence_lengths + sequence_lengths.append(cumulative_draft_lens) + + # 4) Prepare gather_ids + # Although we take all "historical" draft tokens as input, only some of them will be expanded in the next drafter layer. + # We only need to gather_ids these draft tokens. + cur_gather_ids = spec_tree_manager.get_gather_ids( + cur_draft_layer_idx) + position_start_offset = len( + position_ids) - cumulative_draft_lens - 1 + gather_ids.extend( + [position_start_offset + i for i in cur_gather_ids]) + + else: + num_cached_tokens_per_seq.append(past_seen_token_num) + position_ids.append(past_seen_token_num) + sequence_lengths.append(1) + gather_ids.append(len(position_ids) - 1) # Multimodal multimodal_params = MultimodalParams( @@ -1541,8 +1639,11 @@ def previous_seq_slots_device(): attn_metadata.num_chunked_ctx_requests = 0 if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx( self.attn_backend): - attn_metadata.num_contexts += len(extend_requests) - attn_metadata.num_chunked_ctx_requests = len(extend_requests) + # For the linear-tree, we will use fmha_v2 for the target model in generation phase. + # But for the static/dynamic tree, we will use xqa for the target model in generation phase. + if spec_tree_manager is None: + attn_metadata.num_contexts += len(extend_requests) + attn_metadata.num_chunked_ctx_requests = len(extend_requests) attn_metadata.kv_cache_params = KVCacheParams( use_cache=True, @@ -2098,14 +2199,14 @@ def _get_lora_params_from_requests(self, return lora_params @nvtx_range("_prepare_inputs") - def _prepare_inputs( - self, - scheduled_requests: ScheduledRequests, - kv_cache_manager: KVCacheManager, - attn_metadata: AttentionMetadata, - spec_metadata: Optional[SpecMetadata] = None, - new_tensors_device: Optional[SampleStateTensors] = None, - cache_indirection_buffer: Optional[torch.Tensor] = None): + def _prepare_inputs(self, + scheduled_requests: ScheduledRequests, + kv_cache_manager: KVCacheManager, + attn_metadata: AttentionMetadata, + spec_metadata: Optional[SpecMetadata] = None, + new_tensors_device: Optional[SampleStateTensors] = None, + cache_indirection_buffer: Optional[torch.Tensor] = None, + resource_manager: Optional[ResourceManager] = None): if self.mapping is not None and 'cp_type' in self.mapping.cp_config: cp_type = self.mapping.cp_config['cp_type'] if CpType.STAR == cp_type: @@ -2117,7 +2218,8 @@ def _prepare_inputs( return self._prepare_tp_inputs(scheduled_requests, kv_cache_manager, attn_metadata, spec_metadata, new_tensors_device, - cache_indirection_buffer) + cache_indirection_buffer, + resource_manager) @torch.inference_mode() @with_model_extra_attrs(lambda self: self.model.extra_attrs) @@ -2139,14 +2241,16 @@ def forward( spec_metadata = self._set_up_spec_metadata(spec_resource_manager, no_cache=kv_cache_manager is None) + # attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode( spec_resource_manager, self.is_draft_model, self.attn_backend, - self.model_is_wrapped) + self.model_is_wrapped, spec_metadata.is_spec_dec_tree) attn_metadata.update_spec_dec_param( - is_spec_dec_mode, spec_metadata.is_spec_dec_tree, - spec_metadata.is_spec_dec_dynamic_tree, - self.original_max_draft_len) + scheduled_requests, is_spec_dec_mode, spec_metadata, + spec_resource_manager.spec_tree_manager, + self.original_max_draft_len, + self.original_max_total_draft_tokens) else: spec_resource_manager = None spec_metadata = None @@ -2183,7 +2287,7 @@ def forward( inputs, gather_ids = self._prepare_inputs( padded_requests, kv_cache_manager, attn_metadata, spec_metadata, - new_tensors_device, cache_indirection_buffer) + new_tensors_device, cache_indirection_buffer, resource_manager) self.iter_counter += 1 diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 02e95fa1e5b..d9db32b0a64 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -156,6 +156,7 @@ def __init__(self, max_batch_size: int = 8, max_beam_width: int = 1, max_draft_len: int = 0, + max_total_draft_tokens: int = 0, kv_cache_transceiver: Optional[KvCacheTransceiver] = None, guided_decoder: Optional[GuidedDecoder] = None, garbage_collection_gen0_threshold: Optional[int] = None, @@ -191,6 +192,7 @@ def __init__(self, self.active = True self.max_beam_width = max_beam_width self.max_draft_len = max_draft_len + self.max_total_draft_tokens = max_total_draft_tokens self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens self.print_log = model_engine.pytorch_backend_config.print_iter_log self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats @@ -982,7 +984,7 @@ def _prepare_and_schedule_batch(self): self.use_spec_decode = self.drafter.should_use_spec_decode( self.active_requests, self.max_batch_size, self.model_engine.max_num_tokens, - self.model_engine.spec_config.max_draft_len) + self.model_engine.spec_config.max_total_draft_tokens) self.model_engine.enable_spec_decode = self.use_spec_decode # Set up draft_tokens in active_requests, because they could be used in the scheduling stage. @@ -991,10 +993,10 @@ def _prepare_and_schedule_batch(self): LlmRequestState.GENERATION_IN_PROGRESS, LlmRequestState.DISAGG_GENERATION_INIT): continue - max_draft_len = self.model_engine.spec_config.max_draft_len + max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens request.draft_tokens = [ 0 - ] * max_draft_len if max_draft_len > 0 else [] + ] * max_total_draft_tokens if max_total_draft_tokens > 0 else [] # When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch, # we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet. @@ -1161,11 +1163,11 @@ def _prepare_draft_requests(self): continue req.py_last_draft_tokens = req.py_draft_tokens - max_draft_len = self.model_engine.spec_config.max_draft_len + max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens - if max_draft_len > 0 and self.use_spec_decode: - req.py_draft_tokens = [0] * max_draft_len - req.py_draft_pages_allocated = max_draft_len + if max_total_draft_tokens > 0 and self.use_spec_decode: + req.py_draft_tokens = [0] * max_total_draft_tokens + req.py_draft_pages_allocated = max_total_draft_tokens else: req.py_draft_tokens = [] req.py_draft_pages_allocated = 0 @@ -1553,7 +1555,7 @@ def _pad_attention_dp_dummy_request(self): request_ids=[0], is_gen=True, prepare_resource=True, - max_num_draft_tokens=self.max_draft_len, + max_num_draft_tokens=self.max_total_draft_tokens, )[0] llm_request.is_attention_dp_dummy = True spec_resource_manager = self.resource_manager.get_resource_manager( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 1c28a7b43b7..0e786c58041 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -345,7 +345,8 @@ def create_py_executor( use_chain_drafter = ( guided_decoding_config is None and draft_spec_config._allow_greedy_draft_tokens - and pytorch_backend_config.attn_backend == "TRTLLM") + and pytorch_backend_config.attn_backend == "TRTLLM" + and spec_config.is_linear_tree) else: use_chain_drafter = False @@ -355,7 +356,7 @@ def create_py_executor( def drafting_loop_wrapper(model): from tensorrt_llm._torch.speculative.drafting_loops import \ ChainDrafter - + assert spec_config.is_linear_tree, "ChainDrafter only supports linear tree now" return ChainDrafter(spec_config.max_draft_len, model) else: drafting_loop_wrapper = None @@ -396,11 +397,11 @@ def drafting_loop_wrapper(model): if not pytorch_backend_config.disable_overlap_scheduler: model_engine_max_seq_len = model_engine.max_seq_len + 1 if spec_config is not None: - model_engine_max_seq_len += spec_config.max_draft_len + model_engine_max_seq_len += spec_config.max_total_draft_tokens if spec_config is not None: model_engine_max_seq_len += get_num_extra_kv_tokens(spec_config) - model_engine_max_seq_len += spec_config.max_draft_len + model_engine_max_seq_len += spec_config.max_total_draft_tokens max_seq_len = model_engine_max_seq_len max_num_tokens = model_engine.max_num_tokens @@ -470,7 +471,8 @@ def drafting_loop_wrapper(model): "vocab_size_padded": model_engine.model.vocab_size_padded } if spec_config is not None: - kwargs["max_num_draft_tokens"] = spec_config.max_draft_len + kwargs[ + "max_num_draft_tokens"] = spec_config.max_total_draft_tokens if spec_config is None or spec_config.spec_dec_mode.support_guided_decoder( ): diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 18edf8946c8..a876d0645d0 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -801,7 +801,7 @@ class Args: def __init__(self, args: Args): self.max_seq_len = args.max_seq_len - self.max_tokens = args.max_draft_len + 1 + self.max_tokens = args.max_total_draft_tokens + 1 assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1" self.max_num_sequences = args.max_num_sequences @@ -936,8 +936,8 @@ def _process_draft_tokens_tree(self, request: LlmRequest, we can find the longest match by comparing all the paths. Args: request: LlmRequest. The request with draft tokens. - new_tokens: torch.Tensor. [max_draft_len + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer. The tokens generated by the target model - The relationship between [max_draft_len + 1] and the draft token tree: + new_tokens: torch.Tensor. [max_total_draft_tokens + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer. The tokens generated by the target model + The relationship between [max_total_draft_tokens + 1] and the draft token tree: If the current node is accepted, what is the NEXT token_id that the target model will generate? For example, new_tokens[0, req_idx, 1] indicates the NEXT token_id sampled from the root node in the draft token tree if it is accepted. We know that the root node in the draft token tree is always accepted. Therefore, new_tokens[0, req_idx, 1] indicates the token_id following the root node, @@ -951,11 +951,15 @@ def _process_draft_tokens_tree(self, request: LlmRequest, if get_draft_token_length(request) == 0: cur_draft_layer_idx = spec_tree_manager.cur_draft_layer_idx - # TODO: For the last layer of the dynamic tree, we need to resampling all the draft tokens. - cur_layer_num_nodes = sum( - spec_tree_manager.get_top_k_list(cur_draft_layer_idx)) - for i in range(cur_layer_num_nodes): - new_token = add_token(request, new_tokens, beam=0, step=i) + if spec_tree_manager.use_dynamic_tree and cur_draft_layer_idx == spec_tree_manager.max_draft_len - 1: + # TODO: For the last layer of the dynamic tree, we need to resampling all the draft tokens. + raise NotImplementedError("Dynamic tree is not supported yet") + else: + # Add the draft tokens of the previous layer. + for i in range( + spec_tree_manager.get_current_layer_draft_len( + cur_draft_layer_idx - 1)): + new_token = add_token(request, new_tokens, beam=0, step=i) return 0 else: # handle the target model request @@ -964,7 +968,9 @@ def _process_draft_tokens_tree(self, request: LlmRequest, assert seq_slot is not None eagle_paths = spec_tree_manager.get_eagle_paths(seq_slot) - all_draft_tokens = request.py_draft_tokens # [max_total_draft_tokens] + all_draft_tokens = torch.tensor( + request.py_draft_tokens, dtype=torch.int, + device='cpu') # [max_total_draft_tokens] all_target_tokens = new_tokens[:, seq_slot, :].squeeze( -1) # [max_total_draft_tokens] assert all_target_tokens.shape[ @@ -1015,8 +1021,9 @@ def _process_draft_tokens_tree(self, request: LlmRequest, return num_accepted_draft_tokens - 1 - def _tree_sampling_batch(self, requests: list[LlmRequest], - max_num_sequences: int, seq_slots: torch.Tensor, + def _tree_sampling_batch(self, new_tokens_cuda: torch.Tensor, + requests: list[LlmRequest], max_num_sequences: int, + seq_slots: torch.Tensor, model_outputs: dict[str, torch.Tensor], spec_tree_manager: SpecTreeManager): @@ -1037,10 +1044,9 @@ def _tree_sampling_batch(self, requests: list[LlmRequest], assert len(top_k_list) == num_logits_per_request # Considering that the beam_width of spec-dec can only be 1, we ignore this dimension here. - new_draft_tokens_cuda = torch.empty( - (max_num_sequences, spec_tree_manager.max_total_draft_tokens + 1), - dtype=torch.int64, - device=raw_logits.device) + new_tokens_cuda = new_tokens_cuda.zero_() + new_tokens_cuda = new_tokens_cuda.reshape( + max_num_sequences, spec_tree_manager.max_total_draft_tokens + 1) top_k_list_cumsum = torch.cumsum(top_k_list, dim=0) # Different nodes have different topK value. @@ -1051,21 +1057,21 @@ def _tree_sampling_batch(self, requests: list[LlmRequest], # 3) Sample the logits according to the topK value. indices = torch.topk(logits, k=top_k_list_i, dim=-1).indices # 4) Write to the temporary output tensor. - new_draft_tokens_cuda[ + new_tokens_cuda[ seq_slots, top_k_list_cumsum[i] - - top_k_list_i:top_k_list_cumsum[i]] = indices[request_index] + top_k_list_i:top_k_list_cumsum[i]] = indices[request_index].to( + torch.int) # 5) Append eagle3 d2t. - self._apply_d2t(new_draft_tokens_cuda, model_outputs) + self._apply_d2t(new_tokens_cuda[seq_slots, :top_k_list_cumsum[-1]], + model_outputs) # 6) Copy back to the output tensor. - int_new_draft_tokens = new_draft_tokens_cuda.transpose(0, 1).to( + new_tokens_cuda = new_tokens_cuda.transpose(0, 1).to( torch.int, non_blocking=True).unsqueeze(dim=-1) + new_tokens_host = new_tokens_cuda.to("cpu", non_blocking=True) - new_draft_tokens_host = int_new_draft_tokens.to("cpu", - non_blocking=True) - - return new_draft_tokens_host + return new_tokens_host def _process_draft_tokens_rejection_sampling( self, request: LlmRequest, new_tokens: torch.Tensor) -> int: @@ -1115,10 +1121,22 @@ def process_draft_tokens( if request.py_draft_logits is None: spec_tree_manager = self.get_spec_tree_manager(resource_manager) if spec_tree_manager is not None: - num_accepted = self._process_draft_tokens_tree( - request, - new_tokens=new_tokens, - spec_tree_manager=spec_tree_manager) + if get_draft_token_length(request) == 0: + # draft request + num_draft_tokens_this_layer = spec_tree_manager.get_current_layer_draft_len( + spec_tree_manager.cur_draft_layer_idx - 1) + for i in range(num_draft_tokens_this_layer): + new_token = add_token(request, + new_tokens, + beam=self.BEAM, + step=i) + return 0 + else: + # target request + num_accepted = self._process_draft_tokens_tree( + request, + new_tokens=new_tokens, + spec_tree_manager=spec_tree_manager) else: num_accepted = self._process_draft_tokens_greedy( request, new_tokens=new_tokens) @@ -1150,7 +1168,7 @@ def update_requests( if req.state == LlmRequestState.GENERATION_COMPLETE: continue processed = 1 - num_accepted = self.process_draft_tokens(req, new_tokens, + num_accepted = self.process_draft_tokens(req, state.host.new_tokens, resource_manager) if get_draft_token_length(req) > 0: req.py_num_accepted_draft_tokens = num_accepted @@ -1665,11 +1683,12 @@ def _process_requests( if spec_tree_manager is not None and logits_cuda.size(0) == len( scheduled_requests.all_requests()): new_tokens_host = self._tree_sampling_batch( - requests, - self.max_num_sequences, - seq_slots, - model_outputs, - spec_tree_manager, + new_tokens_cuda=new_tokens_cuda, + requests=requests, + max_num_sequences=self.max_num_sequences, + seq_slots=seq_slots, + model_outputs=model_outputs, + spec_tree_manager=spec_tree_manager, ) return new_tokens_host diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index 74384206740..9dcab0a451b 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -29,14 +29,14 @@ def prepare_draft_tokens( @final def should_use_spec_decode(self, requests: List[LlmRequest], max_batch_size: int, max_num_tokens: int, - max_draft_len: int) -> bool: + max_total_draft_tokens: int) -> bool: """ You probably don't want to override this. ModelEngine assumes that speculation is always on if max_concurrency is not specified by the user's spec config. """ - # Inputs typically validated upstream: max_batch_size>0, max_num_tokens>0, max_draft_len>=0 + # Inputs typically validated upstream: max_batch_size>0, max_num_tokens>0, max_total_draft_tokens>=0 if self.max_concurrency is None: return True @@ -45,7 +45,7 @@ def should_use_spec_decode(self, requests: List[LlmRequest], if not requests or max_batch_size <= 0 or max_num_tokens <= 0: return False - tokens_per_request = 1 + max_draft_len + tokens_per_request = 1 + max_total_draft_tokens token_cap = max_num_tokens // tokens_per_request if token_cap <= 0: return False @@ -63,7 +63,8 @@ def pad_draft_tokens_for_cuda_graph( scheduled_requests: The scheduled requests to pad """ for req in scheduled_requests.generation_requests: - max_draft_tokens = self.max_draft_tokens + max_draft_tokens = self.max_draft_len num_draft_tokens = get_draft_token_length(req) + req.py_draft_tokens.extend( 0 for _ in range(max_draft_tokens - num_draft_tokens)) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 571850c82da..4da703d12a6 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -32,11 +32,11 @@ def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype, max_num_tokens: int): self.dtype = dtype self.max_draft_len = config.max_draft_len + self.max_total_draft_tokens = config.max_total_draft_tokens self.hidden_size = hidden_size self.max_num_requests = max_num_requests self.max_seq_len = max_seq_len self.slot_manager = SlotManager(max_num_requests) - self.max_total_draft_tokens = config.max_total_draft_tokens # empty hidden states tensor max_num_tokens = min(max_num_tokens, diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 16522e98320..2cad43c9cf8 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -117,14 +117,21 @@ def attention_need_spec_dec_mode( is_draft_model: bool, attention_backend: Type[AttentionBackend], use_chain_drafter: bool, + is_spec_dec_tree: bool, ): """ If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode). """ is_trtllm_attention = issubclass(attention_backend, TrtllmAttention) - return self.is_eagle3_one_model() or ( - self.is_eagle3() and spec_resource_manager.is_first_draft - and is_trtllm_attention and use_chain_drafter and is_draft_model) + return ( + self.is_eagle3_one_model() # one model + or (self.is_eagle3() and spec_resource_manager.is_first_draft + and is_trtllm_attention and use_chain_drafter and is_draft_model + ) # two model + the first drafter + use_chain_drafter + or ( + self.is_eagle3() and is_spec_dec_tree and is_trtllm_attention + ) # two model + non-linear tree (static tree or dynamic tree) (both target model and drafter model) + trtllm attention backend + ) @staticmethod def from_string(name: Optional[str]) -> "SpeculativeDecodingMode": @@ -140,8 +147,10 @@ class SpecMetadata: """ # The max number of requests in a single batch. max_num_requests: int - # The max number of draft tokens. + # The number of draft layers. (Also the number of draft tokens for the linear tree.) max_draft_len: int + # The max number of draft tokens for the static tree and dynamic tree . + max_total_draft_tokens: int # The number of gen-phase sequences in the batch. num_generations: int = 0 # Whether CUDA graph is enabled. @@ -177,9 +186,13 @@ class SpecMetadata: # The number of layers num_layers: int = 0 - # if spec-dec tree is a tree or a chain (linear tree) - is_spec_dec_tree: bool = False # if spec-dec tree wouldn't be changed at all, the mask won't be computed every step. + # NOTE: For the linear tree, though it can be treated as a special case of static tree. + # NOTE: But we do not set `is_spec_dec_tree` to True for this cases. + # NOTE: i.e., for the linear tree, is_spec_dec_tree == False and is_spec_dec_dynamic_tree == False. + # whether the spec-dec mode is a tree (can be static tree or dynamic tree). + is_spec_dec_tree: bool = False + # whether the spec-dec mode is a dynamic tree. is_spec_dec_dynamic_tree: bool = False def __post_init__(self): diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 0a1a58d8575..88f317b48ac 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -12,8 +12,7 @@ from ..pyexecutor.guided_decoder import GuidedDecoder from ..pyexecutor.handle_logits import HandleLogits from ..pyexecutor.llm_request import LlmRequest, LlmRequestState -from ..pyexecutor.resource_manager import (BaseResourceManager, ResourceManager, - ResourceManagerType) +from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager from ..pyexecutor.sampler import Sampler, SampleState, SampleStateTensors from ..pyexecutor.scheduler import ScheduledRequests from ..pyexecutor.seq_slot_manager import SeqSlotManager @@ -46,7 +45,8 @@ def __init__( self, spec_config: "DecodingBaseConfig", draft_model_engine: "ModelEngine", - max_draft_tokens: int, + max_draft_len: int, + max_total_draft_tokens: int, draft_seq_slot_manager: SeqSlotManager, sampler: Sampler, spec_resource_manager: Optional[BaseResourceManager] = None, @@ -57,8 +57,11 @@ def __init__( # Validate required parameters if draft_model_engine is None: raise ValueError("draft_model_engine cannot be None") - if max_draft_tokens < 0: - raise ValueError("max_draft_tokens must be >= 0") + if max_draft_len < 0: + raise ValueError("max_draft_len must be >= 0") + if max_total_draft_tokens < 0: + raise ValueError("max_total_draft_tokens must be >= 0") + assert max_draft_len <= max_total_draft_tokens # Model and resource management self.draft_model_engine = draft_model_engine @@ -67,7 +70,8 @@ def __init__( # Configuration self.spec_config = spec_config - self.max_draft_tokens = max_draft_tokens + self.max_draft_len = max_draft_len + self.max_total_draft_tokens = max_total_draft_tokens # Sampling self.sampler = sampler self.guided_decoder = guided_decoder @@ -78,6 +82,10 @@ def __init__( assert guided_decoder is None assert spec_config._allow_greedy_draft_tokens + self.spec_tree_manager = None + if self.spec_resource_manager is not None: + self.spec_tree_manager = self.spec_resource_manager.spec_tree_manager + def _create_draft_request(self, request: LlmRequest, input_tokens: Optional[List]) -> LlmRequest: """Create a draft request with common parameters.""" @@ -148,9 +156,9 @@ def _create_accepted_tokens_request_for_trtllm_attn( Create a chunked context request for accepted tokens. Only applicable if the draft model needs to recompute KV cache for accepted tokens (e.g. eagle 3) """ - # Pad input_tokens to max_draft_tokens + # Pad input_tokens to max_draft_len input_tokens.extend( - 0 for _ in range(self.max_draft_tokens - num_accepted_tokens)) + 0 for _ in range(self.max_draft_len - num_accepted_tokens)) new_request = self._create_draft_request(request, input_tokens) new_request.state = LlmRequestState.GENERATION_IN_PROGRESS new_request.py_num_accepted_draft_tokens = request.py_num_accepted_draft_tokens @@ -333,6 +341,8 @@ def sample_async( num_context_logits_prefix_sum, resource_manager) except Exception as e: + traceback.print_exc() + str(e) logger.error(f"Error in sampling: {str(e)}") return None @@ -345,18 +355,12 @@ def update_request_states(self, if request.context_remaining_length == 0: request.state = LlmRequestState.GENERATION_IN_PROGRESS - def update_cur_draft_layer_idx( - self, - cur_draft_layer_idx: int, - resource_manager: Optional[ResourceManager] = None): - spec_resource_manager = resource_manager.get_resource_manager( - ResourceManagerType.SPEC_RESOURCE_MANAGER) - if spec_resource_manager is None: - return None - - spec_tree_manager = spec_resource_manager.spec_tree_manager - if spec_tree_manager is not None: - spec_tree_manager.cur_draft_layer_idx = cur_draft_layer_idx + def update_cur_draft_layer_idx(self, cur_draft_layer_idx: int): + """ + Update the current draft layer index in spec tree manager. + """ + if self.spec_tree_manager is not None: + self.spec_tree_manager.cur_draft_layer_idx = cur_draft_layer_idx def update_requests( self, @@ -378,7 +382,16 @@ def process_decoded_tokens( self.draft_seq_slot_manager.free_resources(req) continue - target_model_req.py_draft_tokens.append(req.get_last_tokens(0)) + # linear-tree + if self.spec_tree_manager is None: + target_model_req.py_draft_tokens.append(req.get_last_tokens(0)) + # static tree or dynamic tree + else: + num_draft_tokens_cur_draft_layer = self.spec_tree_manager.get_current_layer_draft_len( + self.spec_tree_manager.cur_draft_layer_idx - 1) + target_model_req.py_draft_tokens.extend( + req.get_tokens(0)[-num_draft_tokens_cur_draft_layer:]) + target_model_req.py_draft_logits = req.py_result.generation_logits # forwards Nones if req.state != LlmRequestState.GENERATION_COMPLETE and len( target_model_req.py_draft_tokens @@ -457,9 +470,10 @@ def _convert_draft_tensors( if has_draft_tokens: # We already updated the target state, so the new_tokens_lens should be all ones. new_tokens_lens = torch.ones(batch_size, device=device) - next_draft_tokens = torch.zeros(batch_size, - self.max_draft_tokens, - device=device) + next_draft_tokens = torch.zeros( + batch_size, + self.max_total_draft_tokens, # wy: (?) + device=device) # Create a new SampleStateTensorsMTP object with the additional fields updated_tensors = SampleStateTensorsMTP( @@ -543,7 +557,7 @@ def process_static_draft_outputs( outputs_host = outputs.host.new_tokens outputs.sampler_event.synchronize() - for token_idx in range(self.max_draft_tokens): + for token_idx in range(self.max_total_draft_tokens): # wy: (?) for req_idx, req in enumerate(draft_batch.all_requests()): target_model_req = req_id_to_old_request[req.py_request_id] if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS: @@ -560,24 +574,26 @@ def process_static_draft_outputs( def process_dynamic_draft_outputs( self, - outputs: Any, + previous_draft_state: SampleState, req_id_to_old_request: Dict[int, LlmRequest], resource_manager: Optional[ResourceManager] = None) -> None: """ Process outputs from dynamic draft loop, update target requests, and clean up resources. """ - self.update_requests(outputs, resource_manager) - self.process_decoded_tokens(outputs.scheduled_requests, + self.update_requests(previous_draft_state, resource_manager) + self.process_decoded_tokens(previous_draft_state.scheduled_requests, req_id_to_old_request) def _execute_draft_iteration( - self, draft_batch: ScheduledRequests, - resource_manager: ResourceManager, - previous_draft_state: Optional[SampleState], - cur_draft_layer_idx: int) -> Tuple[Any, Optional[SampleState]]: - self.update_cur_draft_layer_idx( - cur_draft_layer_idx, resource_manager - ) # Update the current draft layer index in the resource manager. + self, + draft_batch: ScheduledRequests, + resource_manager: ResourceManager, + previous_draft_state: Optional[SampleState], + cur_draft_layer_idx: int, + ) -> Tuple[Any, Optional[SampleState]]: + + # Update the current draft layer index in the resource manager. + self.update_cur_draft_layer_idx(cur_draft_layer_idx) """Forward pass through the draft model.""" outputs = self.forward_draft_model( draft_batch, @@ -631,7 +647,7 @@ def _execute_draft_loop( previous_draft_state = initial_draft_state # Generate remaining draft tokens iteratively - for i in range(self.max_draft_tokens - 1): + for i in range(self.max_draft_len - 1): if len(draft_batch.generation_requests) == 0: break @@ -693,7 +709,7 @@ def generate_draft_tokens_with_overlap( # Initial forward pass self.update_cur_draft_layer_idx( - 0, resource_manager + cur_draft_layer_idx=0 ) # Update the current draft layer index in the resource manager. outputs = self.forward_draft_model(draft_batch, resource_manager, @@ -707,7 +723,7 @@ def generate_draft_tokens_with_overlap( target_inputs, outputs, draft_position=0, - draft_length=self.max_draft_tokens, + draft_length=self.max_draft_len, # wy: (?) draft_batch=draft_batch, req_id_to_old_request=req_id_to_old_request) @@ -776,7 +792,7 @@ def prepare_draft_tokens( return self.update_cur_draft_layer_idx( - 0, resource_manager + cur_draft_layer_idx=0 ) # Update the current draft layer index in the resource manager. # Initial forward pass. May do the complete drafting loop # if use_static_draft_loop is set. @@ -803,9 +819,12 @@ def prepare_draft_tokens( None, sample_state) # Final cleanup + self.update_cur_draft_layer_idx( + cur_draft_layer_idx=self.max_draft_len) if previous_draft_state is not None: self.process_dynamic_draft_outputs(previous_draft_state, - req_id_to_old_request) + req_id_to_old_request, + resource_manager) except Exception as e: traceback.print_exc() diff --git a/tensorrt_llm/_torch/speculative/spec_tree_manager.py b/tensorrt_llm/_torch/speculative/spec_tree_manager.py index b84083cfb3b..b96f02e1121 100644 --- a/tensorrt_llm/_torch/speculative/spec_tree_manager.py +++ b/tensorrt_llm/_torch/speculative/spec_tree_manager.py @@ -1,3 +1,4 @@ +import math from typing import List, Optional import torch @@ -7,7 +8,7 @@ class SpecTreeManager: use_dynamic_tree: bool # Whether using dynamic tree max_total_draft_tokens: int # The number of all nodes in the tree (except the root) dynamic_tree_max_topK: int # If using dynamic tree, the number of nodes to expand each time. - max_draft_len: int # The number of drafter layer. When using linear-tree, the max_draft_len is the same as max_total_draft_tokens. + max_draft_len: int # The number of drafter layer. cur_draft_layer_idx: int # The current index of the drafter layer # Auxiliary buffers @@ -26,6 +27,10 @@ class SpecTreeManager: # shape: [num_trees, max_total_draft_tokens + 1], pad the 0-1 matrix to int32 vector. spec_dec_pack_mask: Optional[torch.Tensor] = None + # The spec position offsets. + # shape: [num_trees, max_total_draft_tokens + 1]. + spec_dec_position_offsets: Optional[torch.Tensor] = None + def __init__(self, max_num_requests: int, use_dynamic_tree: bool, max_total_draft_tokens: int, max_draft_len: int, eagle_choices: [List[List[int]]], dynamic_tree_max_topK: int): @@ -36,7 +41,7 @@ def __init__(self, max_num_requests: int, use_dynamic_tree: bool, self.eagle_choices = eagle_choices self.num_trees = max_num_requests if use_dynamic_tree else 1 self.dynamic_tree_max_topK = dynamic_tree_max_topK - self.cur_draft_layer_idx = -1 + self.cur_draft_layer_idx = 0 # Initialize the buffers self.eagle_paths = torch.ones( @@ -53,37 +58,93 @@ def __init__(self, max_num_requests: int, use_dynamic_tree: bool, pin_memory=True, ).unsqueeze(0).repeat(self.num_trees, 1, 1) self.spec_dec_pack_mask = torch.zeros( - (self.num_trees + 1, self.max_total_draft_tokens + 1), + (self.num_trees, self.max_total_draft_tokens + 1, + math.ceil((self.max_total_draft_tokens + 1) / 32)), dtype=torch.int32, device='cpu', pin_memory=True, ) + self.spec_dec_position_offsets = torch.zeros( + (self.num_trees, self.max_total_draft_tokens + 1), + dtype=torch.int32, + device='cpu', + pin_memory=True, + ) + + # top_k_list[i] means the topK values for the i-th layer. + # Examples: + # top_k_list[0] = [3] means the topK values for the 0-th layer (aka, expand the root node) is 3. + # top_k_list[1] = [3, 2, 1] means the topK values for the 1-th layer's nodes is 3, 2, 1, respectively. self.top_k_list = [] + # cumulative_draft_lens[i] means the cumulative draft token lengths AFTER the i-th layer. DO NOT include the root node. + # Examples: + # cumulative_draft_lens[0] = 3 means the draft tokens generated AFTER the 0-th layer will be 3. + # cumulative_draft_lens[1] = 9 (3 + 3 + 2 + 1 = 9) means the cumulative generated AFTER the 1-st layer will be 9. + self.cumulative_draft_lens = [] + + # For each drafter's generation requests, we take all selected draft tokens from the current layer and above (excluding the root node) as input tokens. + # However, only some tokens in the current layer will have child nodes, i.e., they will continue to expand and generated the next layer. + # We only need to set the gather_ids for these tokens. + # NOTE: For the static tree, each node is selected. + # gather_ids_per_layer is a list of offsets for the draft tokens that need to gather_ids in draft_layer_id layer. + # The offset is relative to the root node - 1 (-1 is because the root node is excluded). + self.gather_ids_per_layer = [] + + # Auxiliary variable for static tree. + #Considering that the static tree is a fixed tree, we can use some auxiliary variables to record some + # information in advance to avoid repeated calculations between different iterations. + # nodes_list_per_layer[i] means the nodes list for the i-th layer. Include the root node. + self.nodes_list_per_layer = [] + # Mapping choices to unique indices. + self.index_mapping_set = {} + if self.use_dynamic_tree: self.init_tree_info_for_dynamic_tree() else: self.init_tree_info_for_static_tree() + # self.dump_tree_info() + def init_tree_info_for_dynamic_tree(self): # For the dynamic tree # To the internal layer, the number of nodes is the same as the dynamic_tree_max_topK. - self.top_k_list = [ - torch.ones(self.dynamic_tree_max_topK, - dtype=torch.int32, - device='cpu', - pin_memory=True) * self.dynamic_tree_max_topK + self.top_k_list.append( + torch.tensor([self.dynamic_tree_max_topK], + dtype=torch.int32, + device='cpu', + pin_memory=True)) + for i in range(self.max_draft_len - 1): + self.topK_list.append( + torch.tensor([ + self.dynamic_tree_max_topK + for _ in range(self.dynamic_tree_max_topK) + ], + dtype=torch.int32, + device='cpu', + pin_memory=True)) + + self.cumulative_draft_lens = [ + i * self.dynamic_tree_max_topK + for i in range(1, self.max_draft_len + 1) ] + self.gather_ids_per_layer.append([0]) + for i in range(1, self.max_draft_len): + self.gather_ids_per_layer.append( + list( + range(self.dynamic_tree_max_topK * (i - 1), + self.dynamic_tree_max_topK * i))) + # For the static tree def init_tree_info_for_static_tree(self): - index_mapping_set = {} - nodes_list_per_layer = [[] for _ in range(self.max_draft_len + 1)] + self.index_mapping_set = {} + self.nodes_list_per_layer = [[] for _ in range(self.max_draft_len + 1)] child_nodes_list = [[] for _ in range(self.max_total_draft_tokens + 1)] # 1) Map the index for i, choice in enumerate(self.eagle_choices): - index_mapping_set[str(choice)] = i + 1 + self.index_mapping_set[str(choice)] = i + 1 # 2) Reconstruct the eagle_paths self.eagle_paths.fill_(-1) @@ -91,44 +152,70 @@ def init_tree_info_for_static_tree(self): for i, choice in enumerate(self.eagle_choices): self.eagle_paths[0][i + 1][0] = 0 for j in range(len(choice)): - self.eagle_paths[0][i + 1][j + 1] = index_mapping_set[str( + self.eagle_paths[0][i + 1][j + 1] = self.index_mapping_set[str( choice[:j + 1])] # 3) Compute node_list_per_layer - nodes_list_per_layer[0].append(0) # root node + self.nodes_list_per_layer[0].append(0) # root node for choice in self.eagle_choices: cur_layer = len(choice) - nodes_list_per_layer[cur_layer].append( - index_mapping_set[str(choice)]) + self.nodes_list_per_layer[cur_layer].append( + self.index_mapping_set[str(choice)]) # 4) Compute child_nodes_list for choice in self.eagle_choices: if len(choice) == 1: # root node's children - child_nodes_list[0].append(index_mapping_set[str(choice)]) + child_nodes_list[0].append(self.index_mapping_set[str(choice)]) else: - child_nodes_list[index_mapping_set[str(choice[:-1])]].append( - index_mapping_set[str(choice)]) + child_nodes_list[self.index_mapping_set[str( + choice[:-1])]].append(self.index_mapping_set[str(choice)]) # 5) Compute top_k_list for i in range(self.max_draft_len): - cur_layer_nodes = nodes_list_per_layer[i] + cur_layer_nodes = self.nodes_list_per_layer[i] tmp_top_k_list = [ len(child_nodes_list[node]) for node in cur_layer_nodes if len(child_nodes_list[node]) > 0 ] - assert sum(tmp_top_k_list) == len(nodes_list_per_layer[i + 1]) + assert sum(tmp_top_k_list) == len(self.nodes_list_per_layer[i + 1]) self.top_k_list.append( torch.tensor(tmp_top_k_list, dtype=torch.int32, device='cpu', pin_memory=True)) + # 6) Compute cumulative_draft_lens + self.cumulative_draft_lens.append(len(self.nodes_list_per_layer[1])) + for i in range(1, self.max_draft_len): + self.cumulative_draft_lens.append( + self.cumulative_draft_lens[i - 1] + + len(self.nodes_list_per_layer[i + 1])) - # 6) Compute the spec decoding according to the eagle_paths + # 7) Compute the spec decoding according to the eagle_paths for i, path in enumerate(self.eagle_paths[0]): indices = path[path > -1] self.spec_dec_mask_matrix[0][i, indices] = 1 - self.spec_dec_pack_mask = self.compute_spec_dec_pack_mask( - self.spec_dec_mask_matrix) + self.compute_spec_dec_pack_mask(self.spec_dec_mask_matrix, + self.spec_dec_pack_mask) + + # 8) Compute the spec position offsets + start_idx = 0 + for i in range(self.max_draft_len + 1): + num_nodes_this_layer = len(self.nodes_list_per_layer[i]) + self.spec_dec_position_offsets[:, start_idx:start_idx + + num_nodes_this_layer] = i + start_idx += num_nodes_this_layer + + # 9) Compute the gather_ids_per_layer + self.gather_ids_per_layer.append([0]) + for i in range(1, self.max_draft_len): + cur_gather_ids = [] + for path in self.eagle_choices: + if len(path) == i: + # Has child node(s) + if (len(child_nodes_list[self.index_mapping_set[str(path)]]) + > 0): + cur_gather_ids.append(self.index_mapping_set[str(path)]) + self.gather_ids_per_layer.append(cur_gather_ids) # Get the eagle_paths def get_eagle_paths(self, tree_idx=0): @@ -145,22 +232,70 @@ def get_eagle_paths(self, tree_idx=0): # Get the topK list for the specific draft layer def get_top_k_list(self, draft_layer_id): - assert draft_layer_id >= 0 + assert draft_layer_id >= 0 and draft_layer_id < self.max_draft_len return self.top_k_list[draft_layer_id] + # Get the cumulative draft token lengths AFTER the i-th layer. DO NOT include the root node. + def get_cumulative_draft_lens(self, draft_layer_id): + assert draft_layer_id >= 0 and draft_layer_id < self.max_draft_len + return self.cumulative_draft_lens[draft_layer_id] + + # Get the draft token lengths for the specific draft layer. + def get_current_layer_draft_len(self, draft_layer_id): + assert draft_layer_id >= 0 and draft_layer_id < self.max_draft_len + if self.use_dynamic_tree: + return self.dynamic_tree_max_topK + else: + return len(self.nodes_list_per_layer[draft_layer_id + + 1]) # +1 to skip the root node + + # Get the gather ids for the specific draft layer. + # Return: A list []. + def get_gather_ids(self, draft_layer_id): + assert draft_layer_id > 0 and draft_layer_id < self.max_draft_len + return self.gather_ids_per_layer[draft_layer_id] + # Compute the packed mask according to the mask matrix - def compute_spec_dec_pack_mask(self, mask_matrix): + def compute_spec_dec_pack_mask(self, mask_matrix, packed_mask): # mask_matrix: shape: [num_trees, max_total_draft_tokens + 1, max_total_draft_tokens + 1] + # packed_mask: shape: [num_trees, max_total_draft_tokens + 1, math.ceil((max_total_draft_tokens + 1) / 32)] + num_blocks = math.ceil((self.max_total_draft_tokens + 1) / 32) int_tensor = mask_matrix.to(torch.int32) - weights = torch.pow( - 2, torch.arange(mask_matrix.shape[-1], device=mask_matrix.device)) - packed_mask = torch.sum(int_tensor * weights, dim=-1) - return packed_mask + int_tensor = int_tensor.reshape(-1, self.max_total_draft_tokens + 1) + packed_mask = packed_mask.reshape(-1, num_blocks) + + for block_idx in range(num_blocks): + start_idx = block_idx * 32 + end_idx = min(start_idx + 32, self.max_total_draft_tokens + 1) + block_bits = int_tensor[:, start_idx:end_idx] + weight = torch.pow( + 2, + torch.arange(end_idx - start_idx, + dtype=torch.int32, + device=int_tensor.device)) + block_value = torch.sum(block_bits * weight, dim=-1) + packed_mask[:, block_idx] = block_value + + packed_mask = packed_mask.reshape(self.num_trees, + self.max_total_draft_tokens + 1, + num_blocks) # Print the tree info def dump_tree_info(self): if not self.use_dynamic_tree: print(f"Static tree: {self.eagle_paths}") + print(f"Nodes list per layer: {self.nodes_list_per_layer}") + print(f"Index mapping set: {self.index_mapping_set}") print(f"TopK list: {self.top_k_list}") + print(f"Cumulative draft lens: {self.cumulative_draft_lens}") + print(f"Gather ids per layer: {self.gather_ids_per_layer}") print(f"Spec dec mask matrix: {self.spec_dec_mask_matrix.int()}") print(f"Spec dec pack mask: {self.spec_dec_pack_mask}") + print(f"Spec dec position offsets: {self.spec_dec_position_offsets}") + + def print_mask_matrix_from_packed_mask(self): + for i in range(self.num_trees): + for j in range(self.max_total_draft_tokens + 1): + num_blocks = math.ceil((self.max_total_draft_tokens + 1) / 32) + for k in range(num_blocks - 1, -1, -1): + print(bin(self.spec_dec_pack_mask[i, j, k])[2:], end='') diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 56b44704c0e..7a1d1101a02 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -42,7 +42,8 @@ def get_spec_metadata(spec_config, is_mtp_eagle=False, max_total_draft_tokens=spec_config.max_total_draft_tokens, eagle_choices=spec_config.eagle_choices, - is_spec_dec_tree=spec_config.eagle_choices is not None, + is_spec_dec_tree=spec_config.eagle_choices is not None + or spec_config.use_dynamic_tree, is_spec_dec_dynamic_tree=spec_config.use_dynamic_tree, ) if spec_config.spec_dec_mode.is_eagle3_one_model(): @@ -143,6 +144,7 @@ def get_spec_drafter(model_engine, return ModelDrafter(spec_config, draft_model_engine, spec_config.max_draft_len, + spec_config.max_total_draft_tokens, SeqSlotManager(max_num_requests), sampler, spec_resource_manager=spec_resource_manager, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 9a7faede4ce..2d14130a151 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -357,6 +357,8 @@ class _ModelFormatKind(Enum): class DecodingBaseConfig(StrictBaseModel): max_draft_len: Optional[int] = None + # The number of draft tokens in the draft tokens tree. + max_total_draft_tokens: Optional[int] = None speculative_model_dir: Optional[Union[str, Path]] = None # PyTorch only. @@ -436,6 +438,12 @@ class MedusaDecodingConfig(DecodingBaseConfig): medusa_choices: Optional[List[List[int]]] = None num_medusa_heads: Optional[int] = None + def __init__(self, **kwargs): + super().__init__() + for attr_name, attr_value in kwargs.items(): + setattr(self, attr_name, attr_value) + self.max_total_draft_tokens = self.max_draft_len # Current Medusa only support linear tree + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -454,8 +462,6 @@ class EagleDecodingConfig(DecodingBaseConfig): use_dynamic_tree: Optional[bool] = False # The topK value for each layer when enable dynamic tree. dynamic_tree_max_topK: Optional[int] = None - # The number of draft tokens in the draft tokens tree. - max_total_draft_tokens: Optional[int] = None # The number of eagle layer. will not be used in pytorch flow, just for compatibility with TRT flow num_eagle_layers: Optional[int] = None # The number of non-leaves in each layer. @@ -490,7 +496,7 @@ def __init__(self, **kwargs): # Checks whether the input eagle choices is valid # and reset the max_draft_len and num_eagle_layers if necessary if self.eagle_choices is not None: - # If eagle_choices is provided, use_dynamic_tree will not be used + # If eagle_choices is provided, use_dynamic_tree should not be used assert not self.use_dynamic_tree, "If eagle_choices is provided, use_dynamic_tree need to be False" # Get num_eagle_layers from eagle_choices @@ -513,6 +519,9 @@ def __init__(self, **kwargs): assert self.dynamic_tree_max_topK is not None and self.dynamic_tree_max_topK > 0, "dynamic_tree_max_topK should be provided, which indicates the number of nodes to expand each time" assert self.max_total_draft_tokens is not None and self.max_total_draft_tokens > 0, "max_total_draft_tokens should be provided, which indicates the total nodes of the final draft tree. (exclude the root node)" + if self.eagle3_one_model: + assert self.is_linear_tree, "Eagle3 one-model does not support tree decoding now. Please use Eagle3 two-model." + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -561,12 +570,24 @@ def num_capture_layers(self) -> int: return len(self.eagle3_layers_to_capture) return 3 + @functools.cached_property + def is_linear_tree(self) -> bool: + if self.eagle_choices is None and self.use_dynamic_tree is None: + return True + return False + class UserProvidedDecodingConfig(DecodingBaseConfig): # Cannot use real type annotations due to circular imports drafter: object # Type is Drafter resource_manager: object = None # Type is Optional[ResourceManager] + def __init__(self, **kwargs): + super().__init__() + for attr_name, attr_value in kwargs.items(): + setattr(self, attr_name, attr_value) + self.max_total_draft_tokens = self.max_draft_len # Current UserProvided only support linear tree + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -599,6 +620,12 @@ class NGramDecodingConfig(DecodingBaseConfig): is_use_oldest: bool = True is_public_pool: bool = True + def __init__(self, **kwargs): + super().__init__() + for attr_name, attr_value in kwargs.items(): + setattr(self, attr_name, attr_value) + self.max_total_draft_tokens = self.max_draft_len # Current NGram only support linear tree + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -611,6 +638,12 @@ def supports_backend(self, backend: str) -> bool: class DraftTargetDecodingConfig(DecodingBaseConfig): + def __init__(self, **kwargs): + super().__init__() + for attr_name, attr_value in kwargs.items(): + setattr(self, attr_name, attr_value) + self.max_total_draft_tokens = self.max_draft_len # Current DraftTarget only support linear tree + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -640,11 +673,17 @@ class MTPDecodingConfig(DecodingBaseConfig): BEGIN_THINKING_PHASE_TOKEN: int = 128798 END_THINKING_PHASE_TOKEN: int = 128799 + def __init__(self, **kwargs): + super().__init__() + for attr_name, attr_value in kwargs.items(): + setattr(self, attr_name, attr_value) + if attr_name == 'num_nextn_predict_layers': + self.max_draft_len = attr_value + self.max_total_draft_tokens = attr_value # Current MTP only support linear tree + @classmethod def from_dict(cls, data: dict): - out = cls(**data) - out.max_draft_len = out.num_nextn_predict_layers - return out + return cls(**data) decoding_type: ClassVar[str] = "MTP" @@ -678,6 +717,12 @@ class AutoDecodingConfig(DecodingBaseConfig): Attributes that are inherited from the base class are ignored. """ + def __init__(self, **kwargs): + super().__init__() + for attr_name, attr_value in kwargs.items(): + setattr(self, attr_name, attr_value) + self.max_total_draft_tokens = self.max_draft_len # Current Auto only support linear tree + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -1020,6 +1065,7 @@ def validate_positive_values(cls, v): def __init__(self, **data): super().__init__(**data) + self.max_total_draft_tokens = self.max_draft_len # Current Lookahead only support linear tree self._check_fields() def calculate_speculative_resource(self): diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index df54cd3e456..2183931564a 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -2049,6 +2049,41 @@ def test_ptp_quickstart_advanced_eagle3(llm_root, llm_venv, model_name, _check_mem_usage(running_log, [25.2, 0, 0, 0]) +@pytest.mark.parametrize("model_name,model_path,eagle_model_path", [ + ("Llama-3.1-8b-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct", + "EAGLE3-LLaMA3.1-Instruct-8B"), +]) +def test_draft_token_tree_quickstart_advanced_eagle3(llm_root, llm_venv, + model_name, model_path, + eagle_model_path): + print(f"Testing {model_name}.") + example_root = Path(os.path.join(llm_root, "examples", "llm-api")) + with tempfile.NamedTemporaryFile(mode='w+t', + suffix=f".{model_name}.log", + dir="./", + delete=True, + delete_on_close=True) as running_log: + llm_venv.run_cmd([ + str(example_root / "quickstart_advanced.py"), + "--prompt", + "You are a good assistant. Please tell me the capital of France is", + "--spec_decode_max_draft_len", + "3", + "--spec_decode_algo", + "eagle3", + "--model_dir", + f"{llm_models_root()}/{model_path}", + "--draft_model_dir", + f"{llm_models_root()}/{eagle_model_path}", + "--disable_kv_cache_reuse", + "--disable_overlap_scheduler", + "--eagle_choices", + "[[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [2, 0], [0, 0, 0], [0, 1, 0], [1, 0, 0]]", + ], + stdout=running_log) + _check_mem_usage(running_log, [25.2, 0, 0, 0]) + + @pytest.mark.parametrize("model_name,model_path", [ ("Llama-3.1-8B-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct"), ]) diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 586161fa15b..4543ba4afa3 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -111,6 +111,7 @@ l0_h100: - test_e2e.py::test_openai_chat_harmony - test_e2e.py::test_openai_responses - test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] TIMEOUT (90) + - test_e2e.py::test_draft_token_tree_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B] # Only support hopper yet - test_e2e.py::test_trtllm_benchmark_serving[llama-3.1-model/Meta-Llama-3.1-8B] # ------------- AutoDeploy tests --------------- - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype diff --git a/tests/unittest/_torch/speculative/test_draft_token_tree_runtime.py b/tests/unittest/_torch/speculative/test_draft_token_tree_runtime.py new file mode 100644 index 00000000000..35d7d1b37d0 --- /dev/null +++ b/tests/unittest/_torch/speculative/test_draft_token_tree_runtime.py @@ -0,0 +1,724 @@ +import os +import sys +import unittest +from dataclasses import dataclass + +import torch +from utils.llm_data import llm_models_root + +import tensorrt_llm +from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest +from tensorrt_llm._torch.pyexecutor.model_engine import PyTorchModelEngine + +# isort: off +from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager, + ResourceManager, + ResourceManagerType + ) +# isort: on +from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors +from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests +from tensorrt_llm._torch.speculative.eagle3 import (Eagle3ResourceManager, + Eagle3SpecMetadata) +from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager +from tensorrt_llm.bindings.executor import KvCacheConfig +from tensorrt_llm.llmapi import EagleDecodingConfig, SamplingParams +from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig +from tensorrt_llm.mapping import Mapping + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + + +# Test the update_spec_dec_param function for static tree +def test_draft_token_static_tree_prepare_spec_params(): + models_path = llm_models_root() + eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" # It will not actually be used. + + max_num_tokens = 1024 + kv_cache_manager = None + scheduled_requests = [ + ] # for the static tree, we will not use the scheduled requests in update_spec_dec_param + use_dynamic_tree = False + + def run_test(max_batch_size, max_draft_len, max_total_draft_tokens, + cur_draft_layer_idx, eagle_choices, is_draft_model, + is_spec_decoding_enabled, ref_spec_decoding_position_offsets, + ref_spec_decoding_packed_mask, + ref_spec_decoding_generation_lengths): + + attn_metadata = TrtllmAttentionMetadata( + max_num_requests=max_batch_size, + max_num_tokens=max_num_tokens, + kv_cache_manager=kv_cache_manager) + + spec_config = EagleDecodingConfig( + max_draft_len=max_draft_len, + max_total_draft_tokens=max_total_draft_tokens, + speculative_model_dir=eagle_model_dir, + eagle3_one_model=False, + eagle_choices=eagle_choices, + use_dynamic_tree=use_dynamic_tree, + ) + + spec_tree_manager = SpecTreeManager( + max_num_requests=max_batch_size, + use_dynamic_tree=spec_config.use_dynamic_tree, + max_draft_len=spec_config.max_draft_len, + max_total_draft_tokens=spec_config.max_total_draft_tokens, + eagle_choices=spec_config.eagle_choices, + dynamic_tree_max_topK=spec_config.dynamic_tree_max_topK, + ) + spec_tree_manager.cur_draft_layer_idx = cur_draft_layer_idx + + spec_metadata = Eagle3SpecMetadata( + max_draft_len=spec_config.max_draft_len, + spec_dec_mode=spec_config.spec_dec_mode, + max_num_requests=max_batch_size, + num_layers=32, + hidden_size=1024, + max_num_tokens=max_num_tokens, + dtype=torch.bfloat16, + is_draft_model=is_draft_model, + eagle3_resource_manager=None, + layers_to_capture=spec_config.eagle3_layers_to_capture, + max_total_draft_tokens=spec_config.max_total_draft_tokens, + eagle_choices=spec_config.eagle_choices, + is_spec_dec_tree=spec_config.eagle_choices is not None + or spec_config.use_dynamic_tree, + is_spec_dec_dynamic_tree=spec_config.use_dynamic_tree, + ) + + attn_metadata.update_spec_dec_param( + scheduled_requests=scheduled_requests, + is_spec_decoding_enabled=is_spec_decoding_enabled, + spec_metadata=spec_metadata, + spec_tree_manager=spec_tree_manager, + max_draft_len=max_draft_len, + max_total_draft_tokens=max_total_draft_tokens, + ) + + print( + f"attn_metadata.spec_decoding_position_offsets: {attn_metadata.spec_decoding_position_offsets}" + ) + print( + f"ref_spec_decoding_position_offsets: {ref_spec_decoding_position_offsets}" + ) + + print( + f"attn_metadata.spec_decoding_packed_mask: {attn_metadata.spec_decoding_packed_mask}" + ) + print(f"ref_spec_decoding_packed_mask: {ref_spec_decoding_packed_mask}") + + print( + f"attn_metadata.spec_decoding_generation_lengths: {attn_metadata.spec_decoding_generation_lengths}" + ) + print( + f"ref_spec_decoding_generation_lengths: {ref_spec_decoding_generation_lengths}" + ) + + if is_spec_decoding_enabled: + assert torch.all(attn_metadata.spec_decoding_position_offsets == + ref_spec_decoding_position_offsets) + assert torch.all(attn_metadata.spec_decoding_packed_mask == + ref_spec_decoding_packed_mask) + assert torch.all(attn_metadata.spec_decoding_generation_lengths == + ref_spec_decoding_generation_lengths) + else: + assert attn_metadata.spec_decoding_position_offsets is None + assert attn_metadata.spec_decoding_packed_mask is None + assert attn_metadata.spec_decoding_generation_lengths is None + + ################## CASE 1 is_spec_decoding_enabled = False ########################## + max_batch_size = 1 + is_spec_decoding_enabled = False + max_draft_len = 3 + max_total_draft_tokens = 12 + cur_draft_layer_idx = 0 + eagle_choices = [[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [0, 0, 0], [0, 1, 0], [1, 0, 0]] + is_draft_model = False + ref_spec_decoding_position_offsets = None + ref_spec_decoding_packed_mask = None + ref_spec_decoding_generation_lengths = None + + run_test(max_batch_size, max_draft_len, max_total_draft_tokens, + cur_draft_layer_idx, eagle_choices, is_draft_model, + is_spec_decoding_enabled, ref_spec_decoding_position_offsets, + ref_spec_decoding_packed_mask, + ref_spec_decoding_generation_lengths) + + ################## CASE 2 target model ########################## + max_batch_size = 1 + is_spec_decoding_enabled = True + max_draft_len = 3 + max_total_draft_tokens = 12 + cur_draft_layer_idx = 0 + eagle_choices = [[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [0, 0, 0], [0, 1, 0], [1, 0, 0]] + is_draft_model = False # i.e, target model + ref_spec_decoding_position_offsets = torch.tensor( + [[0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3]], + dtype=torch.int, + device='cuda') + ref_spec_decoding_packed_mask = torch.tensor( + [1, 3, 5, 9, 19, 35, 67, 133, 261, 521, 1043, 2083, 4229], + dtype=torch.int, + device='cuda').reshape(1, max_total_draft_tokens + 1, 1) + ref_spec_decoding_generation_lengths = torch.tensor([13], + dtype=torch.int, + device='cuda') + + run_test(max_batch_size, max_draft_len, max_total_draft_tokens, + cur_draft_layer_idx, eagle_choices, is_draft_model, + is_spec_decoding_enabled, ref_spec_decoding_position_offsets, + ref_spec_decoding_packed_mask, + ref_spec_decoding_generation_lengths) + + ################## CASE 3 target model, batch_size = 2 ########################## + max_batch_size = 2 + is_spec_decoding_enabled = True + max_draft_len = 3 + max_total_draft_tokens = 12 + cur_draft_layer_idx = 0 + eagle_choices = [[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [0, 0, 0], [0, 1, 0], [1, 0, 0]] + is_draft_model = False # i.e, target model + ref_spec_decoding_position_offsets = torch.tensor( + [[0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3]], + dtype=torch.int, + device='cuda').repeat(2, 1, 1) + ref_spec_decoding_packed_mask = torch.tensor( + [1, 3, 5, 9, 19, 35, 67, 133, 261, 521, 1043, 2083, 4229], + dtype=torch.int, + device='cuda').reshape(1, max_total_draft_tokens + 1, + 1).repeat(2, 1, 1) + ref_spec_decoding_generation_lengths = torch.tensor([13], + dtype=torch.int, + device='cuda').repeat( + 2, 1) + + run_test(max_batch_size, max_draft_len, max_total_draft_tokens, + cur_draft_layer_idx, eagle_choices, is_draft_model, + is_spec_decoding_enabled, ref_spec_decoding_position_offsets, + ref_spec_decoding_packed_mask, + ref_spec_decoding_generation_lengths) + + ################## CASE 4 target model, bigger tree ########################## + max_batch_size = 1 + is_spec_decoding_enabled = True + max_draft_len = 4 + max_total_draft_tokens = 20 + cur_draft_layer_idx = 0 + eagle_choices = [[0], [1], [2], [3], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [3, 0], [0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 2, 0], + [1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], + [0, 1, 0, 0]] + is_draft_model = False # i.e, target model + ref_spec_decoding_position_offsets = torch.tensor( + [[0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4]], + dtype=torch.int, + device='cuda') + ref_spec_decoding_packed_mask = torch.tensor( + [ + 1, 3, 5, 9, 17, 35, 67, 131, 261, 517, 1033, 2065, 4131, 8227, + 16451, 32899, 65797, 135203, 266275, 532515, 1065027 + ], + dtype=torch.int, + device='cuda').reshape(1, max_total_draft_tokens + 1, 1) + ref_spec_decoding_generation_lengths = torch.tensor([21], + dtype=torch.int, + device='cuda') + + run_test(max_batch_size, max_draft_len, max_total_draft_tokens, + cur_draft_layer_idx, eagle_choices, is_draft_model, + is_spec_decoding_enabled, ref_spec_decoding_position_offsets, + ref_spec_decoding_packed_mask, + ref_spec_decoding_generation_lengths) + + ################## CASE 5 drafter model, drafter_layer_idx = 0 ########################## + max_batch_size = 1 + is_spec_decoding_enabled = True + max_draft_len = 3 + max_total_draft_tokens = 12 + cur_draft_layer_idx = 0 + eagle_choices = [[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [0, 0, 0], [0, 1, 0], [1, 0, 0]] + is_draft_model = True # i.e, drafter model + # These tensors are not used in the first drafter layer. + ref_spec_decoding_position_offsets = torch.tensor( + [[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + dtype=torch.int, + device='cuda') + ref_spec_decoding_packed_mask = torch.tensor( + [[1, 3, 7, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + dtype=torch.int, + device='cuda').reshape(1, max_total_draft_tokens + 1, 1) + ref_spec_decoding_generation_lengths = torch.tensor([4], + dtype=torch.int, + device='cuda') + + run_test(max_batch_size, max_draft_len, max_total_draft_tokens, + cur_draft_layer_idx, eagle_choices, is_draft_model, + is_spec_decoding_enabled, ref_spec_decoding_position_offsets, + ref_spec_decoding_packed_mask, + ref_spec_decoding_generation_lengths) + + ################## CASE 6 drafter model, drafter_layer_idx = 1 ########################## + max_batch_size = 1 + is_spec_decoding_enabled = True + max_draft_len = 3 + max_total_draft_tokens = 12 + cur_draft_layer_idx = 1 + eagle_choices = [[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [0, 0, 0], [0, 1, 0], [1, 0, 0]] + is_draft_model = True # i.e, drafter model + ref_spec_decoding_position_offsets = torch.tensor( + [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + dtype=torch.int, + device='cuda') + ref_spec_decoding_packed_mask = torch.tensor( + [1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int, + device='cuda').reshape(1, max_total_draft_tokens + 1, 1) + ref_spec_decoding_generation_lengths = torch.tensor([3], + dtype=torch.int, + device='cuda') + + run_test(max_batch_size, max_draft_len, max_total_draft_tokens, + cur_draft_layer_idx, eagle_choices, is_draft_model, + is_spec_decoding_enabled, ref_spec_decoding_position_offsets, + ref_spec_decoding_packed_mask, + ref_spec_decoding_generation_lengths) + + ################## CASE 7 drafter model, drafter_layer_idx = 2 ########################## + max_batch_size = 1 + is_spec_decoding_enabled = True + max_draft_len = 3 + max_total_draft_tokens = 12 + cur_draft_layer_idx = 2 + eagle_choices = [[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [0, 0, 0], [0, 1, 0], [1, 0, 0]] + is_draft_model = True # i.e, drafter model + ref_spec_decoding_position_offsets = torch.tensor( + [[0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]], + dtype=torch.int, + device='cuda') + ref_spec_decoding_packed_mask = torch.tensor( + [1, 2, 4, 9, 17, 33, 66, 130, 260, 0, 0, 0, 0], + dtype=torch.int, + device='cuda').reshape(1, max_total_draft_tokens + 1, 1) + ref_spec_decoding_generation_lengths = torch.tensor([9], + dtype=torch.int, + device='cuda') + + run_test(max_batch_size, max_draft_len, max_total_draft_tokens, + cur_draft_layer_idx, eagle_choices, is_draft_model, + is_spec_decoding_enabled, ref_spec_decoding_position_offsets, + ref_spec_decoding_packed_mask, + ref_spec_decoding_generation_lengths) + + +############################################################################################################################## + + +def _create_request(input_tokens, req_id: int, is_first_draft: bool): + sampling_params = SamplingParams() + kwargs = { + "request_id": + req_id, + "max_new_tokens": + 128, + "input_tokens": + input_tokens, + "sampling_config": + tensorrt_llm.bindings.SamplingConfig( + sampling_params._get_sampling_config()), + "is_streaming": + False, + } + request = LlmRequest(**kwargs) + request.paged_kv_block_ids = [] + request.py_is_first_draft = is_first_draft + request.py_seq_slot = req_id + request.py_batch_idx = req_id + + return request + + +@dataclass +class Config: + torch_dtype: torch.dtype + num_key_value_heads: int = 16 + num_attention_heads: int = 16 + hidden_size: int = 256 + architectures: list[str] = None + + @property + def head_dim(self) -> int: + return self.hidden_size // self.num_attention_heads + + +class DummyModel(torch.nn.Module): + + def __init__(self, dtype: torch.dtype): + super().__init__() + self.model_config = ModelConfig(pretrained_config=Config( + torch_dtype=dtype)) + self.recorded_position_ids = None + + def infer_max_seq_len(self): + return 2048 + + @property + def config(self): + return self.model_config.pretrained_config + + def forward(self, *args, **kwargs) -> torch.Tensor: + input_ids = kwargs["input_ids"] + self.recorded_position_ids = kwargs["position_ids"] + batch_size = input_ids.size(0) + return {"logits": torch.randn((batch_size, 2048), device='cuda')} + + +class DummyModelEngine(PyTorchModelEngine): + + def __init__(self, + pytorch_backend_config: PyTorchConfig, + spec_config: DecodingBaseConfig, + batch_size: int, + dtype: torch.dtype, + max_seq_len: int = 128, + max_total_draft_tokens: int = 12, + is_draft_model: bool = False) -> None: + self.dtype = dtype + mapping = Mapping(world_size=tensorrt_llm.mpi_world_size(), + tp_size=tensorrt_llm.mpi_world_size(), + rank=tensorrt_llm.mpi_rank()) + self.model_is_wrapped = False + self.hidden_size = 2048 + self.max_num_tokens = max_seq_len + self.max_seq_len = max_seq_len + + super().__init__( + model_path="", + pytorch_backend_config=pytorch_backend_config, + checkpoint_loader=None, + batch_size=batch_size, + max_seq_len=max_seq_len, + mapping=mapping, + spec_config=spec_config, + is_draft_model=is_draft_model, + model=DummyModel(self.dtype), + ) + self.max_total_draft_tokens = max_total_draft_tokens + + +def create_model_engine_and_kvcache(spec_config, + max_num_requests, + batch_size, + use_cuda_graph, + max_seq_len, + max_total_draft_tokens, + is_draft_model, + config: PyTorchConfig = None): + tokens_per_block = 1 + max_tokens = 258 # Atleast 1 more than the max seq len + num_layers = 1 + + config = config if config else PyTorchConfig( + use_cuda_graph=use_cuda_graph, + cuda_graph_padding_enabled=use_cuda_graph) + + if use_cuda_graph: + config.cuda_graph_batch_sizes = [ + 1, 2, 4, 8, 16, 32, 64, 128 + ] if config.cuda_graph_batch_sizes is None else config.cuda_graph_batch_sizes + + model_engine = DummyModelEngine( + pytorch_backend_config=config, + spec_config=spec_config, + batch_size=max_num_requests, + dtype=torch.half, + max_seq_len=max_seq_len, + max_total_draft_tokens=max_total_draft_tokens, + is_draft_model=is_draft_model) + + kv_cache_config = KvCacheConfig(max_tokens=max_tokens) + mapping = Mapping(world_size=1, tp_size=1, rank=0) + kv_cache_manager = KVCacheManager( + kv_cache_config, + tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, + num_layers=num_layers, + num_kv_heads=model_engine.model.config.num_key_value_heads, + head_dim=model_engine.model.config.head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_tokens, + max_batch_size=batch_size, + mapping=mapping, + dtype=tensorrt_llm.bindings.DataType.HALF, + ) + + return model_engine, kv_cache_manager + + +# from executor.test_pytorch_model_engine import create_model_engine_and_kvcache, _create_request +# Test the prepare_inputs function for static tree +def test_draft_token_static_tree_prepare_inputs(): + + max_num_requests = 1 + max_batch_size = max_num_requests + batch_size = 1 + use_cuda_graph = False + max_num_tokens = 128 + max_seq_len = 128 + hidden_size = 1024 + + # Use same tree + max_draft_len = 3 + max_total_draft_tokens = 12 + eagle_model_dir = "" + eagle_choices = [[0], [1], [2], [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], + [2, 0], [0, 0, 0], [0, 1, 0], [1, 0, 0]] + use_dynamic_tree = False + + def run_test(scheduled_requests, new_tensors_device, is_draft_model, + cur_draft_layer_idx, ref_input_ids, ref_position_ids, + ref_gather_ids): + + # 1) Create spec related config, resource managers + spec_config = EagleDecodingConfig( + max_draft_len=max_draft_len, + max_total_draft_tokens=max_total_draft_tokens, + speculative_model_dir=eagle_model_dir, + eagle3_one_model=False, + eagle_choices=eagle_choices, + use_dynamic_tree=use_dynamic_tree, + ) + eagle3_resource_manager = Eagle3ResourceManager( + spec_config, + torch.half, + hidden_size, + max_num_requests, + max_seq_len, + max_num_tokens, + ) + eagle3_resource_manager.spec_tree_manager.cur_draft_layer_idx = cur_draft_layer_idx + + # 2) Create model engine and kv cache manager + model_engine, kv_cache_manager = create_model_engine_and_kvcache( + spec_config=spec_config, + max_num_requests=max_num_requests, + batch_size=batch_size, + use_cuda_graph=use_cuda_graph, + max_seq_len=max_seq_len, + max_total_draft_tokens=max_total_draft_tokens, + is_draft_model=is_draft_model) + model_engine._disable_overlap_scheduler = True + + for req in scheduled_requests.all_requests(): + kv_cache_manager.add_dummy_requests([req.request_id], + [len(req.get_tokens(0))]) + eagle3_resource_manager.add_dummy_requests([req.request_id]) + + resource_manager = ResourceManager({ + ResourceManagerType.KV_CACHE_MANAGER: + kv_cache_manager, + ResourceManagerType.SPEC_RESOURCE_MANAGER: + eagle3_resource_manager + }) + + # 3) Create attn metadata + attn_metadata = TrtllmAttentionMetadata( + max_num_requests=max_batch_size, + max_num_tokens=max_num_tokens, + kv_cache_manager=kv_cache_manager) + attn_metadata.max_seq_len = max_seq_len + attn_metadata._max_seq_len_storage = max_seq_len + + # 4) Create spec metadata + spec_metadata = Eagle3SpecMetadata( + max_draft_len=spec_config.max_draft_len, + spec_dec_mode=spec_config.spec_dec_mode, + max_num_requests=max_batch_size, + num_layers=32, + hidden_size=1024, + max_num_tokens=max_num_tokens, + dtype=torch.bfloat16, + is_draft_model=is_draft_model, + eagle3_resource_manager=eagle3_resource_manager, + layers_to_capture=spec_config.eagle3_layers_to_capture, + max_total_draft_tokens=spec_config.max_total_draft_tokens, + eagle_choices=spec_config.eagle_choices, + is_spec_dec_tree=spec_config.eagle_choices is not None + or spec_config.use_dynamic_tree, + is_spec_dec_dynamic_tree=spec_config.use_dynamic_tree, + ) + model_engine.spec_metadata = spec_metadata + + # 5) Run the prepare_tp_inputs function + inputs, gather_ids = model_engine._prepare_tp_inputs( + scheduled_requests=scheduled_requests, + kv_cache_manager=kv_cache_manager, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + new_tensors_device=new_tensors_device, + cache_indirection_buffer=None, + resource_manager=resource_manager) + + print(f"inputs['input_ids']: {inputs['input_ids']}") + print(f"ref_input_ids: {ref_input_ids}") + assert torch.all(inputs['input_ids'] == ref_input_ids) + + print(f"inputs['position_ids']: {inputs['position_ids']}") + print(f"ref_position_ids: {ref_position_ids}") + assert torch.all(inputs['position_ids'].squeeze(0) == ref_position_ids) + + print(f"gather_ids: {gather_ids}") + print(f"ref_gather_ids: {ref_gather_ids}") + assert torch.all(gather_ids == ref_gather_ids) + + ################## CASE 1 target model, the generation phase ########################## + is_draft_model = False + scheduled_requests = ScheduledRequests() + target_gen_request = _create_request( + input_tokens=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + req_id=0, + is_first_draft=False) + target_gen_request.py_draft_tokens = [ + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 + ] + scheduled_requests.generation_requests.append(target_gen_request) + cur_draft_layer_idx = 0 + + ref_input_ids = torch.tensor( + [14, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + dtype=torch.int, + device='cuda') + ref_position_ids = torch.tensor( + [14, 15, 15, 15, 16, 16, 16, 16, 16, 16, 17, 17, 17], + dtype=torch.int, + device='cuda') + ref_gather_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + dtype=torch.int, + device='cuda') + + run_test(scheduled_requests, None, is_draft_model, cur_draft_layer_idx, + ref_input_ids, ref_position_ids, ref_gather_ids) + + ################## CASE 2 drafter model, context phase, the first drafter layer ########################## + is_draft_model = True + scheduled_requests = ScheduledRequests() + # '[1:]prompt + new token' already done in model_drafter.py::_prepare_draft_batch() + # The input request here are the draft batch. + drafter_request = _create_request( + input_tokens=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + req_id=0, + is_first_draft=False) + scheduled_requests.context_requests.append(drafter_request) + cur_draft_layer_idx = 0 + + ref_input_ids = torch.tensor( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + dtype=torch.int, + device='cuda') + ref_position_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + dtype=torch.int, + device='cuda') + ref_gather_ids = torch.tensor([14], dtype=torch.int, device='cuda') + + run_test(scheduled_requests, None, is_draft_model, cur_draft_layer_idx, + ref_input_ids, ref_position_ids, ref_gather_ids) + + ################## CASE 3 drafter model, the first drafter layer ########################## + is_draft_model = True + scheduled_requests = ScheduledRequests() + # the input_toeksn already be pad to max_draft_len + 1 + drafter_request = _create_request(input_tokens=[ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 0, 0 + ], + req_id=0, + is_first_draft=True) + drafter_request.py_num_accepted_draft_tokens = 1 # + scheduled_requests.generation_requests.append(drafter_request) + cur_draft_layer_idx = 0 # Prepare to execute the 0-th drafter layer + + ref_input_ids = torch.tensor([17, 18, 0, 0], dtype=torch.int, + device='cuda') # max_draft_len + 1 + ref_position_ids = torch.tensor([16, 17, 18, 19], + dtype=torch.int, + device='cuda') + ref_gather_ids = torch.tensor([1], dtype=torch.int, device='cuda') + + run_test(scheduled_requests, None, is_draft_model, cur_draft_layer_idx, + ref_input_ids, ref_position_ids, ref_gather_ids) + + ################## CASE 4 drafter model, the second drafter layer ########################## + is_draft_model = True + scheduled_requests = ScheduledRequests() + drafter_request = _create_request(input_tokens=[ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17 + ], + req_id=0, + is_first_draft=False) + scheduled_requests.generation_requests.append(drafter_request) + cur_draft_layer_idx = 1 # Prepare to execute the 1-st drafter layer + + new_tensors = torch.zeros((max_total_draft_tokens + 1, max_num_requests, 1), + dtype=torch.int, + device='cuda') + new_tensors[:3, 0, 0] = torch.tensor([30, 31, 32], + dtype=torch.int, + device='cuda') + new_tensors_device = SampleStateTensors(new_tokens=new_tensors) + + ref_input_ids = torch.tensor([30, 31, 32], dtype=torch.int, device='cuda') + ref_position_ids = torch.tensor([17, 17, 17], + dtype=torch.int, + device='cuda') + ref_gather_ids = torch.tensor([0, 1, 2], dtype=torch.int, device='cuda') + + run_test(scheduled_requests, new_tensors_device, is_draft_model, + cur_draft_layer_idx, ref_input_ids, ref_position_ids, + ref_gather_ids) + + ################## CASE 5 drafter model, the third drafter layer ########################## + is_draft_model = True + scheduled_requests = ScheduledRequests() + drafter_request = _create_request( + # 30, 31, 32 are from the previous drafter layer + input_tokens=[ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 30, 31, + 32 + ], + req_id=0, + is_first_draft=False) + scheduled_requests.generation_requests.append(drafter_request) + cur_draft_layer_idx = 2 # Prepare to execute the 2-nd drafter la + + new_tensors = torch.zeros((max_total_draft_tokens + 1, max_num_requests, 1), + dtype=torch.int, + device='cuda') + new_tensors[:6, 0, 0] = torch.tensor([40, 41, 42, 43, 44, 45], + dtype=torch.int, + device='cuda') + new_tensors_device = SampleStateTensors(new_tokens=new_tensors) + + # 30, 31, 32 are from the previous drafter layer + ref_input_ids = torch.tensor([30, 31, 32, 40, 41, 42, 43, 44, 45], + dtype=torch.int, + device='cuda') + ref_position_ids = torch.tensor([17, 17, 17, 18, 18, 18, 18, 18, 18], + dtype=torch.int, + device='cuda') + ref_gather_ids = torch.tensor([3, 4, 6], dtype=torch.int, device='cuda') + + run_test(scheduled_requests, new_tensors_device, is_draft_model, + cur_draft_layer_idx, ref_input_ids, ref_position_ids, + ref_gather_ids) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittest/_torch/speculative/test_draft_token_tree_verification.py b/tests/unittest/_torch/speculative/test_draft_token_tree_verification.py index bfeab8c4270..b4ddcaca3fc 100644 --- a/tests/unittest/_torch/speculative/test_draft_token_tree_verification.py +++ b/tests/unittest/_torch/speculative/test_draft_token_tree_verification.py @@ -229,10 +229,9 @@ def test_static_tree_verification_for_target_model(): sampling_config=SamplingConfig(SamplingParams()._get_sampling_config()), is_streaming=False, ) - input_request.py_draft_tokens = torch.tensor( - [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22], - dtype=torch.int, - device='cpu') # all draft tokens + input_request.py_draft_tokens = [ + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22 + ] # shape: [max_total_draft_tokens + 1, max_batch_size, beam_width] input_new_tokens = torch.tensor( [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 0], @@ -274,10 +273,9 @@ def test_static_tree_verification_for_target_model(): sampling_config=SamplingConfig(SamplingParams()._get_sampling_config()), is_streaming=False, ) - input_request.py_draft_tokens = torch.tensor( - [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22], - dtype=torch.int, - device='cpu') # all draft tokens, [max_total_draft_tokens] + input_request.py_draft_tokens = [ + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22 + ] # shape: [max_total_draft_tokens + 1, max_batch_size, beam_width] input_new_tokens = torch.tensor( [11, 15, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112 @@ -319,10 +317,9 @@ def test_static_tree_verification_for_target_model(): sampling_config=SamplingConfig(SamplingParams()._get_sampling_config()), is_streaming=False, ) - input_request.py_draft_tokens = torch.tensor( - [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22], - dtype=torch.int, - device='cpu') # all draft tokens, [max_total_draft_tokens] + input_request.py_draft_tokens = [ + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22 + ] # shape: [max_total_draft_tokens + 1, max_batch_size, beam_width] input_new_tokens = torch.tensor( [11, 14, 102, 103, 20, 105, 106, 107, 108, 109, 110, 111, 112 From 47a2f4571e1c3223643f8f5b0006c7fe086e760c Mon Sep 17 00:00:00 2001 From: Yue Weng <25103990+yweng0828@users.noreply.github.com> Date: Sat, 11 Oct 2025 09:01:18 +0000 Subject: [PATCH 2/2] fix random issue Signed-off-by: Yue Weng <25103990+yweng0828@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 3 +- tensorrt_llm/_torch/speculative/eagle3.py | 38 +++++++++++++++++++++-- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index a876d0645d0..adb05f907d1 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1063,8 +1063,7 @@ def _tree_sampling_batch(self, new_tokens_cuda: torch.Tensor, torch.int) # 5) Append eagle3 d2t. - self._apply_d2t(new_tokens_cuda[seq_slots, :top_k_list_cumsum[-1]], - model_outputs) + self._apply_d2t(new_tokens_cuda, model_outputs) # 6) Copy back to the output tensor. new_tokens_cuda = new_tokens_cuda.transpose(0, 1).to( diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 4da703d12a6..6f23c7d7425 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -143,6 +143,7 @@ def __post_init__(self): self.is_spec_dec_dynamic_tree = False def prepare(self): + spec_tree_manager = self.eagle3_resource_manager.spec_tree_manager is_first_draft = self.eagle3_resource_manager.is_first_draft # Update start indices # Here, we assume the sequence lengths (seq_lens) during the draft model @@ -169,9 +170,40 @@ def prepare(self): hidden_states_write_indices.extend( list(range(start_idx, start_idx + seq_len))) else: - old_seq_len = self.eagle3_resource_manager.seq_lens[slot_id] - hidden_states_read_indices.append(start_idx + old_seq_len - 1) - hidden_states_write_indices.append(start_idx + seq_len - 1) + if spec_tree_manager is not None: + cur_draft_layer_idx = spec_tree_manager.cur_draft_layer_idx + if cur_draft_layer_idx == 1: + # copy the last hidden states to the beginning of the buffer + old_seq_len = self.eagle3_resource_manager.seq_lens[ + slot_id] + last_hidden_states_idx = start_idx + old_seq_len - 1 + eagle3_hidden_states = self.eagle3_resource_manager.hidden_states + eagle3_hidden_states[start_idx, :].copy_( + eagle3_hidden_states[last_hidden_states_idx, :], + non_blocking=True) + + for gather_ids, top_k_list in zip( + spec_tree_manager. + gather_ids_per_layer[:cur_draft_layer_idx], + spec_tree_manager.top_k_list[:cur_draft_layer_idx]): + for gather_id, top_k_list_i in zip( + gather_ids, top_k_list): + hidden_states_read_indices.extend([gather_id] * + top_k_list_i) + + hidden_states_write_indices.extend( + list( + range( + start_idx + 1, start_idx + 1 + + spec_tree_manager.get_cumulative_draft_lens( + cur_draft_layer_idx - 1)))) + assert len(hidden_states_read_indices) == self.num_tokens + assert len(hidden_states_write_indices) == self.num_tokens + else: + old_seq_len = self.eagle3_resource_manager.seq_lens[slot_id] + hidden_states_read_indices.append(start_idx + old_seq_len - + 1) + hidden_states_write_indices.append(start_idx + seq_len - 1) self.eagle3_resource_manager.seq_lens[slot_id] = seq_len # Prepare hidden states gather ids self.hidden_states_read_indices_host = torch.tensor(