Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,10 @@ def restore_from_spec_dec(self) -> None:
setattr(self, f, v)
self._saved_tensors.clear()

def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree,
is_spec_dec_dynamic_tree, max_draft_tokens):
def update_spec_dec_param(self, scheduled_requests,
is_spec_decoding_enabled, spec_metadata,
spec_tree_manager, max_draft_len,
max_total_draft_tokens):
"""
Hook to be called when using TRTLLM attention backend in spec-dec mode.
"""
Expand Down
124 changes: 96 additions & 28 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,73 +1027,141 @@ def prepare_context_mla_with_cached_kv(self,
self.ctx_kv_indptr[:self.num_contexts + 1].copy_(
self.host_ctx_kv_indptr[:self.num_contexts + 1], non_blocking=True)

def update_spec_dec_param(self, is_spec_decoding_enabled, is_spec_dec_tree,
is_spec_dec_dynamic_tree, max_draft_tokens):
def update_spec_dec_param(self, scheduled_requests,
is_spec_decoding_enabled, spec_metadata,
spec_tree_manager, max_draft_len,
max_total_draft_tokens):
# spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version(
) < 100

# use_spec_decoding is default to true by default, change in runtime by layers / requests
self.use_spec_decoding = self.is_spec_decoding_enabled

self.is_spec_dec_tree = is_spec_dec_tree
self.is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree
self.is_spec_dec_tree = False if spec_tree_manager is None else True
self.is_spec_dec_dynamic_tree = False if spec_tree_manager is None else spec_tree_manager.use_dynamic_tree

# Parameters can be fixed and not changed during runtime if the
if self.is_spec_decoding_enabled:
self.spec_decoding_position_offsets = torch.empty(
[self.max_num_requests, max_draft_tokens + 1],
self.spec_decoding_position_offsets = torch.zeros(
[self.max_num_requests, max_total_draft_tokens + 1],
dtype=torch.int,
device='cuda',
)

self.spec_decoding_packed_mask = torch.empty(
self.spec_decoding_packed_mask = torch.zeros(
[
self.max_num_requests, max_draft_tokens + 1,
math.ceil(max_draft_tokens / 32)
self.max_num_requests, max_total_draft_tokens + 1,
math.ceil(max_total_draft_tokens / 32)
],
dtype=torch.int,
device='cuda',
)

self.spec_decoding_generation_lengths = torch.empty(
self.spec_decoding_generation_lengths = torch.zeros(
[self.max_num_requests],
dtype=torch.int,
device='cuda',
)

if self.is_spec_dec_dynamic_tree:
assert False, "currently dynamic tree is not supported"
# Prepare the spec-dec mask, position offset and generation length for static tree of dynamic tree.
if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree:
assert spec_metadata.spec_dec_mode.is_eagle3(
), "Tree decoding is only supported for Eagle3 now"
# If is the drafter model
if spec_metadata.is_draft_model:
# Layer cur_draft_layer_idx already executed, we are now preparing the spec-dec params for executing cur_draft_layer_idx + 1 layer.
cur_draft_layer_idx = spec_tree_manager.cur_draft_layer_idx

# If is the dynamic tree
if self.is_spec_dec_dynamic_tree:
# TODO: add dynamic tree logic
assert False, "Dynamic tree is not supported yet"
# If is the static tree
else:
# For the first drafter layer, we treat it as a chunked prefill.
# It takes the accepted tokens + new generated tokens as input.
if cur_draft_layer_idx == 0:
# In fact, the drafter model calls 'fmha_v2' or 'xqa' at this time.
self.generate_spec_decoding_position_offsets(
max_total_draft_tokens=max_draft_len)
self.generate_spec_decoding_packed_mask(
max_total_draft_tokens=max_draft_len)
self.generate_spec_decoding_generation_length(
max_total_draft_tokens=max_draft_len)
else:
num_process_draft_tokens = spec_tree_manager.get_cumulative_draft_lens(
cur_draft_layer_idx - 1)

self.spec_decoding_position_offsets[:, :num_process_draft_tokens].copy_(
spec_tree_manager.spec_dec_position_offsets[
0:, 1:1 + num_process_draft_tokens] -
1, # Exclude the root node
non_blocking=True)
self.spec_decoding_packed_mask[:, :
num_process_draft_tokens, :].copy_(
spec_tree_manager
.
spec_dec_pack_mask[
0:, 1:
num_process_draft_tokens
+ 1, :] / 2,
non_blocking=True
)
self.spec_decoding_generation_lengths[:].fill_(
num_process_draft_tokens)

# If is the target model
else:
# If is the dynamic tree
if self.is_spec_dec_dynamic_tree:
# TODO: add dynamic tree logic
assert False, "Dynamic tree is not supported yet"
# If is the static tree
else:
self.spec_decoding_position_offsets[
:,
].copy_(
spec_tree_manager.spec_dec_position_offsets[0:, :],
non_blocking=True)
self.spec_decoding_packed_mask[:, :, :].copy_(
spec_tree_manager.spec_dec_pack_mask[0:, :, :],
non_blocking=True)
self.spec_decoding_generation_lengths[:].fill_(
spec_tree_manager.max_total_draft_tokens + 1)
else:
# Prepare for the linear-tree.
# Populate the mask that won't change during inference phase.
self.generate_spec_decoding_position_offsets(
max_draft_tokens=max_draft_tokens)
max_total_draft_tokens=max_total_draft_tokens)
self.generate_spec_decoding_packed_mask(
max_draft_tokens=max_draft_tokens)
max_total_draft_tokens=max_total_draft_tokens)
self.generate_spec_decoding_generation_length(
max_draft_tokens=max_draft_tokens)
max_total_draft_tokens=max_total_draft_tokens)

def generate_spec_decoding_position_offsets(self, max_draft_tokens):
assert not self.is_spec_dec_tree, "only chained/linear tree is supported now"
position_offset = torch.arange(max_draft_tokens + 1,
def generate_spec_decoding_position_offsets(self, max_total_draft_tokens):
position_offset = torch.arange(max_total_draft_tokens + 1,
dtype=torch.int,
device='cpu',
pin_memory=True)

# fill all the batches with same position offset
self.spec_decoding_position_offsets.copy_(position_offset,
non_blocking=True)

def generate_spec_decoding_packed_mask(self, max_draft_tokens):
assert not self.is_spec_dec_tree, "only chained/linear tree is supported now"
dummy_idx = torch.arange(max_draft_tokens + 1)
self.spec_decoding_position_offsets[:, :max_total_draft_tokens +
1].copy_(position_offset,
non_blocking=True)

def generate_spec_decoding_packed_mask(self, max_total_draft_tokens):
# TODO: fix this limitation
assert max_total_draft_tokens < 32, "max_total_draft_tokens should be less than 32, will be fixed later"
dummy_idx = torch.arange(max_total_draft_tokens + 1)
spec_decoding_packed_mask = torch.pow(2, dummy_idx + 1) - 1
self.spec_decoding_packed_mask[:, :, 0].copy_(spec_decoding_packed_mask,
non_blocking=True)
self.spec_decoding_packed_mask[:, :max_total_draft_tokens + 1,
0].copy_(spec_decoding_packed_mask,
non_blocking=True)

def generate_spec_decoding_generation_length(self, max_draft_tokens):
def generate_spec_decoding_generation_length(self, max_total_draft_tokens):
spec_decoding_generation_length = torch.full((self.max_num_requests, ),
max_draft_tokens + 1)
max_total_draft_tokens + 1)
self.spec_decoding_generation_lengths[:self.max_num_requests].copy_(
spec_decoding_generation_length, non_blocking=True)

Expand Down
13 changes: 6 additions & 7 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,11 @@ def create_autodeploy_executor(ad_config: LlmArgs):
max_draft_len = (
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_len
)
max_total_draft_tokens = 0
if ad_config.speculative_config is None:
max_total_draft_tokens = 0
elif hasattr(ad_config.speculative_config, "max_total_draft_tokens"):
max_total_draft_tokens = ad_config.speculative_config.max_total_draft_tokens
else:
max_total_draft_tokens = max_draft_len
max_total_draft_tokens = (
0
if ad_config.speculative_config is None
else ad_config.speculative_config.max_total_draft_tokens
)

# initialize model engine
engine = ADEngine.build_from_config(ad_config=ad_config)
Expand Down Expand Up @@ -374,6 +372,7 @@ def create_autodeploy_executor(ad_config: LlmArgs):
max_input_len=ad_config.max_input_len,
max_batch_size=ad_config.max_batch_size,
max_draft_len=max_draft_len,
max_total_draft_tokens=max_total_draft_tokens,
max_beam_width=ad_config.max_beam_width,
)
return py_executor
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def __init__(
aux_stream: Optional[torch.cuda.Stream] = None,
):
config = model_config.pretrained_config
predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
super().__init__(hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
Expand Down
15 changes: 6 additions & 9 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ def _get_token_num_for_estimation(self) -> int:
if not pytorch_backend_config.disable_overlap_scheduler:
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
if spec_cfg is not None:
num_extra_tokens_per_seq += spec_cfg.max_draft_len
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens

if spec_cfg is not None:
num_extra_tokens_per_seq += spec_cfg.max_draft_len
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg)

if self._dummy_reqs is None:
Expand Down Expand Up @@ -751,6 +751,8 @@ def create_py_executor_instance(
max_beam_width=max_beam_width,
max_draft_len=spec_config.max_draft_len
if spec_config is not None else 0,
max_total_draft_tokens=spec_config.max_total_draft_tokens
if spec_config is not None else 0,
kv_cache_transceiver=kv_cache_transceiver,
guided_decoder=guided_decoder,
start_worker=start_worker,
Expand All @@ -767,13 +769,8 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
max_num_sequences = max_batch_size * mapping.pp_size
max_draft_len = (0 if speculative_config is None else
speculative_config.max_draft_len)
max_total_draft_tokens = 0
if speculative_config is None:
max_total_draft_tokens = 0
elif hasattr(speculative_config, 'max_total_draft_tokens'):
max_total_draft_tokens = speculative_config.max_total_draft_tokens
else:
max_total_draft_tokens = max_draft_len
max_total_draft_tokens = (0 if speculative_config is None else
speculative_config.max_total_draft_tokens)

return TorchSampler.Args(
max_seq_len=max_seq_len,
Expand Down
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def enable_spec_decode(self):
@property
def max_possible_draft_len(self):
engine = self._get_engine()
return (engine.original_max_draft_len if self.enable_spec_decode else 0)
return (engine.original_max_total_draft_tokens
if self.enable_spec_decode else 0)

def get_graph_key(
self,
Expand All @@ -102,10 +103,12 @@ def get_graph_key(
engine = self._get_engine()
if engine.is_draft_model and spec_resource_manager is not None and isinstance(
spec_resource_manager, Eagle3ResourceManager):
# If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'.
# Because we will pad the input to 'max_draft_len' length for the first draft layer.
draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0
key = (batch_size, draft_len, spec_resource_manager.is_first_draft)
else:
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
draft_len = self.spec_config.max_total_draft_tokens if self.enable_spec_decode else 0
key = (batch_size, draft_len, False)
return key

Expand Down Expand Up @@ -344,7 +347,7 @@ def _get_padded_batch(self, batch: ScheduledRequests,
self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
[CUDA_GRAPH_DUMMY_REQUEST_ID],
is_gen=True,
max_num_draft_tokens=engine.max_draft_len,
max_num_draft_tokens=engine.max_total_draft_tokens,
use_mrope=engine.use_mrope,
max_beam_width=engine.max_beam_width)[0]
self.padding_dummy_request.is_cuda_graph_dummy = True
Expand Down
Loading
Loading