Skip to content

Commit b17a837

Browse files
committed
fix ci
Signed-off-by: Yue Weng <[email protected]>
1 parent ef8b2b6 commit b17a837

File tree

7 files changed

+91
-48
lines changed

7 files changed

+91
-48
lines changed

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
if TYPE_CHECKING:
1313
from ..speculative.utils import SpecDecodingTensor
14+
from ..speculative.interface import SpecMetadata
15+
from ..speculative.spec_tree_manager import SpecTreeManager
1416

1517
from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
1618
RotaryScalingType)
@@ -337,10 +339,12 @@ def update_spec_dec_param(
337339
self,
338340
batch_size,
339341
is_spec_decoding_enabled,
340-
spec_metadata,
341-
spec_tree_manager,
342+
is_spec_dec_tree,
343+
is_spec_dec_dynamic_tree,
342344
max_draft_len,
343345
max_total_draft_tokens,
346+
spec_metadata: Optional['SpecMetadata'] = None,
347+
spec_tree_manager: Optional['SpecTreeManager'] = None,
344348
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None):
345349
"""
346350
Hook to be called when using TRTLLM attention backend in spec-dec mode.

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
if TYPE_CHECKING:
1010
from ..speculative.utils import SpecDecodingTensor
11+
from ..speculative.interface import SpecMetadata
12+
from ..speculative.spec_tree_manager import SpecTreeManager
1113

1214
from tensorrt_llm._utils import get_sm_version
1315
from tensorrt_llm.bindings.internal import thop
@@ -1052,25 +1054,26 @@ def update_spec_dec_param(
10521054
self,
10531055
batch_size,
10541056
is_spec_decoding_enabled,
1055-
spec_metadata,
1056-
spec_tree_manager,
1057+
is_spec_dec_tree,
1058+
is_spec_dec_dynamic_tree,
10571059
max_draft_len,
10581060
max_total_draft_tokens,
1061+
spec_metadata: Optional['SpecMetadata'] = None,
1062+
spec_tree_manager: Optional['SpecTreeManager'] = None,
10591063
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
10601064
):
10611065
if spec_decoding_tensor is not None:
1062-
spec_decoding_tensor.position_offsets
1063-
spec_decoding_tensor.packed_mask
1064-
spec_decoding_tensor.generation_lengths
1066+
spec_decoding_position_offsets = spec_decoding_tensor.position_offsets
1067+
spec_decoding_packed_mask = spec_decoding_tensor.packed_mask
1068+
spec_decoding_generation_lengths = spec_decoding_tensor.generation_lengths
10651069
else:
1066-
pass
1070+
spec_decoding_position_offsets = None
1071+
spec_decoding_packed_mask = None
1072+
spec_decoding_generation_lengths = None
10671073
# spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
10681074
self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version(
10691075
) < 100
10701076

1071-
self.is_spec_dec_tree = False if spec_tree_manager is None else True
1072-
self.is_spec_dec_dynamic_tree = False if spec_tree_manager is None else spec_tree_manager.use_dynamic_tree
1073-
10741077
if get_sm_version() >= 100:
10751078
if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree:
10761079
assert not self.is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
@@ -1079,6 +1082,9 @@ def update_spec_dec_param(
10791082
# use_spec_decoding is default to true by default, change in runtime by layers / requests
10801083
self.use_spec_decoding = self.is_spec_decoding_enabled
10811084

1085+
self.is_spec_dec_tree = is_spec_dec_tree
1086+
self.is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree
1087+
10821088
# Parameters can be fixed and not changed during runtime if the
10831089
if self.is_spec_decoding_enabled:
10841090
# These buffers are accessed more like removing input padding,
@@ -1104,28 +1110,40 @@ def update_spec_dec_param(
11041110
device='cuda',
11051111
)
11061112

1107-
# Prepare the spec-dec mask, position offset and generation length for static tree of dynamic tree.
1108-
# We only prepare the spec-dec mask, position offset and generation length for the target model here.
1109-
# For the drafter model, we will prepare them in the drafting loops.
1110-
is_target_model = not spec_metadata.is_draft_model
1111-
is_using_tree = self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree
1112-
if is_target_model and is_using_tree:
1113+
is_target_model = not spec_metadata.is_draft_model if hasattr(
1114+
spec_metadata, 'is_draft_model') else False
1115+
1116+
if self.is_spec_dec_tree and self.is_spec_dec_dynamic_tree:
1117+
# dynamic tree
1118+
assert spec_decoding_position_offsets is not None, "spec_decoding_position_offsets is required for dynamic tree"
1119+
assert spec_decoding_packed_mask is not None, "spec_decoding_packed_mask is required for dynamic tree"
1120+
self.spec_decoding_position_offsets.copy_(
1121+
spec_decoding_position_offsets, non_blocking=True)
1122+
self.spec_decoding_packed_mask.copy_(spec_decoding_packed_mask,
1123+
non_blocking=True)
1124+
if spec_decoding_generation_lengths is not None:
1125+
self.spec_decoding_generation_lengths.copy_(
1126+
spec_decoding_generation_lengths, non_blocking=True)
1127+
else:
1128+
self.generate_spec_decoding_generation_length(
1129+
batch_size=batch_size,
1130+
max_draft_len=max_total_draft_tokens)
1131+
elif self.is_spec_dec_tree and not self.is_spec_dec_dynamic_tree and spec_metadata is not None and is_target_model:
1132+
# static tree and target model
1133+
# Prepare the spec-dec mask, position offset and generation length for static tree.
1134+
# We only prepare the spec-dec mask, position offset and generation length for the target model here.
1135+
# For the drafter model, we will prepare them in the drafting loops.
1136+
11131137
assert spec_metadata.spec_dec_mode.is_eagle3(
11141138
), "Tree decoding is only supported for Eagle3 now"
1115-
# If is the dynamic tree
1116-
if self.is_spec_dec_dynamic_tree:
1117-
# TODO: add dynamic tree logic
1118-
assert False, "Dynamic tree is not supported yet"
1119-
# If is the static tree
1120-
else:
1121-
self.spec_decoding_position_offsets[:batch_size, :].copy_(
1122-
spec_tree_manager.spec_dec_position_offsets[0, :],
1123-
non_blocking=True)
1124-
self.spec_decoding_packed_mask[:batch_size, :, :].copy_(
1125-
spec_tree_manager.spec_dec_packed_mask[0, :, :],
1126-
non_blocking=True)
1127-
self.spec_decoding_generation_lengths[:batch_size].fill_(
1128-
spec_tree_manager.max_total_draft_tokens + 1)
1139+
self.spec_decoding_position_offsets[:batch_size, :].copy_(
1140+
spec_tree_manager.spec_dec_position_offsets[0, :],
1141+
non_blocking=True)
1142+
self.spec_decoding_packed_mask[:batch_size, :, :].copy_(
1143+
spec_tree_manager.spec_dec_packed_mask[0, :, :],
1144+
non_blocking=True)
1145+
self.spec_decoding_generation_lengths[:batch_size].fill_(
1146+
spec_tree_manager.max_total_draft_tokens + 1)
11291147
else:
11301148
# Prepare for the linear-tree.
11311149
# Populate the mask that won't change during inference phase.
@@ -1142,7 +1160,6 @@ def generate_spec_decoding_position_offsets(self, batch_size,
11421160
dtype=torch.int,
11431161
device='cpu',
11441162
pin_memory=True).repeat(batch_size)
1145-
#
11461163
# fill all the batches with same position offset
11471164
self.spec_decoding_position_offsets.reshape(-1)[:(max_draft_len + 1) *
11481165
batch_size].copy_(

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,7 +1341,9 @@ def _prepare_tp_inputs(
13411341
if spec_config is not None:
13421342
spec_resource_manager = resource_manager.get_resource_manager(
13431343
ResourceManagerType.SPEC_RESOURCE_MANAGER)
1344-
spec_tree_manager = spec_resource_manager.spec_tree_manager
1344+
if spec_resource_manager is not None and hasattr(
1345+
spec_resource_manager, 'spec_tree_manager'):
1346+
spec_tree_manager = spec_resource_manager.spec_tree_manager
13451347

13461348
# will contain previous batch indices of generation requests
13471349
previous_batch_indices = []
@@ -2330,9 +2332,15 @@ def forward(
23302332
spec_resource_manager, self.is_draft_model, self.attn_backend,
23312333
self.model_is_wrapped, spec_metadata.is_spec_dec_tree)
23322334
attn_metadata.update_spec_dec_param(
2333-
scheduled_requests.batch_size, is_spec_dec_mode, spec_metadata,
2334-
spec_tree_manager, self.original_max_draft_len,
2335-
self.original_max_total_draft_tokens, spec_decoding_tensor)
2335+
batch_size=scheduled_requests.batch_size,
2336+
is_spec_decoding_enabled=is_spec_dec_mode,
2337+
is_spec_dec_tree=spec_metadata.is_spec_dec_tree,
2338+
is_spec_dec_dynamic_tree=spec_metadata.is_spec_dec_dynamic_tree,
2339+
max_draft_len=spec_metadata.max_draft_len,
2340+
max_total_draft_tokens=spec_metadata.max_total_draft_tokens,
2341+
spec_metadata=spec_metadata,
2342+
spec_tree_manager=spec_tree_manager,
2343+
spec_decoding_tensor=spec_decoding_tensor)
23362344
else:
23372345
spec_resource_manager = None
23382346
spec_metadata = None

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype,
4848
self.max_total_draft_tokens = self.max_draft_len
4949

5050
# empty hidden states tensor
51-
max_num_tokens = min(
52-
max_num_tokens, max_num_requests *
53-
(self.max_total_draft_tokens + 1)) + (self.max_total_draft_tokens +
54-
1)
51+
max_num_tokens = min(max_num_tokens, max_num_requests *
52+
self.max_seq_len) + (self.max_total_draft_tokens +
53+
1) * max_num_requests
54+
5555
self.hidden_states = torch.empty(
5656
(max_num_tokens, self.hidden_size * config.num_capture_layers),
5757
dtype=self.dtype,
@@ -165,6 +165,7 @@ def __post_init__(self):
165165

166166
def prepare(self):
167167
is_first_draft = self.eagle3_resource_manager.is_first_draft
168+
spec_tree_manager = self.eagle3_resource_manager.spec_tree_manager
168169
# Update start indices
169170
# Here, we assume the sequence lengths (seq_lens) during the draft model
170171
# forward will not exceed those of the target model. So pre-allocate
@@ -186,18 +187,19 @@ def prepare(self):
186187
for req_id, seq_len in zip(self.request_ids, self.seq_lens):
187188
slot_id = self.eagle3_resource_manager.slot_manager.get_slot(req_id)
188189
start_idx = self.eagle3_resource_manager.start_indices[slot_id]
189-
# 1) target model
190+
# 1) target model or (is_first_draft and is_linear_tree)
190191
# If this is the first draft or the target model forward, we need to
191192
# read/write all of the hidden states
192-
if not self.is_draft_model:
193+
if not self.is_draft_model or (is_first_draft
194+
and spec_tree_manager is None):
193195
hidden_states_read_indices.extend(
194196
list(range(start_idx, start_idx + seq_len)))
195197
hidden_states_write_indices.extend(
196198
list(range(start_idx, start_idx + seq_len)))
197-
# 2)is_first_draft
199+
# 2)is_first_draft and draft_token_tree
198200
# After target model forward, some draft tokens will be accepted.
199201
# These draft tokens' hidden states will be used for draft model's first drafter layer.
200-
elif is_first_draft:
202+
elif is_first_draft and spec_tree_manager is not None:
201203
assert req_id in self.request_accepted_path.keys(
202204
), f"Request {req_id} not found in request_accepted_path"
203205
accepted_path = self.request_accepted_path[req_id]

tensorrt_llm/_torch/speculative/ngram.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def __init__(
171171
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
172172
self.spec_config = spec_config
173173
self.max_draft_len = spec_config.max_draft_len
174+
self.max_total_draft_tokens = spec_config.max_total_draft_tokens
175+
assert self.max_draft_len == self.max_total_draft_tokens, "NGram only supports linear tree."
174176
self.spec_resource_manager = ngram_pool_manager
175177

176178
def prepare_draft_tokens(

tensorrt_llm/llmapi/llm_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,10 @@ def spec_dec_mode(self):
495495
return TorchSpeculativeDecodingMode.from_string(
496496
self.decoding_type.upper())
497497

498+
@functools.cached_property
499+
def is_linear_tree(self) -> bool:
500+
return self.max_draft_len == self.max_total_draft_tokens
501+
498502

499503
class KvCacheConnectorConfig(StrictBaseModel):
500504
"""

tests/unittest/_torch/modeling/test_modeling_llama.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def run_forward(input_ids, position_ids, attn_metadata):
516516
use_spec_decoding = True
517517
is_spec_dec_tree = True
518518
is_spec_dec_dynamic_tree = True
519-
max_draft_tokens = gen_input_ids_0.size(-1) - 1
519+
max_total_draft_tokens = gen_input_ids_0.size(-1) - 1
520520

521521
attn_metadata_gen_phase_0 = metadata_cls(
522522
seq_lens=torch.tensor([gen_input_ids_0.size(-1)], dtype=torch.int),
@@ -540,10 +540,12 @@ def run_forward(input_ids, position_ids, attn_metadata):
540540
packed_mask=spec_decoding_packed_mask)
541541

542542
attn_metadata_gen_phase_0.update_spec_dec_param(
543+
batch_size=batch_size,
543544
is_spec_decoding_enabled=is_spec_decoding_enabled,
544545
is_spec_dec_dynamic_tree=is_spec_dec_dynamic_tree,
545546
is_spec_dec_tree=is_spec_dec_tree,
546-
max_draft_tokens=max_draft_tokens,
547+
max_draft_len=max_total_draft_tokens,
548+
max_total_draft_tokens=max_total_draft_tokens,
547549
spec_decoding_tensor=spec_decoding_tensor,
548550
)
549551

@@ -586,10 +588,12 @@ def run_forward(input_ids, position_ids, attn_metadata):
586588
[gen_input_ids_1.size(-1)], dtype=torch.int)
587589
attn_metadata_gen_phase_0.kv_cache_params.num_cached_tokens_per_seq = num_cached_tokens_per_seq_1
588590
attn_metadata_gen_phase_0.update_spec_dec_param(
591+
batch_size=batch_size,
589592
is_spec_decoding_enabled=is_spec_decoding_enabled,
590593
is_spec_dec_tree=is_spec_dec_tree,
591594
is_spec_dec_dynamic_tree=False,
592-
max_draft_tokens=gen_input_ids_1.size(-1) - 1)
595+
max_draft_len=gen_input_ids_1.size(-1) - 1,
596+
max_total_draft_tokens=gen_input_ids_1.size(-1) - 1)
593597

594598
gen_position_ids_1 = [
595599
torch.full(
@@ -630,10 +634,12 @@ def run_forward(input_ids, position_ids, attn_metadata):
630634
is_spec_dec_tree=is_spec_dec_tree,
631635
is_spec_dec_dynamic_tree=False)
632636
attn_metadata_ref.update_spec_dec_param(
637+
batch_size=batch_size,
633638
is_spec_decoding_enabled=is_spec_decoding_enabled,
634639
is_spec_dec_tree=is_spec_dec_tree,
635640
is_spec_dec_dynamic_tree=False,
636-
max_draft_tokens=gen_input_ids_ref.size(-1) - 1,
641+
max_draft_len=gen_input_ids_ref.size(-1) - 1,
642+
max_total_draft_tokens=gen_input_ids_ref.size(-1) - 1,
637643
)
638644

639645
gen_position_ids_ref = [

0 commit comments

Comments
 (0)