Skip to content

Commit 2624b3f

Browse files
committed
[Bugfix] Resolve MTP > 1 issue when lm head tp > 1
Previously, the dummy run executed compute_logits only once, regardless of num_speculative_tokens. This caused execute_model to hang on compute_logits when lm head tensor parallelism exceeded 1. The fix ensures compute_logits executes correctly during dummy run, matching num_speculative_tokens. Signed-off-by: Jade Zheng <[email protected]>
1 parent fff258b commit 2624b3f

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def dummy_run(self,
136136
num_reqs: int = 0,
137137
num_tokens_across_dp: Optional[torch.Tensor] = None,
138138
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
139-
batch_descriptor=None):
139+
batch_descriptor=None,
140+
dummy_compute_logits=lambda hidden_states: None):
140141
moe_comm_type = self.runner._select_moe_comm_method(
141142
num_tokens, with_prefill)
142143
with set_ascend_forward_context(None,
@@ -148,6 +149,7 @@ def dummy_run(self,
148149
positions=self.positions[:num_tokens],
149150
hidden_states=self.hidden_states[:num_tokens],
150151
)
152+
dummy_compute_logits(self.hidden_states)
151153

152154
def generate_token_ids(self,
153155
valid_sampled_token_ids: list[list[int]],

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def dummy_run(self,
232232
num_reqs: int = 0,
233233
num_tokens_across_dp=None,
234234
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
235-
batch_descriptor=None) -> None:
235+
batch_descriptor=None,
236+
dummy_compute_logits=lambda hidden_states: None) -> None:
236237

237238
(
238239
num_tokens,
@@ -316,6 +317,7 @@ def dummy_run(self,
316317
self.update_stream, forward_context,
317318
positions.shape[0],
318319
self.vllm_config.speculative_config)
320+
dummy_compute_logits(previous_hidden_states)
319321
if with_prefill:
320322
break
321323

@@ -775,6 +777,7 @@ def _propose(
775777
logits = self.model.compute_logits(sample_hidden_states)
776778
if lmhead_tp_enable() and num_indices < logits.shape[0]:
777779
logits = logits[:num_indices]
780+
last_token_indices = last_token_indices[:num_indices]
778781
draft_token_ids = logits.argmax(dim=-1)
779782

780783
if self.num_speculative_tokens == 1:
@@ -840,7 +843,7 @@ def _propose(
840843
# For the requests that exceed the max model length, we set the
841844
# sequence length to 1 to minimize their overheads in attention.
842845
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
843-
attn_metadata_i.seq_lens.device, non_blocking=True)
846+
attn_metadata_i.seq_lens.device, non_blocking=False)
844847
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
845848
exceeds_max_model_len_cpu, 1)
846849
# Mask out the slot mappings that exceed the max model length.

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3086,6 +3086,14 @@ def dummy_compute_logits(hidden_states):
30863086
return self.model.compute_logits(
30873087
hidden_states[dummy_indices])
30883088

3089+
def dummy_drafter_compute_logits(hidden_states):
3090+
return self.drafter.compute_logits(
3091+
hidden_states[dummy_indices])
3092+
3093+
else:
3094+
dummy_compute_logits = lambda hidden_states: None
3095+
dummy_drafter_compute_logits = lambda hidden_states: None
3096+
30893097
with set_ascend_forward_context(
30903098
attn_metadata,
30913099
self.vllm_config,
@@ -3105,8 +3113,7 @@ def dummy_compute_logits(hidden_states):
31053113
with_prefill, is_torchair_compile, input_ids, positions,
31063114
attn_metadata, num_tokens, intermediate_tensors,
31073115
inputs_embeds)
3108-
if need_dummy_logits:
3109-
dummy_compute_logits(hidden_states)
3116+
dummy_compute_logits(hidden_states)
31103117

31113118
if self.drafter:
31123119
self.drafter.dummy_run(
@@ -3115,10 +3122,8 @@ def dummy_compute_logits(hidden_states):
31153122
num_reqs=num_reqs,
31163123
num_tokens_across_dp=num_tokens_across_dp,
31173124
aclgraph_runtime_mode=aclgraph_runtime_mode,
3118-
batch_descriptor=batch_descriptor)
3119-
if need_dummy_logits:
3120-
self.drafter.model.compute_logits(
3121-
hidden_states[dummy_indices])
3125+
batch_descriptor=batch_descriptor,
3126+
dummy_compute_logits=dummy_drafter_compute_logits)
31223127
if self.in_profile_run and self.dynamic_eplb:
31233128
self.model.clear_all_moe_loads()
31243129
if not self.in_profile_run and self.dynamic_eplb:

0 commit comments

Comments
 (0)