Skip to content

Commit df777e9

Browse files
[bugfix] pcp + mtp acl graph bugfix (#4221)
Fix pcp + mtp bug while using acl graph. While using pcp + mtp, we need to flatten block_table to avoid irregular attn mask shape, this was done in mla attn_metadata builder, but we found out that this influences block_table address and leads to incorrect results while enable acl graph. To fix this, we enlarge block_table buffer size and flatten block_table in model_runner prepare_inputs, so this will not influence block_table address. - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@2918c1b Signed-off-by: zhangsicheng5 <[email protected]>
1 parent 9328f37 commit df777e9

File tree

3 files changed

+69
-24
lines changed

3 files changed

+69
-24
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,12 @@ def build(
369369
device = self.device
370370

371371
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
372+
if self.pcp_size > 1:
373+
num_decodes_flatten = num_decodes * self.decode_threshold
374+
block_table = common_attn_metadata.block_table_tensor[:
375+
num_decodes_flatten
376+
+
377+
num_prefills]
372378

373379
if num_actual_tokens_pcp_padded is None:
374380
num_actual_tokens_pcp_padded = num_actual_tokens
@@ -546,6 +552,9 @@ def build(
546552
cos=cos,
547553
pcp_metadata=pcp_metadata,
548554
)
555+
if self.pcp_size > 1:
556+
prefill_metadata.block_table = block_table[
557+
num_decodes_flatten:, ...]
549558

550559
decode_metadata = None
551560
if num_decodes > 0:
@@ -556,12 +565,12 @@ def build(
556565
max_seq_lens = seq_lens[:num_decodes].max().item()
557566
seq_lens = seq_lens[:num_decodes]
558567
input_positions = input_positions[:num_decode_tokens]
559-
block_table = block_table[:num_decodes, ...]
560-
# For pcp + spec decode, we flatten seq_lens and block_table
561-
# to avoid irregular spec_attn_mask shape
562-
if self.pcp_size > 1 and self.decode_threshold > 1:
563-
block_table = block_table.repeat_interleave(
564-
self.decode_threshold, dim=0)
568+
if self.pcp_size > 1:
569+
# For pcp + spec decode, we flatten seq_lens and block_table
570+
# to avoid irregular spec_attn_mask shape
571+
block_table = block_table[:num_decodes_flatten, ...]
572+
else:
573+
block_table = block_table[:num_decodes, ...]
565574
seq_lens_list = seq_lens.tolist()
566575

567576
if num_computed_tokens_of_pcp_dcp is not None:

vllm_ascend/worker/block_table.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,29 @@ def __init__(self,
2727
pin_memory: bool,
2828
device: torch.device,
2929
kernel_sizes: Union[list[int], None] = None,
30-
cp_kv_cache_interleave_size: int = 1):
30+
cp_kv_cache_interleave_size: int = 1,
31+
num_speculative_tokens: int = 0):
3132
self.max_num_reqs = max_num_reqs
3233
self.max_num_blocks_per_req = max_num_blocks_per_req
3334
self.max_num_batched_tokens = max_num_batched_tokens
3435
self.pin_memory = pin_memory
3536
self.device = device
3637
self.physical_block_size = block_size
38+
39+
try:
40+
self.pcp_world_size = get_pcp_group(
41+
).world_size if prefill_context_parallel_enable() else 1
42+
self.pcp_rank = get_pcp_group(
43+
).rank_in_group if self.pcp_world_size > 1 else 0
44+
self.dcp_world_size = get_dcp_group().world_size
45+
self.dcp_rank = get_dcp_group().rank_in_group
46+
except AssertionError:
47+
# DCP might not be initialized in testing
48+
self.dcp_world_size = 1
49+
self.dcp_rank = 0
50+
self.pcp_world_size = 1
51+
self.pcp_rank = 0
52+
3753
# If kernel_sizes is None or [0], use physical block size (no splitting)
3854
if kernel_sizes is None or kernel_sizes == [0]:
3955
self.block_size = block_size
@@ -69,34 +85,23 @@ def __init__(self,
6985
else:
7086
logical_table_size = max_num_blocks_per_req
7187

88+
duplicate_size = 1
89+
if self.pcp_world_size > 1:
90+
duplicate_size += num_speculative_tokens
7291
self.block_table = torch.zeros(
73-
(max_num_reqs, logical_table_size),
92+
(max_num_reqs * duplicate_size, logical_table_size),
7493
device=self.device,
7594
dtype=torch.int32,
7695
)
7796
self.block_table_cpu = torch.zeros(
78-
(max_num_reqs, logical_table_size),
97+
(max_num_reqs * duplicate_size, logical_table_size),
7998
device="cpu",
8099
dtype=torch.int32,
81100
pin_memory=pin_memory,
82101
)
83102
self.block_table_np = self.block_table_cpu.numpy()
84103
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
85104

86-
try:
87-
self.pcp_world_size = get_pcp_group(
88-
).world_size if prefill_context_parallel_enable() else 1
89-
self.pcp_rank = get_pcp_group(
90-
).rank_in_group if self.pcp_world_size > 1 else 0
91-
self.dcp_world_size = get_dcp_group().world_size
92-
self.dcp_rank = get_dcp_group().rank_in_group
93-
except AssertionError:
94-
# DCP might not be initialized in testing
95-
self.dcp_world_size = 1
96-
self.dcp_rank = 0
97-
self.pcp_world_size = 1
98-
self.pcp_rank = 0
99-
100105
self.slot_mapping_cpu = torch.zeros(
101106
self.max_num_batched_tokens +
102107
2 * self.pcp_world_size * self.max_num_reqs,
@@ -306,7 +311,7 @@ def __init__(self,
306311
block_size * dcp_world_size * pcp_world_size),
307312
1 + num_speculative_tokens), max_num_batched_tokens,
308313
pin_memory, device, kernel_size_list,
309-
cp_kv_cache_interleave_size)
314+
cp_kv_cache_interleave_size, num_speculative_tokens)
310315
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
311316
]
312317

vllm_ascend/worker/model_runner_v1.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
596596
self.is_pooling_model,
597597
self.vllm_config.model_config.logits_processors),
598598
is_pooling_model=self.is_pooling_model,
599+
num_speculative_tokens=(
600+
self.vllm_config.speculative_config.num_speculative_tokens
601+
if self.vllm_config.speculative_config else 0),
599602
kernel_block_sizes=[[self.vllm_config.cache_config.block_size]],
600603
cp_kv_cache_interleave_size=self.parallel_config.
601604
cp_kv_cache_interleave_size
@@ -1922,6 +1925,31 @@ def _prepare_inputs(
19221925
prefill_context_parallel_metadata=long_seq_metadata,
19231926
)
19241927

1928+
if self.speculative_config and self.pcp_size > 1:
1929+
# For pcp + spec decode, we flatten block_table
1930+
# to avoid irregular spec_attn_mask shape, e.g.,
1931+
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
1932+
# ori block_table: # [d0, d1, p0, p1, p2]
1933+
# (num_reqs_d + num_reqs_p, max_num_blocks),
1934+
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
1935+
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
1936+
ori_query_lens = self.query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
1937+
self.query_start_loc_pcp_full_cpu[:num_reqs]
1938+
num_prefill_reqs = (ori_query_lens
1939+
> self.decode_threshold).sum().item()
1940+
num_decode_reqs = num_reqs - num_prefill_reqs
1941+
num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold
1942+
blk_table_tensor[
1943+
num_decode_reqs_flatten:num_decode_reqs_flatten +
1944+
num_prefill_reqs].copy_(
1945+
blk_table_tensor[num_decode_reqs:num_decode_reqs +
1946+
num_prefill_reqs].clone())
1947+
blk_table_tensor[:num_decode_reqs_flatten].copy_(
1948+
blk_table_tensor[:num_decode_reqs].repeat_interleave(
1949+
self.decode_threshold, dim=0))
1950+
common_attn_metadata.block_table_tensor = \
1951+
blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs]
1952+
19251953
if self.speculative_config and \
19261954
self.spec_decode_common_attn_metadata is None:
19271955
self.spec_decode_common_attn_metadata = common_attn_metadata
@@ -2831,6 +2859,9 @@ def _build_dummy_attn_metadata(
28312859
sin=self.sin,
28322860
prefill_context_parallel_metadata=long_seq_metadata,
28332861
)
2862+
if self.pcp_size > 1:
2863+
common_attn_metadata.block_table_tensor = \
2864+
block_table_tensor[:num_reqs * self.decode_threshold]
28342865
attn_state = AscendAttentionState.DecodeOnly
28352866
if self.speculative_config and \
28362867
self.speculative_config.method == "deepseek_mtp":

0 commit comments

Comments
 (0)