Skip to content

Commit be493e0

Browse files
[BugFix] Fix new nightly failures (#29578)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent ae0ce1b commit be493e0

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

vllm/v1/attention/backends/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,32 @@ class CommonAttentionMetadata:
100100
dcp_local_seq_lens_cpu: torch.Tensor | None = None
101101
"""Sequence lengths of the local rank in decode context parallelism world"""
102102

103+
# TODO(lucas): remove once we have FULL-CG spec-decode support
104+
def unpadded(
105+
self, num_actual_tokens: int, num_actual_reqs: int
106+
) -> "CommonAttentionMetadata":
107+
maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
108+
return CommonAttentionMetadata(
109+
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
110+
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
111+
seq_lens=self.seq_lens[:num_actual_reqs],
112+
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
113+
num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs],
114+
num_reqs=num_actual_reqs,
115+
num_actual_tokens=num_actual_tokens,
116+
max_query_len=self.max_query_len,
117+
max_seq_len=self.max_seq_len,
118+
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
119+
slot_mapping=self.slot_mapping[:num_actual_tokens],
120+
causal=self.causal,
121+
logits_indices_padded=self.logits_indices_padded,
122+
num_logits_indices=self.num_logits_indices,
123+
encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
124+
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
125+
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
126+
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
127+
)
128+
103129

104130
def slice_query_start_locs(
105131
query_start_loc: torch.Tensor,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1551,7 +1551,7 @@ def _build_attention_metadata(
15511551
# Encoder-only layers do not have KV cache, so we need to
15521552
# create a dummy block table and slot mapping for them.
15531553
blk_table_tensor = torch.zeros(
1554-
(num_tokens_padded, 1),
1554+
(num_reqs_padded, 1),
15551555
dtype=torch.int32,
15561556
device=self.device,
15571557
)
@@ -1652,6 +1652,16 @@ def _build_attention_metadata(
16521652
for layer_name in attn_group.layer_names:
16531653
attn_metadata[layer_name] = attn_metadata_i
16541654

1655+
if spec_decode_common_attn_metadata is not None and (
1656+
num_reqs != num_reqs_padded or num_tokens != num_tokens_padded
1657+
):
1658+
# Currently the drafter still only uses piecewise cudagraphs (and modifies
1659+
# the attention metadata in directly), and therefore does not want to use
1660+
# padded attention metadata.
1661+
spec_decode_common_attn_metadata = (
1662+
spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs)
1663+
)
1664+
16551665
return attn_metadata, spec_decode_common_attn_metadata
16561666

16571667
def _compute_cascade_attn_prefix_lens(

0 commit comments

Comments
 (0)