Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def dummy_run(self,
num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None):
moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
with set_ascend_forward_context(None,
Expand All @@ -148,6 +149,7 @@ def dummy_run(self,
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
dummy_compute_logits(self.hidden_states)

def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]],
Expand Down
7 changes: 5 additions & 2 deletions vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def dummy_run(self,
num_reqs: int = 0,
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None) -> None:
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None) -> None:

(
num_tokens,
Expand Down Expand Up @@ -316,6 +317,7 @@ def dummy_run(self,
self.update_stream, forward_context,
positions.shape[0],
self.vllm_config.speculative_config)
dummy_compute_logits(previous_hidden_states)
if with_prefill:
break

Expand Down Expand Up @@ -775,6 +777,7 @@ def _propose(
logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices]
last_token_indices = last_token_indices[:num_indices]
draft_token_ids = logits.argmax(dim=-1)

if self.num_speculative_tokens == 1:
Expand Down Expand Up @@ -840,7 +843,7 @@ def _propose(
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
attn_metadata_i.seq_lens.device, non_blocking=True)
attn_metadata_i.seq_lens.device, non_blocking=False)
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len_cpu, 1)
# Mask out the slot mappings that exceed the max model length.
Expand Down
32 changes: 18 additions & 14 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3076,14 +3076,21 @@ def _dummy_run(

need_dummy_logits = (not self.in_profile_run
and lmhead_tp_enable())

if need_dummy_logits:
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)

def dummy_compute_logits(hidden_states):
return self.model.compute_logits(
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)

def dummy_compute_logits(hidden_states):
if not need_dummy_logits:
return None
return self.model.compute_logits(hidden_states[dummy_indices])

def dummy_drafter_compute_logits(hidden_states):
if not need_dummy_logits or self.drafter is None:
return
if hasattr(self.drafter, "model") and hasattr(
self.drafter.model, "compute_logits"):
return self.drafter.model.compute_logits(
hidden_states[dummy_indices])

with set_ascend_forward_context(
Expand All @@ -3105,8 +3112,7 @@ def dummy_compute_logits(hidden_states):
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors,
inputs_embeds)
if need_dummy_logits:
dummy_compute_logits(hidden_states)
dummy_compute_logits(hidden_states)

if self.drafter:
self.drafter.dummy_run(
Expand All @@ -3115,10 +3121,8 @@ def dummy_compute_logits(hidden_states):
num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor)
if need_dummy_logits:
self.drafter.model.compute_logits(
hidden_states[dummy_indices])
batch_descriptor=batch_descriptor,
dummy_compute_logits=dummy_drafter_compute_logits)
if self.in_profile_run and self.dynamic_eplb:
self.model.clear_all_moe_loads()
if not self.in_profile_run and self.dynamic_eplb:
Expand Down
Loading