Skip to content

Commit 040bc9a

Browse files
committed
init code
Signed-off-by: Yue Weng <[email protected]>
1 parent a7c2c8c commit 040bc9a

File tree

13 files changed

+1551
-342
lines changed

13 files changed

+1551
-342
lines changed

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,10 @@ def restore_from_spec_dec(self) -> None:
336336
def update_spec_dec_param(
337337
self,
338338
is_spec_decoding_enabled,
339-
is_spec_dec_tree,
340-
is_spec_dec_dynamic_tree,
341-
max_draft_tokens,
339+
spec_metadata,
340+
spec_tree_manager,
341+
max_draft_len,
342+
max_total_draft_tokens,
342343
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None):
343344
"""
344345
Hook to be called when using TRTLLM attention backend in spec-dec mode.

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,47 +1051,46 @@ def prepare_context_mla_with_cached_kv(self,
10511051
def update_spec_dec_param(
10521052
self,
10531053
is_spec_decoding_enabled,
1054-
is_spec_dec_tree,
1055-
is_spec_dec_dynamic_tree,
1056-
max_draft_tokens,
1054+
spec_metadata,
1055+
spec_tree_manager,
1056+
max_draft_len,
1057+
max_total_draft_tokens,
10571058
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
10581059
):
10591060

10601061
if spec_decoding_tensor is not None:
1061-
spec_decoding_position_offsets = spec_decoding_tensor.position_offsets
1062-
spec_decoding_packed_mask = spec_decoding_tensor.packed_mask
1063-
spec_decoding_generation_lengths = spec_decoding_tensor.generation_lengths
1062+
spec_decoding_tensor.position_offsets
1063+
spec_decoding_tensor.packed_mask
1064+
spec_decoding_tensor.generation_lengths
10641065
else:
1065-
spec_decoding_position_offsets = None
1066-
spec_decoding_packed_mask = None
1067-
spec_decoding_generation_lengths = None
1066+
pass
10681067
# spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
10691068
self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version(
10701069
) < 100
10711070

1071+
self.is_spec_dec_tree = False if spec_tree_manager is None else True
1072+
self.is_spec_dec_dynamic_tree = False if spec_tree_manager is None else spec_tree_manager.use_dynamic_tree
1073+
10721074
if get_sm_version() >= 100:
1073-
if is_spec_dec_tree or is_spec_dec_dynamic_tree:
1074-
assert not is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
1075-
assert not is_spec_dec_dynamic_tree, "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."
1075+
if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree:
1076+
assert not self.is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
1077+
assert not self.is_spec_dec_dynamic_tree, "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."
10761078

10771079
# use_spec_decoding is default to true by default, change in runtime by layers / requests
10781080
self.use_spec_decoding = self.is_spec_decoding_enabled
10791081

1080-
self.is_spec_dec_tree = is_spec_dec_tree
1081-
self.is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree
1082-
10831082
# Parameters can be fixed and not changed during runtime if the
10841083
if self.is_spec_decoding_enabled:
10851084
self.spec_decoding_position_offsets = torch.empty(
1086-
[self.max_num_requests, max_draft_tokens + 1],
1085+
[self.max_num_requests, max_total_draft_tokens + 1],
10871086
dtype=torch.int,
10881087
device='cuda',
10891088
)
10901089

10911090
self.spec_decoding_packed_mask = torch.empty(
10921091
[
1093-
self.max_num_requests, max_draft_tokens + 1,
1094-
math.ceil((max_draft_tokens + 1) / 32)
1092+
self.max_num_requests, max_total_draft_tokens + 1,
1093+
math.ceil((max_total_draft_tokens + 1) / 32)
10951094
],
10961095
dtype=torch.int,
10971096
device='cuda',
@@ -1103,30 +1102,41 @@ def update_spec_dec_param(
11031102
device='cuda',
11041103
)
11051104

1106-
if self.is_spec_dec_dynamic_tree:
1107-
assert spec_decoding_position_offsets is not None, "spec_decoding_position_offsets is required for dynamic tree"
1108-
assert spec_decoding_packed_mask is not None, "spec_decoding_packed_mask is required for dynamic tree"
1109-
self.spec_decoding_position_offsets.copy_(
1110-
spec_decoding_position_offsets, non_blocking=True)
1111-
self.spec_decoding_packed_mask.copy_(spec_decoding_packed_mask,
1112-
non_blocking=True)
1113-
if spec_decoding_generation_lengths is not None:
1114-
self.spec_decoding_generation_lengths.copy_(
1115-
spec_decoding_generation_lengths, non_blocking=True)
1105+
# Prepare the spec-dec mask, position offset and generation length for static tree of dynamic tree.
1106+
# We only prepare the spec-dec mask, position offset and generation length for the target model here.
1107+
# For the drafter model, we will prepare them in the drafting loops.
1108+
is_target_model = not spec_metadata.is_draft_model
1109+
is_using_tree = self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree
1110+
if is_target_model and is_using_tree:
1111+
assert spec_metadata.spec_dec_mode.is_eagle3(
1112+
), "Tree decoding is only supported for Eagle3 now"
1113+
# If is the dynamic tree
1114+
if self.is_spec_dec_dynamic_tree:
1115+
# TODO: add dynamic tree logic
1116+
assert False, "Dynamic tree is not supported yet"
1117+
# If is the static tree
11161118
else:
1117-
self.generate_spec_decoding_generation_length(
1118-
max_draft_tokens=max_draft_tokens)
1119+
self.spec_decoding_position_offsets[
1120+
:,
1121+
].copy_(spec_tree_manager.spec_dec_position_offsets[0, :],
1122+
non_blocking=True)
1123+
self.spec_decoding_packed_mask[:, :, :].copy_(
1124+
spec_tree_manager.spec_dec_packed_mask[0, :, :],
1125+
non_blocking=True)
1126+
self.spec_decoding_generation_lengths[:].fill_(
1127+
spec_tree_manager.max_total_draft_tokens + 1)
11191128
else:
1129+
# Prepare for the linear-tree.
11201130
# Populate the mask that won't change during inference phase.
11211131
self.generate_spec_decoding_position_offsets(
1122-
max_draft_tokens=max_draft_tokens)
1132+
max_total_draft_tokens=max_total_draft_tokens)
11231133
self.generate_spec_decoding_packed_mask(
1124-
max_draft_tokens=max_draft_tokens)
1134+
max_total_draft_tokens=max_total_draft_tokens)
11251135
self.generate_spec_decoding_generation_length(
1126-
max_draft_tokens=max_draft_tokens)
1136+
max_total_draft_tokens=max_total_draft_tokens)
11271137

1128-
def generate_spec_decoding_position_offsets(self, max_draft_tokens):
1129-
position_offset = torch.arange(max_draft_tokens + 1,
1138+
def generate_spec_decoding_position_offsets(self, max_total_draft_tokens):
1139+
position_offset = torch.arange(max_total_draft_tokens + 1,
11301140
dtype=torch.int,
11311141
device='cpu',
11321142
pin_memory=True)
@@ -1135,15 +1145,17 @@ def generate_spec_decoding_position_offsets(self, max_draft_tokens):
11351145
self.spec_decoding_position_offsets.copy_(position_offset,
11361146
non_blocking=True)
11371147

1138-
def generate_spec_decoding_packed_mask(self, max_draft_tokens):
1139-
dummy_idx = torch.arange(max_draft_tokens + 1)
1148+
def generate_spec_decoding_packed_mask(self, max_total_draft_tokens):
1149+
# TODO: fix this limitation
1150+
assert max_total_draft_tokens < 32, "max_total_draft_tokens should be less than 32, will be fixed later"
1151+
dummy_idx = torch.arange(max_total_draft_tokens + 1)
11401152
spec_decoding_packed_mask = torch.pow(2, dummy_idx + 1) - 1
11411153
self.spec_decoding_packed_mask[:, :, 0].copy_(spec_decoding_packed_mask,
11421154
non_blocking=True)
11431155

1144-
def generate_spec_decoding_generation_length(self, max_draft_tokens):
1156+
def generate_spec_decoding_generation_length(self, max_total_draft_tokens):
11451157
spec_decoding_generation_length = torch.full((self.max_num_requests, ),
1146-
max_draft_tokens + 1)
1158+
max_total_draft_tokens + 1)
11471159
self.spec_decoding_generation_lengths[:self.max_num_requests].copy_(
11481160
spec_decoding_generation_length, non_blocking=True)
11491161

0 commit comments

Comments
 (0)