Skip to content

Commit 60589cf

Browse files
committed
fix bug 5556020
Signed-off-by: qgai <[email protected]>
1 parent f49f42d commit 60589cf

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,9 +470,9 @@ def _update_target_inputs_with_draft_tokens(
470470
continue
471471

472472
# Get the index of the draft/target tokens in the device tensor
473-
draft_idx = req_idx if self.use_static_draft_loop else request.py_batch_idx
473+
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
474474
target_idx = req_id_to_old_request[
475-
request.py_request_id].py_batch_idx
475+
request.py_request_id].py_seq_slot
476476
target_inputs.new_tokens[draft_position + 1:draft_position +
477477
draft_length + 1, target_idx,
478478
0] = draft_tensors[0:draft_length,

0 commit comments

Comments
 (0)