Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "tensorrt_llm/runtime/utils/debugUtils.h"
#include "tensorrt_llm/thop/attentionOp.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <assert.h>
#include <cstdint>
#include <functional>
#include <torch/extension.h>
Expand Down Expand Up @@ -466,7 +467,8 @@ class Runner : public RunnerBase
= spec_decoding_tensor_params[1].value().data_ptr<int32_t>();
enqueue_params.spec_decoding_packed_mask = spec_decoding_tensor_params[2].value().data_ptr<int32_t>();
enqueue_params.spec_decoding_is_generation_length_variable = true;
enqueue_params.spec_decoding_max_generation_length = input_seq_length + 1;
assert(spec_decoding_tensor_params[1].value().dim() == 2); // [batch_size, max_draft_len + 1]
enqueue_params.spec_decoding_max_generation_length = spec_decoding_tensor_params[1].value().sizes()[1];
}

// Current mlaGeneration will using fmha to do attention, so we don't go into enqueueGeneration
Expand Down
9 changes: 8 additions & 1 deletion tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

if TYPE_CHECKING:
from ..speculative.utils import SpecDecodingTensor
from ..speculative.interface import SpecMetadata
from ..speculative.spec_tree_manager import SpecTreeManager

from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
RotaryScalingType)
Expand Down Expand Up @@ -335,10 +337,15 @@ def restore_from_spec_dec(self) -> None:

def update_spec_dec_param(
self,
batch_size,
is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree,
max_draft_tokens,
max_draft_len,
max_total_draft_tokens,
model_is_wrapped: Optional[bool] = False,
spec_metadata: Optional['SpecMetadata'] = None,
spec_tree_manager: Optional['SpecTreeManager'] = None,
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None):
"""
Hook to be called when using TRTLLM attention backend in spec-dec mode.
Expand Down
116 changes: 96 additions & 20 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

if TYPE_CHECKING:
from ..speculative.utils import SpecDecodingTensor
from ..speculative.interface import SpecMetadata
from ..speculative.spec_tree_manager import SpecTreeManager

from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.bindings.internal import thop
Expand Down Expand Up @@ -1055,13 +1057,30 @@ def prepare_context_mla_with_cached_kv(self,

def update_spec_dec_param(
self,
batch_size,
is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree,
max_draft_tokens,
max_draft_len,
max_total_draft_tokens,
model_is_wrapped: Optional[bool] = False,
spec_metadata: Optional['SpecMetadata'] = None,
spec_tree_manager: Optional['SpecTreeManager'] = None,
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
):

'''
Update the spec-dec parameters for the TRTLLM attention layer.
Args:
batch_size: int, the number of requests in the batch.
is_spec_decoding_enabled: bool, whether the attention need to be spec_decoding mode, which is determined by attention_need_spec_dec_mode() function.
is_spec_dec_tree: bool, whether the spec-dec mode is a tree, i.e., static tree or dynamic tree. For linear-tree, it is always False.
is_spec_dec_dynamic_tree: bool, whether using dynamic tree.
max_draft_len: int, the number of the draft layers.
max_total_draft_tokens: int, the number of all nodes in the tree (except the root).
model_is_wrapped: Optional[bool] = False, whether the drafter model is wrapped (i.e, CDL).
spec_metadata: Optional['SpecMetadata'] = None, the metadata of the spec-dec.
spec_tree_manager: Optional['SpecTreeManager'] = None, the spec_tree_manager for draft token tree.
'''
if spec_decoding_tensor is not None:
spec_decoding_position_offsets = spec_decoding_tensor.position_offsets
spec_decoding_packed_mask = spec_decoding_tensor.packed_mask
Expand All @@ -1075,9 +1094,9 @@ def update_spec_dec_param(
) < 100

if get_sm_version() >= 100:
if is_spec_dec_tree or is_spec_dec_dynamic_tree:
assert not is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
assert not is_spec_dec_dynamic_tree, "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."
if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree:
assert not self.is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
assert not self.is_spec_dec_dynamic_tree, "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."

# use_spec_decoding is default to true by default, change in runtime by layers / requests
self.use_spec_decoding = self.is_spec_decoding_enabled
Expand All @@ -1087,16 +1106,18 @@ def update_spec_dec_param(

# Parameters can be fixed and not changed during runtime if the
if self.is_spec_decoding_enabled:
# These buffers are accessed more like removing input padding,
# rather than using max_total_draft_tokens + 1 as the offset between different requests.
self.spec_decoding_position_offsets = torch.empty(
[self.max_num_requests, max_draft_tokens + 1],
[self.max_num_requests, max_total_draft_tokens + 1],
dtype=torch.int,
device='cuda',
)

self.spec_decoding_packed_mask = torch.empty(
[
self.max_num_requests, max_draft_tokens + 1,
math.ceil((max_draft_tokens + 1) / 32)
self.max_num_requests, max_total_draft_tokens + 1,
math.ceil((max_total_draft_tokens + 1) / 32)
],
dtype=torch.int,
device='cuda',
Expand All @@ -1108,7 +1129,11 @@ def update_spec_dec_param(
device='cuda',
)

if self.is_spec_dec_dynamic_tree:
is_target_model = not spec_metadata.is_draft_model if hasattr(
spec_metadata, 'is_draft_model') else False

# Case 1: dynamic tree
if self.is_spec_dec_tree and self.is_spec_dec_dynamic_tree:
assert spec_decoding_position_offsets is not None, "spec_decoding_position_offsets is required for dynamic tree"
assert spec_decoding_packed_mask is not None, "spec_decoding_packed_mask is required for dynamic tree"
self.spec_decoding_position_offsets.copy_(
Expand All @@ -1120,35 +1145,86 @@ def update_spec_dec_param(
spec_decoding_generation_lengths, non_blocking=True)
else:
self.generate_spec_decoding_generation_length(
max_draft_tokens=max_draft_tokens)
max_draft_len=max_total_draft_tokens)

# Case 2/3: static tree
elif self.is_spec_dec_tree and not self.is_spec_dec_dynamic_tree and spec_metadata is not None:
assert spec_metadata.spec_dec_mode.is_eagle3(
), "Tree decoding is only supported for Eagle3 now"

# Case 2: static tree and target model
if is_target_model:
# For the target model, we update the spec-dec parameters with the spec_tree_manager, which is prepared in advance.
self.spec_decoding_position_offsets[:batch_size, :].copy_(
spec_tree_manager.spec_dec_position_offsets[0, :],
non_blocking=True)
self.spec_decoding_packed_mask[:batch_size, :, :].copy_(
spec_tree_manager.spec_dec_packed_mask[0, :, :],
non_blocking=True)
self.spec_decoding_generation_lengths[:batch_size].fill_(
spec_tree_manager.max_total_draft_tokens + 1)

# Case 3: static tree and the first drafter layer
else:
assert model_is_wrapped == True, "The drafter model should be wrapped"
# The first drafter layer will take the padded tokens as input (padding to the max_draft_len + 1)
# But the spec-dec parameters are still in the shape of max_total_draft_tokens + 1.
# Considering that these spec-dec params are accessed consecutively (without padding) in the attention Op,
# we need to write them consecutively when setting them.
# For the next drafter layers, we will prepare these spec-dec params in the drafting loops.
# position_offsets
position_offset = torch.arange(
max_draft_len + 1,
dtype=torch.int,
device='cpu',
pin_memory=True).repeat(batch_size)
self.spec_decoding_position_offsets.reshape(
-1)[:(max_draft_len + 1) * batch_size].copy_(
position_offset, non_blocking=True)
# packed_mask
dummy_idx = torch.arange(max_draft_len + 1)
spec_decoding_packed_mask = torch.pow(
2, dummy_idx + 1) - 1 # [max_draft_len + 1]
spec_decoding_packed_mask = spec_decoding_packed_mask.repeat(
batch_size) # [batch_size * (max_draft_len + 1)]
self.spec_decoding_packed_mask.reshape(
-1)[:(max_draft_len + 1) * batch_size].copy_(
spec_decoding_packed_mask, non_blocking=True)
# generation_lengths
self.generate_spec_decoding_generation_length(
max_draft_len=max_draft_len)

# Case 4: linear tree
else:
# Prepare for the linear-tree.
# Populate the mask that won't change during inference phase.
self.generate_spec_decoding_position_offsets(
max_draft_tokens=max_draft_tokens)
max_draft_len=max_draft_len)
self.generate_spec_decoding_packed_mask(
max_draft_tokens=max_draft_tokens)
max_draft_len=max_draft_len)
self.generate_spec_decoding_generation_length(
max_draft_tokens=max_draft_tokens)
max_draft_len=max_draft_len)

def generate_spec_decoding_position_offsets(self, max_draft_tokens):
position_offset = torch.arange(max_draft_tokens + 1,
def generate_spec_decoding_position_offsets(self, max_draft_len):
position_offset = torch.arange(max_draft_len + 1,
dtype=torch.int,
device='cpu',
pin_memory=True)

# fill all the batches with same position offset
self.spec_decoding_position_offsets.copy_(position_offset,
non_blocking=True)

def generate_spec_decoding_packed_mask(self, max_draft_tokens):
dummy_idx = torch.arange(max_draft_tokens + 1)
def generate_spec_decoding_packed_mask(self, max_draft_len):
# FIXME: remove this limitation
assert max_draft_len < 32, "max_draft_len should be less than 32, will be fixed later"
dummy_idx = torch.arange(max_draft_len + 1)
spec_decoding_packed_mask = torch.pow(2, dummy_idx + 1) - 1
self.spec_decoding_packed_mask[:, :, 0].copy_(spec_decoding_packed_mask,
non_blocking=True)

def generate_spec_decoding_generation_length(self, max_draft_tokens):
def generate_spec_decoding_generation_length(self, max_draft_len):
spec_decoding_generation_length = torch.full((self.max_num_requests, ),
max_draft_tokens + 1)
max_draft_len + 1)
self.spec_decoding_generation_lengths[:self.max_num_requests].copy_(
spec_decoding_generation_length, non_blocking=True)

Expand Down
Loading