Skip to content

Commit 655723b

Browse files
committed
fix ci
Signed-off-by: Yue Weng <[email protected]>
1 parent e0aed01 commit 655723b

File tree

7 files changed

+91
-48
lines changed

7 files changed

+91
-48
lines changed

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
if TYPE_CHECKING:
1313
from ..speculative.utils import SpecDecodingTensor
14+
from ..speculative.interface import SpecMetadata
15+
from ..speculative.spec_tree_manager import SpecTreeManager
1416

1517
from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
1618
RotaryScalingType)
@@ -337,10 +339,12 @@ def update_spec_dec_param(
337339
self,
338340
batch_size,
339341
is_spec_decoding_enabled,
340-
spec_metadata,
341-
spec_tree_manager,
342+
is_spec_dec_tree,
343+
is_spec_dec_dynamic_tree,
342344
max_draft_len,
343345
max_total_draft_tokens,
346+
spec_metadata: Optional['SpecMetadata'] = None,
347+
spec_tree_manager: Optional['SpecTreeManager'] = None,
344348
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None):
345349
"""
346350
Hook to be called when using TRTLLM attention backend in spec-dec mode.

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
if TYPE_CHECKING:
1010
from ..speculative.utils import SpecDecodingTensor
11+
from ..speculative.interface import SpecMetadata
12+
from ..speculative.spec_tree_manager import SpecTreeManager
1113

1214
from tensorrt_llm._utils import get_sm_version
1315
from tensorrt_llm.bindings.internal import thop
@@ -1057,25 +1059,26 @@ def update_spec_dec_param(
10571059
self,
10581060
batch_size,
10591061
is_spec_decoding_enabled,
1060-
spec_metadata,
1061-
spec_tree_manager,
1062+
is_spec_dec_tree,
1063+
is_spec_dec_dynamic_tree,
10621064
max_draft_len,
10631065
max_total_draft_tokens,
1066+
spec_metadata: Optional['SpecMetadata'] = None,
1067+
spec_tree_manager: Optional['SpecTreeManager'] = None,
10641068
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
10651069
):
10661070
if spec_decoding_tensor is not None:
1067-
spec_decoding_tensor.position_offsets
1068-
spec_decoding_tensor.packed_mask
1069-
spec_decoding_tensor.generation_lengths
1071+
spec_decoding_position_offsets = spec_decoding_tensor.position_offsets
1072+
spec_decoding_packed_mask = spec_decoding_tensor.packed_mask
1073+
spec_decoding_generation_lengths = spec_decoding_tensor.generation_lengths
10701074
else:
1071-
pass
1075+
spec_decoding_position_offsets = None
1076+
spec_decoding_packed_mask = None
1077+
spec_decoding_generation_lengths = None
10721078
# spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
10731079
self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version(
10741080
) < 100
10751081

1076-
self.is_spec_dec_tree = False if spec_tree_manager is None else True
1077-
self.is_spec_dec_dynamic_tree = False if spec_tree_manager is None else spec_tree_manager.use_dynamic_tree
1078-
10791082
if get_sm_version() >= 100:
10801083
if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree:
10811084
assert not self.is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
@@ -1084,6 +1087,9 @@ def update_spec_dec_param(
10841087
# use_spec_decoding is default to true by default, change in runtime by layers / requests
10851088
self.use_spec_decoding = self.is_spec_decoding_enabled
10861089

1090+
self.is_spec_dec_tree = is_spec_dec_tree
1091+
self.is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree
1092+
10871093
# Parameters can be fixed and not changed during runtime if the
10881094
if self.is_spec_decoding_enabled:
10891095
# These buffers are accessed more like removing input padding,
@@ -1109,28 +1115,40 @@ def update_spec_dec_param(
11091115
device='cuda',
11101116
)
11111117

1112-
# Prepare the spec-dec mask, position offset and generation length for static tree of dynamic tree.
1113-
# We only prepare the spec-dec mask, position offset and generation length for the target model here.
1114-
# For the drafter model, we will prepare them in the drafting loops.
1115-
is_target_model = not spec_metadata.is_draft_model
1116-
is_using_tree = self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree
1117-
if is_target_model and is_using_tree:
1118+
is_target_model = not spec_metadata.is_draft_model if hasattr(
1119+
spec_metadata, 'is_draft_model') else False
1120+
1121+
if self.is_spec_dec_tree and self.is_spec_dec_dynamic_tree:
1122+
# dynamic tree
1123+
assert spec_decoding_position_offsets is not None, "spec_decoding_position_offsets is required for dynamic tree"
1124+
assert spec_decoding_packed_mask is not None, "spec_decoding_packed_mask is required for dynamic tree"
1125+
self.spec_decoding_position_offsets.copy_(
1126+
spec_decoding_position_offsets, non_blocking=True)
1127+
self.spec_decoding_packed_mask.copy_(spec_decoding_packed_mask,
1128+
non_blocking=True)
1129+
if spec_decoding_generation_lengths is not None:
1130+
self.spec_decoding_generation_lengths.copy_(
1131+
spec_decoding_generation_lengths, non_blocking=True)
1132+
else:
1133+
self.generate_spec_decoding_generation_length(
1134+
batch_size=batch_size,
1135+
max_draft_len=max_total_draft_tokens)
1136+
elif self.is_spec_dec_tree and not self.is_spec_dec_dynamic_tree and spec_metadata is not None and is_target_model:
1137+
# static tree and target model
1138+
# Prepare the spec-dec mask, position offset and generation length for static tree.
1139+
# We only prepare the spec-dec mask, position offset and generation length for the target model here.
1140+
# For the drafter model, we will prepare them in the drafting loops.
1141+
11181142
assert spec_metadata.spec_dec_mode.is_eagle3(
11191143
), "Tree decoding is only supported for Eagle3 now"
1120-
# If is the dynamic tree
1121-
if self.is_spec_dec_dynamic_tree:
1122-
# TODO: add dynamic tree logic
1123-
assert False, "Dynamic tree is not supported yet"
1124-
# If is the static tree
1125-
else:
1126-
self.spec_decoding_position_offsets[:batch_size, :].copy_(
1127-
spec_tree_manager.spec_dec_position_offsets[0, :],
1128-
non_blocking=True)
1129-
self.spec_decoding_packed_mask[:batch_size, :, :].copy_(
1130-
spec_tree_manager.spec_dec_packed_mask[0, :, :],
1131-
non_blocking=True)
1132-
self.spec_decoding_generation_lengths[:batch_size].fill_(
1133-
spec_tree_manager.max_total_draft_tokens + 1)
1144+
self.spec_decoding_position_offsets[:batch_size, :].copy_(
1145+
spec_tree_manager.spec_dec_position_offsets[0, :],
1146+
non_blocking=True)
1147+
self.spec_decoding_packed_mask[:batch_size, :, :].copy_(
1148+
spec_tree_manager.spec_dec_packed_mask[0, :, :],
1149+
non_blocking=True)
1150+
self.spec_decoding_generation_lengths[:batch_size].fill_(
1151+
spec_tree_manager.max_total_draft_tokens + 1)
11341152
else:
11351153
# Prepare for the linear-tree.
11361154
# Populate the mask that won't change during inference phase.
@@ -1147,7 +1165,6 @@ def generate_spec_decoding_position_offsets(self, batch_size,
11471165
dtype=torch.int,
11481166
device='cpu',
11491167
pin_memory=True).repeat(batch_size)
1150-
#
11511168
# fill all the batches with same position offset
11521169
self.spec_decoding_position_offsets.reshape(-1)[:(max_draft_len + 1) *
11531170
batch_size].copy_(

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,7 +1342,9 @@ def _prepare_tp_inputs(
13421342
if spec_config is not None:
13431343
spec_resource_manager = resource_manager.get_resource_manager(
13441344
ResourceManagerType.SPEC_RESOURCE_MANAGER)
1345-
spec_tree_manager = spec_resource_manager.spec_tree_manager
1345+
if spec_resource_manager is not None and hasattr(
1346+
spec_resource_manager, 'spec_tree_manager'):
1347+
spec_tree_manager = spec_resource_manager.spec_tree_manager
13461348

13471349
# will contain previous batch indices of generation requests
13481350
previous_batch_indices = []
@@ -2331,9 +2333,15 @@ def forward(
23312333
spec_resource_manager, self.is_draft_model, self.attn_backend,
23322334
self.model_is_wrapped, spec_metadata.is_spec_dec_tree)
23332335
attn_metadata.update_spec_dec_param(
2334-
scheduled_requests.batch_size, is_spec_dec_mode, spec_metadata,
2335-
spec_tree_manager, self.original_max_draft_len,
2336-
self.original_max_total_draft_tokens, spec_decoding_tensor)
2336+
batch_size=scheduled_requests.batch_size,
2337+
is_spec_decoding_enabled=is_spec_dec_mode,
2338+
is_spec_dec_tree=spec_metadata.is_spec_dec_tree,
2339+
is_spec_dec_dynamic_tree=spec_metadata.is_spec_dec_dynamic_tree,
2340+
max_draft_len=spec_metadata.max_draft_len,
2341+
max_total_draft_tokens=spec_metadata.max_total_draft_tokens,
2342+
spec_metadata=spec_metadata,
2343+
spec_tree_manager=spec_tree_manager,
2344+
spec_decoding_tensor=spec_decoding_tensor)
23372345
else:
23382346
spec_resource_manager = None
23392347
spec_metadata = None

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype,
4848
self.max_total_draft_tokens = self.max_draft_len
4949

5050
# empty hidden states tensor
51-
max_num_tokens = min(
52-
max_num_tokens, max_num_requests *
53-
(self.max_total_draft_tokens + 1)) + (self.max_total_draft_tokens +
54-
1)
51+
max_num_tokens = min(max_num_tokens, max_num_requests *
52+
self.max_seq_len) + (self.max_total_draft_tokens +
53+
1) * max_num_requests
54+
5555
self.hidden_states = torch.empty(
5656
(max_num_tokens, self.hidden_size * config.num_capture_layers),
5757
dtype=self.dtype,
@@ -165,6 +165,7 @@ def __post_init__(self):
165165

166166
def prepare(self):
167167
is_first_draft = self.eagle3_resource_manager.is_first_draft
168+
spec_tree_manager = self.eagle3_resource_manager.spec_tree_manager
168169
# Update start indices
169170
# Here, we assume the sequence lengths (seq_lens) during the draft model
170171
# forward will not exceed those of the target model. So pre-allocate
@@ -186,18 +187,19 @@ def prepare(self):
186187
for req_id, seq_len in zip(self.request_ids, self.seq_lens):
187188
slot_id = self.eagle3_resource_manager.slot_manager.get_slot(req_id)
188189
start_idx = self.eagle3_resource_manager.start_indices[slot_id]
189-
# 1) target model
190+
# 1) target model or (is_first_draft and is_linear_tree)
190191
# If this is the first draft or the target model forward, we need to
191192
# read/write all of the hidden states
192-
if not self.is_draft_model:
193+
if not self.is_draft_model or (is_first_draft
194+
and spec_tree_manager is None):
193195
hidden_states_read_indices.extend(
194196
list(range(start_idx, start_idx + seq_len)))
195197
hidden_states_write_indices.extend(
196198
list(range(start_idx, start_idx + seq_len)))
197-
# 2)is_first_draft
199+
# 2)is_first_draft and draft_token_tree
198200
# After target model forward, some draft tokens will be accepted.
199201
# These draft tokens' hidden states will be used for draft model's first drafter layer.
200-
elif is_first_draft:
202+
elif is_first_draft and spec_tree_manager is not None:
201203
assert req_id in self.request_accepted_path.keys(
202204
), f"Request {req_id} not found in request_accepted_path"
203205
accepted_path = self.request_accepted_path[req_id]

tensorrt_llm/_torch/speculative/ngram.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def __init__(
171171
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
172172
self.spec_config = spec_config
173173
self.max_draft_len = spec_config.max_draft_len
174+
self.max_total_draft_tokens = spec_config.max_total_draft_tokens
175+
assert self.max_draft_len == self.max_total_draft_tokens, "NGram only supports linear tree."
174176
self.spec_resource_manager = ngram_pool_manager
175177

176178
def prepare_draft_tokens(

tensorrt_llm/llmapi/llm_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,10 @@ def spec_dec_mode(self):
520520
return TorchSpeculativeDecodingMode.from_string(
521521
self.decoding_type.upper())
522522

523+
@functools.cached_property
524+
def is_linear_tree(self) -> bool:
525+
return self.max_draft_len == self.max_total_draft_tokens
526+
523527

524528
class KvCacheConnectorConfig(StrictBaseModel):
525529
"""

tests/unittest/_torch/modeling/test_modeling_llama.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def run_forward(input_ids, position_ids, attn_metadata):
516516
use_spec_decoding = True
517517
is_spec_dec_tree = True
518518
is_spec_dec_dynamic_tree = True
519-
max_draft_tokens = gen_input_ids_0.size(-1) - 1
519+
max_total_draft_tokens = gen_input_ids_0.size(-1) - 1
520520

521521
attn_metadata_gen_phase_0 = metadata_cls(
522522
seq_lens=torch.tensor([gen_input_ids_0.size(-1)], dtype=torch.int),
@@ -540,10 +540,12 @@ def run_forward(input_ids, position_ids, attn_metadata):
540540
packed_mask=spec_decoding_packed_mask)
541541

542542
attn_metadata_gen_phase_0.update_spec_dec_param(
543+
batch_size=batch_size,
543544
is_spec_decoding_enabled=is_spec_decoding_enabled,
544545
is_spec_dec_dynamic_tree=is_spec_dec_dynamic_tree,
545546
is_spec_dec_tree=is_spec_dec_tree,
546-
max_draft_tokens=max_draft_tokens,
547+
max_draft_len=max_total_draft_tokens,
548+
max_total_draft_tokens=max_total_draft_tokens,
547549
spec_decoding_tensor=spec_decoding_tensor,
548550
)
549551

@@ -586,10 +588,12 @@ def run_forward(input_ids, position_ids, attn_metadata):
586588
[gen_input_ids_1.size(-1)], dtype=torch.int)
587589
attn_metadata_gen_phase_0.kv_cache_params.num_cached_tokens_per_seq = num_cached_tokens_per_seq_1
588590
attn_metadata_gen_phase_0.update_spec_dec_param(
591+
batch_size=batch_size,
589592
is_spec_decoding_enabled=is_spec_decoding_enabled,
590593
is_spec_dec_tree=is_spec_dec_tree,
591594
is_spec_dec_dynamic_tree=False,
592-
max_draft_tokens=gen_input_ids_1.size(-1) - 1)
595+
max_draft_len=gen_input_ids_1.size(-1) - 1,
596+
max_total_draft_tokens=gen_input_ids_1.size(-1) - 1)
593597

594598
gen_position_ids_1 = [
595599
torch.full(
@@ -630,10 +634,12 @@ def run_forward(input_ids, position_ids, attn_metadata):
630634
is_spec_dec_tree=is_spec_dec_tree,
631635
is_spec_dec_dynamic_tree=False)
632636
attn_metadata_ref.update_spec_dec_param(
637+
batch_size=batch_size,
633638
is_spec_decoding_enabled=is_spec_decoding_enabled,
634639
is_spec_dec_tree=is_spec_dec_tree,
635640
is_spec_dec_dynamic_tree=False,
636-
max_draft_tokens=gen_input_ids_ref.size(-1) - 1,
641+
max_draft_len=gen_input_ids_ref.size(-1) - 1,
642+
max_total_draft_tokens=gen_input_ids_ref.size(-1) - 1,
637643
)
638644

639645
gen_position_ids_ref = [

0 commit comments

Comments
 (0)