Skip to content

Commit e0aed01

Browse files
committed
fix XQA issue
Signed-off-by: Yue Weng <[email protected]>
1 parent 7146339 commit e0aed01

File tree

9 files changed

+106
-91
lines changed

9 files changed

+106
-91
lines changed

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "tensorrt_llm/runtime/utils/debugUtils.h"
2525
#include "tensorrt_llm/thop/attentionOp.h"
2626
#include "tensorrt_llm/thop/thUtils.h"
27+
#include <assert.h>
2728
#include <cstdint>
2829
#include <functional>
2930
#include <torch/extension.h>
@@ -466,7 +467,8 @@ class Runner : public RunnerBase
466467
= spec_decoding_tensor_params[1].value().data_ptr<int32_t>();
467468
enqueue_params.spec_decoding_packed_mask = spec_decoding_tensor_params[2].value().data_ptr<int32_t>();
468469
enqueue_params.spec_decoding_is_generation_length_variable = true;
469-
enqueue_params.spec_decoding_max_generation_length = input_seq_length + 1;
470+
assert(spec_decoding_tensor_params[1].value().dim() == 2); // [batch_size, max_draft_len + 1]
471+
enqueue_params.spec_decoding_max_generation_length = spec_decoding_tensor_params[1].value().sizes()[1];
470472
}
471473

472474
// Current mlaGeneration will using fmha to do attention, so we don't go into enqueueGeneration

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def restore_from_spec_dec(self) -> None:
335335

336336
def update_spec_dec_param(
337337
self,
338+
batch_size,
338339
is_spec_decoding_enabled,
339340
spec_metadata,
340341
spec_tree_manager,

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,14 +1055,14 @@ def prepare_context_mla_with_cached_kv(self,
10551055

10561056
def update_spec_dec_param(
10571057
self,
1058+
batch_size,
10581059
is_spec_decoding_enabled,
10591060
spec_metadata,
10601061
spec_tree_manager,
10611062
max_draft_len,
10621063
max_total_draft_tokens,
10631064
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
10641065
):
1065-
10661066
if spec_decoding_tensor is not None:
10671067
spec_decoding_tensor.position_offsets
10681068
spec_decoding_tensor.packed_mask
@@ -1086,6 +1086,8 @@ def update_spec_dec_param(
10861086

10871087
# Parameters can be fixed and not changed during runtime if the
10881088
if self.is_spec_decoding_enabled:
1089+
# These buffers are accessed more like removing input padding,
1090+
# rather than using max_total_draft_tokens + 1 as the offset between different requests.
10891091
self.spec_decoding_position_offsets = torch.empty(
10901092
[self.max_num_requests, max_total_draft_tokens + 1],
10911093
dtype=torch.int,
@@ -1121,47 +1123,54 @@ def update_spec_dec_param(
11211123
assert False, "Dynamic tree is not supported yet"
11221124
# If is the static tree
11231125
else:
1124-
self.spec_decoding_position_offsets[
1125-
:,
1126-
].copy_(spec_tree_manager.spec_dec_position_offsets[0, :],
1127-
non_blocking=True)
1128-
self.spec_decoding_packed_mask[:, :, :].copy_(
1126+
self.spec_decoding_position_offsets[:batch_size, :].copy_(
1127+
spec_tree_manager.spec_dec_position_offsets[0, :],
1128+
non_blocking=True)
1129+
self.spec_decoding_packed_mask[:batch_size, :, :].copy_(
11291130
spec_tree_manager.spec_dec_packed_mask[0, :, :],
11301131
non_blocking=True)
1131-
self.spec_decoding_generation_lengths[:].fill_(
1132+
self.spec_decoding_generation_lengths[:batch_size].fill_(
11321133
spec_tree_manager.max_total_draft_tokens + 1)
11331134
else:
11341135
# Prepare for the linear-tree.
11351136
# Populate the mask that won't change during inference phase.
11361137
self.generate_spec_decoding_position_offsets(
1137-
max_total_draft_tokens=max_total_draft_tokens)
1138+
batch_size=batch_size, max_draft_len=max_draft_len)
11381139
self.generate_spec_decoding_packed_mask(
1139-
max_total_draft_tokens=max_total_draft_tokens)
1140+
batch_size=batch_size, max_draft_len=max_draft_len)
11401141
self.generate_spec_decoding_generation_length(
1141-
max_total_draft_tokens=max_total_draft_tokens)
1142+
batch_size=batch_size, max_draft_len=max_draft_len)
11421143

1143-
def generate_spec_decoding_position_offsets(self, max_total_draft_tokens):
1144-
position_offset = torch.arange(max_total_draft_tokens + 1,
1144+
def generate_spec_decoding_position_offsets(self, batch_size,
1145+
max_draft_len):
1146+
position_offset = torch.arange(max_draft_len + 1,
11451147
dtype=torch.int,
11461148
device='cpu',
1147-
pin_memory=True)
1148-
1149+
pin_memory=True).repeat(batch_size)
1150+
#
11491151
# fill all the batches with same position offset
1150-
self.spec_decoding_position_offsets.copy_(position_offset,
1151-
non_blocking=True)
1152-
1153-
def generate_spec_decoding_packed_mask(self, max_total_draft_tokens):
1154-
# TODO: fix this limitation
1155-
assert max_total_draft_tokens < 32, "max_total_draft_tokens should be less than 32, will be fixed later"
1156-
dummy_idx = torch.arange(max_total_draft_tokens + 1)
1157-
spec_decoding_packed_mask = torch.pow(2, dummy_idx + 1) - 1
1158-
self.spec_decoding_packed_mask[:, :, 0].copy_(spec_decoding_packed_mask,
1159-
non_blocking=True)
1160-
1161-
def generate_spec_decoding_generation_length(self, max_total_draft_tokens):
1162-
spec_decoding_generation_length = torch.full((self.max_num_requests, ),
1163-
max_total_draft_tokens + 1)
1164-
self.spec_decoding_generation_lengths[:self.max_num_requests].copy_(
1152+
self.spec_decoding_position_offsets.reshape(-1)[:(max_draft_len + 1) *
1153+
batch_size].copy_(
1154+
position_offset,
1155+
non_blocking=True)
1156+
1157+
def generate_spec_decoding_packed_mask(self, batch_size, max_draft_len):
1158+
# FIXME: remove this limitation
1159+
assert max_draft_len < 32, "max_draft_len should be less than 32, will be fixed later"
1160+
dummy_idx = torch.arange(max_draft_len + 1)
1161+
spec_decoding_packed_mask = torch.pow(
1162+
2, dummy_idx + 1) - 1 # [max_draft_len + 1]
1163+
spec_decoding_packed_mask = spec_decoding_packed_mask.repeat(
1164+
batch_size) # [batch_size * (max_draft_len + 1)]
1165+
self.spec_decoding_packed_mask.reshape(
1166+
-1)[:(max_draft_len + 1) * batch_size].copy_(
1167+
spec_decoding_packed_mask, non_blocking=True)
1168+
1169+
def generate_spec_decoding_generation_length(self, batch_size,
1170+
max_draft_len):
1171+
spec_decoding_generation_length = torch.full((batch_size, ),
1172+
max_draft_len + 1)
1173+
self.spec_decoding_generation_lengths[:batch_size].copy_(
11651174
spec_decoding_generation_length, non_blocking=True)
11661175

11671176

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,8 +2331,8 @@ def forward(
23312331
spec_resource_manager, self.is_draft_model, self.attn_backend,
23322332
self.model_is_wrapped, spec_metadata.is_spec_dec_tree)
23332333
attn_metadata.update_spec_dec_param(
2334-
is_spec_dec_mode, spec_metadata, spec_tree_manager,
2335-
self.original_max_draft_len,
2334+
scheduled_requests.batch_size, is_spec_dec_mode, spec_metadata,
2335+
spec_tree_manager, self.original_max_draft_len,
23362336
self.original_max_total_draft_tokens, spec_decoding_tensor)
23372337
else:
23382338
spec_resource_manager = None

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def update_kv_cache_draft_token_location(self,
581581
requests = scheduled_batch.all_requests()
582582
accepted_draft_token_offsets, packed_accepted_draft_tokens_indices, rewind_draft_token_separate_adjustments = self.locate_accepted_draft_tokens(
583583
requests)
584-
past_key_value_lengths = attn_metadata.kv_lens_cuda
584+
past_key_value_lengths = attn_metadata.kv_lens_cuda[:len(requests)]
585585
if attn_metadata.kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_pool_pointers is not None and attn_metadata.host_kv_cache_pool_mapping is not None:
586586
use_paged_kv_cache = True
587587
else:

tensorrt_llm/_torch/speculative/drafting_loops.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ def prepare_for_generation_with_tree_decoding(
192192
batch_size] -= prev_layer_gen_len_per_req # reset to original length before the drafter loop.
193193
attn_metadata.kv_lens_cuda[:batch_size] += next_layer_gen_len_per_req
194194

195+
# FIXME, update without D2H
196+
# attn_metadata.kv_lens[:batch_size] = attn_metadata.kv_lens_cuda[:batch_size].cpu()
197+
195198
## 3.2) _seq_lens, _seq_lens_cuda
196199
attn_metadata._seq_lens[:batch_size].fill_(next_layer_gen_len_per_req)
197200
attn_metadata._seq_lens_cuda[:batch_size].fill_(next_layer_gen_len_per_req)
@@ -207,23 +210,22 @@ def prepare_for_generation_with_tree_decoding(
207210
attn_metadata.use_spec_decoding = True
208211

209212
## 3.6) spec_decoding_position_offsets
210-
attn_metadata.spec_decoding_position_offsets[:, :
211-
next_layer_gen_len_per_req] = spec_tree_manager.spec_dec_position_offsets_for_drafter_model[
212-
prepare_for_layer_idx -
213-
1].unsqueeze(0)
214-
attn_metadata.spec_decoding_position_offsets[:,
215-
next_layer_gen_len_per_req:] = 0
213+
attn_metadata.spec_decoding_position_offsets.reshape(
214+
-1
215+
)[:batch_size *
216+
next_layer_gen_len_per_req] = spec_tree_manager.spec_dec_position_offsets_for_drafter_model[
217+
prepare_for_layer_idx - 1].repeat(batch_size)
216218

217219
## 3.7) spec_decoding_packed_mask
218-
attn_metadata.spec_decoding_packed_mask[:, :
219-
next_layer_gen_len_per_req, :] = spec_tree_manager.spec_dec_packed_mask_for_drafter_model[
220-
prepare_for_layer_idx -
221-
1].unsqueeze(0)
222-
attn_metadata.spec_decoding_packed_mask[:,
223-
next_layer_gen_len_per_req:, :] = 0
220+
attn_metadata.spec_decoding_packed_mask.reshape(
221+
-1, attn_metadata.spec_decoding_packed_mask.size(-1)
222+
)[:batch_size *
223+
next_layer_gen_len_per_req, :] = spec_tree_manager.spec_dec_packed_mask_for_drafter_model[
224+
prepare_for_layer_idx - 1].repeat(batch_size, 1)
224225

225226
## 3.8) spec_decoding_generation_lengths
226-
attn_metadata.spec_decoding_generation_lengths[:] = next_layer_gen_len_per_req
227+
attn_metadata.spec_decoding_generation_lengths[:
228+
batch_size] = next_layer_gen_len_per_req
227229

228230
# 4) spec_metadata
229231
## 4.1) num_tokens

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,20 +130,32 @@ def attention_need_spec_dec_mode(
130130
spec_resource_manager: BaseResourceManager,
131131
is_draft_model: bool,
132132
attention_backend: Type[AttentionBackend],
133-
use_chain_drafter: bool,
133+
use_chain_drafter: bool, # CDL
134134
is_spec_dec_tree: bool,
135135
):
136136
"""
137137
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
138+
Args:
139+
spec_resource_manager: the resource manager for the spec-dec mode.
140+
is_draft_model: whether the model is a draft model.
141+
attention_backend: the attention backend.
142+
use_chain_drafter: whether to use capturable drafting loops (CDL). For the target model, it is always False.
143+
is_spec_dec_tree: whether the spec-dec mode is a tree, i.e., static tree or dynamic tree.
138144
"""
139145
is_trtllm_attention = issubclass(attention_backend, TrtllmAttention)
140-
return (self.is_eagle3_one_model() # one model
141-
or (self.is_eagle3() and spec_resource_manager.is_first_draft
142-
and is_trtllm_attention and use_chain_drafter
143-
and is_draft_model) # two model + first drafter + CDL
144-
or (self.is_eagle3() and is_trtllm_attention
145-
and is_spec_dec_tree) # two model + tree
146-
)
146+
# Case 1: one model
147+
use_case_1 = self.is_eagle3_one_model()
148+
# Case 2: eagle3 two model + draft model + CDL + is_first_draft + TRTLLM attention
149+
use_case_2 = self.is_eagle3(
150+
) and spec_resource_manager.is_first_draft and use_chain_drafter and is_draft_model and is_trtllm_attention
151+
# Case 3: eagle3 two model + tree decoding + draft model + CDL + TRTLLM attention
152+
use_case_3 = self.is_eagle3(
153+
) and is_spec_dec_tree and is_draft_model and use_chain_drafter and is_trtllm_attention
154+
# Case 4: eagle3 two model + tree decoding + target model + TRTLLM attention
155+
use_case_4 = self.is_eagle3(
156+
) and is_spec_dec_tree and not is_draft_model and is_trtllm_attention
157+
158+
return use_case_1 or use_case_2 or use_case_3 or use_case_4
147159

148160
@staticmethod
149161
def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":

tensorrt_llm/_torch/speculative/spec_tree_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(self, max_num_requests: int, use_dynamic_tree: bool,
122122
else:
123123
self.init_tree_info_for_static_tree()
124124

125-
self.dump_tree_info()
125+
# self.dump_tree_info()
126126

127127
def init_tree_info_for_dynamic_tree(self):
128128
# For the dynamic tree

tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,22 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
188188
assert torch.all(
189189
torch.tensor(attn_metadata.num_contexts) == torch.tensor(
190190
ref_attn_metadata['num_contexts']))
191-
assert torch.all(attn_metadata.spec_decoding_position_offsets ==
192-
ref_attn_metadata['spec_decoding_position_offsets'])
193-
assert torch.all(attn_metadata.spec_decoding_packed_mask ==
194-
ref_attn_metadata['spec_decoding_packed_mask'])
195191
assert torch.all(attn_metadata.spec_decoding_generation_lengths ==
196192
ref_attn_metadata['spec_decoding_generation_lengths'])
193+
total_process_tokens = attn_metadata.spec_decoding_generation_lengths.sum(
194+
)
195+
print(f"total_process_tokens: {total_process_tokens}")
196+
assert torch.all(
197+
attn_metadata.spec_decoding_position_offsets.reshape(
198+
-1)[:total_process_tokens] ==
199+
ref_attn_metadata['spec_decoding_position_offsets']
200+
[:total_process_tokens])
201+
assert torch.all(
202+
attn_metadata.spec_decoding_packed_mask.reshape(
203+
-1, attn_metadata.spec_decoding_packed_mask.size(
204+
-1))[:total_process_tokens, :] ==
205+
ref_attn_metadata['spec_decoding_packed_mask']
206+
[:total_process_tokens, :])
197207

198208
assert torch.all(
199209
torch.tensor(spec_metadata.num_tokens) == torch.tensor(
@@ -267,13 +277,9 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
267277
device='cuda')
268278
ref_attn_metadata['num_contexts'] = 0
269279
ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor(
270-
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
271-
dtype=torch.int32,
272-
device='cuda')
280+
[0, 0, 0], dtype=torch.int32, device='cuda')
273281
ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor(
274-
[1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
275-
dtype=torch.int32,
276-
device='cuda').reshape(1, max_total_draft_tokens + 1, 1)
282+
[1, 2, 4], dtype=torch.int32, device='cuda').unsqueeze(1)
277283
ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor(
278284
[3], dtype=torch.int32, device='cuda')
279285

@@ -361,14 +367,9 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
361367
device='cuda')
362368
ref_attn_metadata['num_contexts'] = 0
363369
ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor(
364-
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
365-
dtype=torch.int32,
366-
device='cuda').repeat(max_batch_size, 1)
370+
[0, 0, 0, 0, 0, 0], dtype=torch.int32, device='cuda')
367371
ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor(
368-
[1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
369-
dtype=torch.int32,
370-
device='cuda').reshape(1, max_total_draft_tokens + 1,
371-
1).repeat(max_batch_size, 1, 1)
372+
[1, 2, 4, 1, 2, 4], dtype=torch.int32, device='cuda').unsqueeze(1)
372373
ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor(
373374
[3, 3], dtype=torch.int32, device='cuda')
374375

@@ -455,14 +456,9 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
455456
device='cuda')
456457
ref_attn_metadata['num_contexts'] = 0
457458
ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor(
458-
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
459-
dtype=torch.int32,
460-
device='cuda').repeat(max_batch_size, 1)
459+
[0, 0, 0, 0, 0, 0], dtype=torch.int32, device='cuda')
461460
ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor(
462-
[1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
463-
dtype=torch.int32,
464-
device='cuda').reshape(1, max_total_draft_tokens + 1,
465-
1).repeat(max_batch_size, 1, 1)
461+
[1, 2, 4, 1, 2, 4], dtype=torch.int32, device='cuda').unsqueeze(1)
466462
ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor(
467463
[3, 3], dtype=torch.int32, device='cuda')
468464

@@ -545,13 +541,9 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
545541
device='cuda')
546542
ref_attn_metadata['num_contexts'] = 0
547543
ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor(
548-
[0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
549-
dtype=torch.int32,
550-
device='cuda')
544+
[0, 0, 1, 1, 1], dtype=torch.int32, device='cuda')
551545
ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor(
552-
[1, 2, 5, 9, 18, 0, 0, 0, 0, 0, 0, 0, 0],
553-
dtype=torch.int32,
554-
device='cuda').reshape(1, max_total_draft_tokens + 1, 1)
546+
[1, 2, 5, 9, 18], dtype=torch.int32, device='cuda').unsqueeze(1)
555547
ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor(
556548
[5], dtype=torch.int32, device='cuda')
557549

@@ -637,13 +629,10 @@ def run_test(max_batch_size, prepare_for_layer_idx, max_total_draft_tokens,
637629
device='cuda')
638630
ref_attn_metadata['num_contexts'] = 0
639631
ref_attn_metadata['spec_decoding_position_offsets'] = torch.tensor(
640-
[0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
641-
dtype=torch.int32,
642-
device='cuda')
632+
[0, 0, 1, 1, 1, 0, 0, 1, 1, 1], dtype=torch.int32, device='cuda')
643633
ref_attn_metadata['spec_decoding_packed_mask'] = torch.tensor(
644-
[1, 2, 5, 9, 18, 0, 0, 0, 0, 0, 0, 0, 0],
645-
dtype=torch.int32,
646-
device='cuda').reshape(1, max_total_draft_tokens + 1, 1)
634+
[1, 2, 5, 9, 18, 1, 2, 5, 9, 18], dtype=torch.int32,
635+
device='cuda').unsqueeze(1)
647636
ref_attn_metadata['spec_decoding_generation_lengths'] = torch.tensor(
648637
[5, 5], dtype=torch.int32, device='cuda')
649638

0 commit comments

Comments
 (0)