Skip to content

Commit 84253da

Browse files
committed
update
Signed-off-by: Jade Zheng <[email protected]>
1 parent c38c6bb commit 84253da

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3076,24 +3076,23 @@ def _dummy_run(
30763076

30773077
need_dummy_logits = (not self.in_profile_run
30783078
and lmhead_tp_enable())
3079-
3080-
if need_dummy_logits:
3081-
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
3082-
dummy_indices = torch.zeros(max_num_reqs_across_dp,
3083-
dtype=torch.int32)
3084-
3085-
def dummy_compute_logits(hidden_states):
3086-
return self.model.compute_logits(
3087-
hidden_states[dummy_indices])
3088-
3089-
def dummy_drafter_compute_logits(hidden_states):
3079+
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
3080+
dummy_indices = torch.zeros(max_num_reqs_across_dp,
3081+
dtype=torch.int32)
3082+
3083+
def dummy_compute_logits(hidden_states):
3084+
if not need_dummy_logits:
3085+
return None
3086+
return self.model.compute_logits(hidden_states[dummy_indices])
3087+
3088+
def dummy_drafter_compute_logits(hidden_states):
3089+
if not need_dummy_logits:
3090+
return
3091+
if hasattr(self.drafter, "model") and hasattr(
3092+
self.drafter.model, "compute_logits"):
30903093
return self.drafter.model.compute_logits(
30913094
hidden_states[dummy_indices])
30923095

3093-
else:
3094-
dummy_compute_logits = lambda hidden_states: None
3095-
dummy_drafter_compute_logits = lambda hidden_states: None
3096-
30973096
with set_ascend_forward_context(
30983097
attn_metadata,
30993098
self.vllm_config,

0 commit comments

Comments
 (0)