Skip to content

Commit ef8b2b6

Browse files
committed
fix XQA issue
Signed-off-by: Yue Weng <[email protected]>
1 parent 040bc9a commit ef8b2b6

File tree

9 files changed

+106
-91
lines changed

9 files changed

+106
-91
lines changed

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "tensorrt_llm/runtime/utils/debugUtils.h"
2525
#include "tensorrt_llm/thop/attentionOp.h"
2626
#include "tensorrt_llm/thop/thUtils.h"
27+
#include <assert.h>
2728
#include <cstdint>
2829
#include <functional>
2930
#include <torch/extension.h>
@@ -466,7 +467,8 @@ class Runner : public RunnerBase
466467
= spec_decoding_tensor_params[1].value().data_ptr<int32_t>();
467468
enqueue_params.spec_decoding_packed_mask = spec_decoding_tensor_params[2].value().data_ptr<int32_t>();
468469
enqueue_params.spec_decoding_is_generation_length_variable = true;
469-
enqueue_params.spec_decoding_max_generation_length = input_seq_length + 1;
470+
assert(spec_decoding_tensor_params[1].value().dim() == 2); // [batch_size, max_draft_len + 1]
471+
enqueue_params.spec_decoding_max_generation_length = spec_decoding_tensor_params[1].value().sizes()[1];
470472
}
471473

472474
// Current mlaGeneration will using fmha to do attention, so we don't go into enqueueGeneration

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def restore_from_spec_dec(self) -> None:
335335

336336
def update_spec_dec_param(
337337
self,
338+
batch_size,
338339
is_spec_decoding_enabled,
339340
spec_metadata,
340341
spec_tree_manager,

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,14 +1050,14 @@ def prepare_context_mla_with_cached_kv(self,
10501050

10511051
def update_spec_dec_param(
10521052
self,
1053+
batch_size,
10531054
is_spec_decoding_enabled,
10541055
spec_metadata,
10551056
spec_tree_manager,
10561057
max_draft_len,
10571058
max_total_draft_tokens,
10581059
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
10591060
):
1060-
10611061
if spec_decoding_tensor is not None:
10621062
spec_decoding_tensor.position_offsets
10631063
spec_decoding_tensor.packed_mask
@@ -1081,6 +1081,8 @@ def update_spec_dec_param(
10811081

10821082
# Parameters can be fixed and not changed during runtime if the
10831083
if self.is_spec_decoding_enabled:
1084+
# These buffers are accessed more like removing input padding,
1085+
# rather than using max_total_draft_tokens + 1 as the offset between different requests.
10841086
self.spec_decoding_position_offsets = torch.empty(
10851087
[self.max_num_requests, max_total_draft_tokens + 1],
10861088
dtype=torch.int,
@@ -1116,47 +1118,54 @@ def update_spec_dec_param(
11161118
assert False, "Dynamic tree is not supported yet"
11171119
# If is the static tree
11181120
else:
1119-
self.spec_decoding_position_offsets[
1120-
:,
1121-
].copy_(spec_tree_manager.spec_dec_position_offsets[0, :],
1122-
non_blocking=True)
1123-
self.spec_decoding_packed_mask[:, :, :].copy_(
1121+
self.spec_decoding_position_offsets[:batch_size, :].copy_(
1122+
spec_tree_manager.spec_dec_position_offsets[0, :],
1123+
non_blocking=True)
1124+
self.spec_decoding_packed_mask[:batch_size, :, :].copy_(
11241125
spec_tree_manager.spec_dec_packed_mask[0, :, :],
11251126
non_blocking=True)
1126-
self.spec_decoding_generation_lengths[:].fill_(
1127+
self.spec_decoding_generation_lengths[:batch_size].fill_(
11271128
spec_tree_manager.max_total_draft_tokens + 1)
11281129
else:
11291130
# Prepare for the linear-tree.
11301131
# Populate the mask that won't change during inference phase.
11311132
self.generate_spec_decoding_position_offsets(
1132-
max_total_draft_tokens=max_total_draft_tokens)
1133+
batch_size=batch_size, max_draft_len=max_draft_len)
11331134
self.generate_spec_decoding_packed_mask(
1134-
max_total_draft_tokens=max_total_draft_tokens)
1135+
batch_size=batch_size, max_draft_len=max_draft_len)
11351136
self.generate_spec_decoding_generation_length(
1136-
max_total_draft_tokens=max_total_draft_tokens)
1137+
batch_size=batch_size, max_draft_len=max_draft_len)
11371138

1138-
def generate_spec_decoding_position_offsets(self, max_total_draft_tokens):
1139-
position_offset = torch.arange(max_total_draft_tokens + 1,
1139+
def generate_spec_decoding_position_offsets(self, batch_size,
1140+
max_draft_len):
1141+
position_offset = torch.arange(max_draft_len + 1,
11401142
dtype=torch.int,
11411143
device='cpu',
1142-
pin_memory=True)
1143-
1144+
pin_memory=True).repeat(batch_size)
1145+
#
11441146
# fill all the batches with same position offset
1145-
self.spec_decoding_position_offsets.copy_(position_offset,
1146-
non_blocking=True)
1147-
1148-
def generate_spec_decoding_packed_mask(self, max_total_draft_tokens):
1149-
# TODO: fix this limitation
1150-
assert max_total_draft_tokens < 32, "max_total_draft_tokens should be less than 32, will be fixed later"
1151-
dummy_idx = torch.arange(max_total_draft_tokens + 1)
1152-
spec_decoding_packed_mask = torch.pow(2, dummy_idx + 1) - 1
1153-
self.spec_decoding_packed_mask[:, :, 0].copy_(spec_decoding_packed_mask,
1154-
non_blocking=True)
1155-
1156-
def generate_spec_decoding_generation_length(self, max_total_draft_tokens):
1157-
spec_decoding_generation_length = torch.full((self.max_num_requests, ),
1158-
max_total_draft_tokens + 1)
1159-
self.spec_decoding_generation_lengths[:self.max_num_requests].copy_(
1147+
self.spec_decoding_position_offsets.reshape(-1)[:(max_draft_len + 1) *
1148+
batch_size].copy_(
1149+
position_offset,
1150+
non_blocking=True)
1151+
1152+
def generate_spec_decoding_packed_mask(self, batch_size, max_draft_len):
1153+
# FIXME: remove this limitation
1154+
assert max_draft_len < 32, "max_draft_len should be less than 32, will be fixed later"
1155+
dummy_idx = torch.arange(max_draft_len + 1)
1156+
spec_decoding_packed_mask = torch.pow(
1157+
2, dummy_idx + 1) - 1 # [max_draft_len + 1]
1158+
spec_decoding_packed_mask = spec_decoding_packed_mask.repeat(
1159+
batch_size) # [batch_size * (max_draft_len + 1)]
1160+
self.spec_decoding_packed_mask.reshape(
1161+
-1)[:(max_draft_len + 1) * batch_size].copy_(
1162+
spec_decoding_packed_mask, non_blocking=True)
1163+
1164+
def generate_spec_decoding_generation_length(self, batch_size,
1165+
max_draft_len):
1166+
spec_decoding_generation_length = torch.full((batch_size, ),
1167+
max_draft_len + 1)
1168+
self.spec_decoding_generation_lengths[:batch_size].copy_(
11601169
spec_decoding_generation_length, non_blocking=True)
11611170

11621171

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,8 +2330,8 @@ def forward(
23302330
spec_resource_manager, self.is_draft_model, self.attn_backend,
23312331
self.model_is_wrapped, spec_metadata.is_spec_dec_tree)
23322332
attn_metadata.update_spec_dec_param(
2333-
is_spec_dec_mode, spec_metadata, spec_tree_manager,
2334-
self.original_max_draft_len,
2333+
scheduled_requests.batch_size, is_spec_dec_mode, spec_metadata,
2334+
spec_tree_manager, self.original_max_draft_len,
23352335
self.original_max_total_draft_tokens, spec_decoding_tensor)
23362336
else:
23372337
spec_resource_manager = None

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def update_kv_cache_draft_token_location(self,
581581
requests = scheduled_batch.all_requests()
582582
accepted_draft_token_offsets, packed_accepted_draft_tokens_indices, rewind_draft_token_separate_adjustments = self.locate_accepted_draft_tokens(
583583
requests)
584-
past_key_value_lengths = attn_metadata.kv_lens_cuda
584+
past_key_value_lengths = attn_metadata.kv_lens_cuda[:len(requests)]
585585
if attn_metadata.kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_pool_pointers is not None and attn_metadata.host_kv_cache_pool_mapping is not None:
586586
use_paged_kv_cache = True
587587
else:

tensorrt_llm/_torch/speculative/drafting_loops.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ def prepare_for_generation_with_tree_decoding(
192192
batch_size] -= prev_layer_gen_len_per_req # reset to original length before the drafter loop.
193193
attn_metadata.kv_lens_cuda[:batch_size] += next_layer_gen_len_per_req
194194

195+
# FIXME, update without D2H
196+
# attn_metadata.kv_lens[:batch_size] = attn_metadata.kv_lens_cuda[:batch_size].cpu()
197+
195198
## 3.2) _seq_lens, _seq_lens_cuda
196199
attn_metadata._seq_lens[:batch_size].fill_(next_layer_gen_len_per_req)
197200
attn_metadata._seq_lens_cuda[:batch_size].fill_(next_layer_gen_len_per_req)
@@ -207,23 +210,22 @@ def prepare_for_generation_with_tree_decoding(
207210
attn_metadata.use_spec_decoding = True
208211

209212
## 3.6) spec_decoding_position_offsets
210-
attn_metadata.spec_decoding_position_offsets[:, :
211-
next_layer_gen_len_per_req] = spec_tree_manager.spec_dec_position_offsets_for_drafter_model[
212-
prepare_for_layer_idx -
213-
1].unsqueeze(0)
214-
attn_metadata.spec_decoding_position_offsets[:,
215-
next_layer_gen_len_per_req:] = 0
213+
attn_metadata.spec_decoding_position_offsets.reshape(
214+
-1
215+
)[:batch_size *
216+
next_layer_gen_len_per_req] = spec_tree_manager.spec_dec_position_offsets_for_drafter_model[
217+
prepare_for_layer_idx - 1].repeat(batch_size)
216218

217219
## 3.7) spec_decoding_packed_mask
218-
attn_metadata.spec_decoding_packed_mask[:, :
219-
next_layer_gen_len_per_req, :] = spec_tree_manager.spec_dec_packed_mask_for_drafter_model[
220-
prepare_for_layer_idx -
221-
1].unsqueeze(0)
222-
attn_metadata.spec_decoding_packed_mask[:,
223-
next_layer_gen_len_per_req:, :] = 0
220+
attn_metadata.spec_decoding_packed_mask.reshape(
221+
-1, attn_metadata.spec_decoding_packed_mask.size(-1)
222+
)[:batch_size *
223+
next_layer_gen_len_per_req, :] = spec_tree_manager.spec_dec_packed_mask_for_drafter_model[
224+
prepare_for_layer_idx - 1].repeat(batch_size, 1)
224225

225226
## 3.8) spec_decoding_generation_lengths
226-
attn_metadata.spec_decoding_generation_lengths[:] = next_layer_gen_len_per_req
227+
attn_metadata.spec_decoding_generation_lengths[:
228+
batch_size] = next_layer_gen_len_per_req
227229

228230
# 4) spec_metadata
229231
## 4.1) num_tokens

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,20 +130,32 @@ def attention_need_spec_dec_mode(
130130
spec_resource_manager: BaseResourceManager,
131131
is_draft_model: bool,
132132
attention_backend: Type[AttentionBackend],
133-
use_chain_drafter: bool,
133+
use_chain_drafter: bool, # CDL
134134
is_spec_dec_tree: bool,
135135
):
136136
"""
137137
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
138+
Args:
139+
spec_resource_manager: the resource manager for the spec-dec mode.
140+
is_draft_model: whether the model is a draft model.
141+
attention_backend: the attention backend.
142+
use_chain_drafter: whether to use capturable drafting loops (CDL). For the target model, it is always False.
143+
is_spec_dec_tree: whether the spec-dec mode is a tree, i.e., static tree or dynamic tree.
138144
"""
139145
is_trtllm_attention = issubclass(attention_backend, TrtllmAttention)
140-
return (self.is_eagle3_one_model() # one model
141-
or (self.is_eagle3() and spec_resource_manager.is_first_draft
142-
and is_trtllm_attention and use_chain_drafter
143-
and is_draft_model) # two model + first drafter + CDL
144-
or (self.is_eagle3() and is_trtllm_attention
145-
and is_spec_dec_tree) # two model + tree
146-
)
146+
# Case 1: one model
147+
use_case_1 = self.is_eagle3_one_model()
148+
# Case 2: eagle3 two model + draft model + CDL + is_first_draft + TRTLLM attention
149+
use_case_2 = self.is_eagle3(
150+
) and spec_resource_manager.is_first_draft and use_chain_drafter and is_draft_model and is_trtllm_attention
151+
# Case 3: eagle3 two model + tree decoding + draft model + CDL + TRTLLM attention
152+
use_case_3 = self.is_eagle3(
153+
) and is_spec_dec_tree and is_draft_model and use_chain_drafter and is_trtllm_attention
154+
# Case 4: eagle3 two model + tree decoding + target model + TRTLLM attention
155+
use_case_4 = self.is_eagle3(
156+
) and is_spec_dec_tree and not is_draft_model and is_trtllm_attention
157+
158+
return use_case_1 or use_case_2 or use_case_3 or use_case_4
147159

148160
@staticmethod
149161
def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":

tensorrt_llm/_torch/speculative/spec_tree_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(self, max_num_requests: int, use_dynamic_tree: bool,
122122
else:
123123
self.init_tree_info_for_static_tree()
124124

125-
self.dump_tree_info()
125+
# self.dump_tree_info()
126126

127127
def init_tree_info_for_dynamic_tree(self):
128128
# For the dynamic tree

tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,22 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
188188
assert torch.all(
189189
torch.tensor(attn_metadata.num_contexts) == torch.tensor(
190190
ref_attn_metadata['num_contexts']))
191-
assert torch.all(attn_metadata.spec_decoding_position_offsets ==
192-
ref_attn_metadata['spec_decoding_position_offsets'])
193-
assert torch.all(attn_metadata.spec_decoding_packed_mask ==
194-
ref_attn_metadata['spec_decoding_packed_mask'])
195191
assert torch.all(attn_metadata.spec_decoding_generation_lengths ==
196192
ref_attn_metadata['spec_decoding_generation_lengths'])
193+
total_process_tokens = attn_metadata.spec_decoding_generation_lengths.sum(
194+
)
195+
print(f"total_process_tokens: {total_process_tokens}")
196+
assert torch.all(
197+
attn_metadata.spec_decoding_position_offsets.reshape(
198+
-1)[:total_process_tokens] ==
199+
ref_attn_metadata['spec_decoding_position_offsets']
200+
[:total_process_tokens])
201+
assert torch.all(
202+
attn_metadata.spec_decoding_packed_mask.reshape(
203+
-1, attn_metadata.spec_decoding_packed_mask.size(
204+
-1))[:total_process_tokens, :] ==
205+
ref_attn_metadata['spec_decoding_packed_mask']
206+
[:total_process_tokens, :])
197207

198208
assert torch.all(
199209
torch.tensor(spec_metadata.num_tokens) == torch.tensor(
@@ -267,13 +277,9 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
267277
device='cuda')
268278
ref_attn_metadata['num_contexts'] = 0
269279
ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor(
270-
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
271-
dtype=torch.int32,
272-
device='cuda')
280+
[0, 0, 0], dtype=torch.int32, device='cuda')
273281
ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor(
274-
[1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
275-
dtype=torch.int32,
276-
device='cuda').reshape(1, max_total_draft_tokens + 1, 1)
282+
[1, 2, 4], dtype=torch.int32, device='cuda').unsqueeze(1)
277283
ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor(
278284
[3], dtype=torch.int32, device='cuda')
279285

@@ -361,14 +367,9 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
361367
device='cuda')
362368
ref_attn_metadata['num_contexts'] = 0
363369
ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor(
364-
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
365-
dtype=torch.int32,
366-
device='cuda').repeat(max_batch_size, 1)
370+
[0, 0, 0, 0, 0, 0], dtype=torch.int32, device='cuda')
367371
ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor(
368-
[1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
369-
dtype=torch.int32,
370-
device='cuda').reshape(1, max_total_draft_tokens + 1,
371-
1).repeat(max_batch_size, 1, 1)
372+
[1, 2, 4, 1, 2, 4], dtype=torch.int32, device='cuda').unsqueeze(1)
372373
ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor(
373374
[3, 3], dtype=torch.int32, device='cuda')
374375

@@ -455,14 +456,9 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
455456
device='cuda')
456457
ref_attn_metadata['num_contexts'] = 0
457458
ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor(
458-
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
459-
dtype=torch.int32,
460-
device='cuda').repeat(max_batch_size, 1)
459+
[0, 0, 0, 0, 0, 0], dtype=torch.int32, device='cuda')
461460
ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor(
462-
[1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
463-
dtype=torch.int32,
464-
device='cuda').reshape(1, max_total_draft_tokens + 1,
465-
1).repeat(max_batch_size, 1, 1)
461+
[1, 2, 4, 1, 2, 4], dtype=torch.int32, device='cuda').unsqueeze(1)
466462
ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor(
467463
[3, 3], dtype=torch.int32, device='cuda')
468464

@@ -545,13 +541,9 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
545541
device='cuda')
546542
ref_attn_metadata['num_contexts'] = 0
547543
ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor(
548-
[0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
549-
dtype=torch.int32,
550-
device='cuda')
544+
[0, 0, 1, 1, 1], dtype=torch.int32, device='cuda')
551545
ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor(
552-
[1, 2, 5, 9, 18, 0, 0, 0, 0, 0, 0, 0, 0],
553-
dtype=torch.int32,
554-
device='cuda').reshape(1, max_total_draft_tokens + 1, 1)
546+
[1, 2, 5, 9, 18], dtype=torch.int32, device='cuda').unsqueeze(1)
555547
ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor(
556548
[5], dtype=torch.int32, device='cuda')
557549

@@ -637,13 +629,10 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
637629
device='cuda')
638630
ref_attn_metadata['num_contexts'] = 0
639631
ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor(
640-
[0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
641-
dtype=torch.int32,
642-
device='cuda')
632+
[0, 0, 1, 1, 1, 0, 0, 1, 1, 1], dtype=torch.int32, device='cuda')
643633
ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor(
644-
[1, 2, 5, 9, 18, 0, 0, 0, 0, 0, 0, 0, 0],
645-
dtype=torch.int32,
646-
device='cuda').reshape(1, max_total_draft_tokens + 1, 1)
634+
[1, 2, 5, 9, 18, 1, 2, 5, 9, 18], dtype=torch.int32,
635+
device='cuda').unsqueeze(1)
647636
ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor(
648637
[5, 5], dtype=torch.int32, device='cuda')
649638

0 commit comments

Comments
 (0)