Skip to content

Commit f29aeb5

Browse files
Add FLASHINFER_MLA to test_mla_backends and add B200 CI run (vllm-project#27663)
Signed-off-by: Matthew Bonanni <[email protected]>
1 parent 5e8862e commit f29aeb5

File tree

4 files changed

+208
-64
lines changed

4 files changed

+208
-64
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,16 @@ steps:
340340
commands:
341341
- pytest -v -s v1/attention
342342

343+
- label: V1 Test attention (B200) # 10min
344+
timeout_in_minutes: 30
345+
gpu: b200
346+
source_file_dependencies:
347+
- vllm/v1/attention
348+
- tests/v1/attention
349+
commands:
350+
- export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this
351+
- pytest -v -s v1/attention
352+
343353
- label: V1 Test others (CPU) # 5 mins
344354
source_file_dependencies:
345355
- vllm/

tests/v1/attention/test_mla_backends.py

Lines changed: 182 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,66 @@
1414
from tests.v1.attention.utils import (
1515
BatchSpec,
1616
create_common_attn_metadata,
17-
create_standard_kv_cache_spec,
1817
create_vllm_config,
1918
try_get_attention_backend,
2019
)
2120
from vllm import _custom_ops as ops
22-
from vllm.attention.backends.registry import _Backend
21+
from vllm.attention.backends.registry import _Backend, backend_to_class_str
2322
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
23+
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
2424
from vllm.config.vllm import set_current_vllm_config
25+
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
26+
from vllm.utils.import_utils import resolve_obj_by_qualname
2527
from vllm.utils.math_utils import cdiv
2628
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
29+
from vllm.v1.attention.backends.mla.common import QueryLenSupport
2730
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
2831
from vllm.v1.kv_cache_interface import FullAttentionSpec
2932

3033
BACKENDS_TO_TEST = [
3134
_Backend.CUTLASS_MLA,
3235
_Backend.FLASHMLA,
3336
_Backend.FLASH_ATTN_MLA,
37+
_Backend.FLASHINFER_MLA,
3438
_Backend.TRITON_MLA,
3539
]
3640

37-
# Remove CUTLASS_MLA from the list if not using sm100
41+
# Remove sm100 backends from the list if not using sm100
3842
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
3943
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
44+
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA)
45+
46+
# Remove FLASH_ATTN_MLA from the list if not supported
47+
if not flash_attn_supports_mla():
48+
BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA)
4049

4150
# Remove FLASHMLA from the list if not supported
4251
if not is_flashmla_dense_supported()[0]:
4352
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
4453

54+
SPEC_DECODE_BACKENDS = []
55+
for backend in BACKENDS_TO_TEST:
56+
builder_cls, _ = try_get_attention_backend(backend)
57+
query_len_support = getattr(
58+
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
59+
)
60+
if query_len_support != QueryLenSupport.SINGLE_ONLY:
61+
SPEC_DECODE_BACKENDS.append(backend)
62+
63+
BACKEND_BLOCK_SIZES = {}
64+
for backend in BACKENDS_TO_TEST:
65+
backend_class_str = backend_to_class_str(backend)
66+
backend_class = resolve_obj_by_qualname(backend_class_str)
67+
supported_sizes = backend_class.get_supported_kernel_block_size()
68+
if supported_sizes:
69+
default_size = supported_sizes[0]
70+
block_size = (
71+
default_size if isinstance(default_size, int) else default_size.base
72+
)
73+
else:
74+
block_size = 16
75+
BACKEND_BLOCK_SIZES[backend] = block_size
76+
4577
torch.manual_seed(42)
4678

4779

@@ -236,6 +268,26 @@ def __init__(self, device: torch.device):
236268
self._q_scale = torch.tensor(1.0, device=device)
237269
self._k_scale = torch.tensor(1.0, device=device)
238270
self._v_scale = torch.tensor(1.0, device=device)
271+
self._prob_scale = torch.tensor(1.0, device=device)
272+
self._q_scale_float = 1.0
273+
self._k_scale_float = 1.0
274+
self._v_scale_float = 1.0
275+
276+
def forward(self, *_args, **_kwargs):
277+
raise NotImplementedError
278+
279+
280+
class MockMLAAttentionLayer(AttentionLayerBase):
281+
"""A mock MLA attention layer for populating static_forward_context."""
282+
283+
def __init__(self, impl):
284+
self.impl = impl
285+
286+
def get_attn_backend(self):
287+
raise NotImplementedError
288+
289+
def get_kv_cache_spec(self, vllm_config):
290+
raise NotImplementedError
239291

240292

241293
def run_attention_backend(
@@ -262,13 +314,6 @@ def run_attention_backend(
262314
# Set the current vllm config so that get_current_vllm_config() works
263315
# in the backend implementations
264316
with set_current_vllm_config(vllm_config):
265-
# Build metadata
266-
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
267-
attn_metadata = builder.build(
268-
common_prefix_len=0,
269-
common_attn_metadata=common_attn_metadata,
270-
)
271-
272317
# Instantiate MLA implementation
273318
num_heads = vllm_config.model_config.get_num_attention_heads(
274319
vllm_config.parallel_config
@@ -302,6 +347,19 @@ def run_attention_backend(
302347
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
303348
impl.process_weights_after_loading(act_dtype)
304349

350+
# Populate static_forward_context with mock attention layers
351+
for layer_name in layer_names:
352+
vllm_config.compilation_config.static_forward_context[layer_name] = (
353+
MockMLAAttentionLayer(impl)
354+
)
355+
356+
# Build metadata
357+
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
358+
attn_metadata = builder.build(
359+
common_prefix_len=0,
360+
common_attn_metadata=common_attn_metadata,
361+
)
362+
305363
# Create mock layer and output buffer
306364
mock_layer = MockAttentionLayer(device)
307365
num_tokens = query.shape[0]
@@ -353,15 +411,14 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
353411
simulated paged KV cache.
354412
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
355413
"""
356-
from vllm.v1.attention.backends.mla.common import QueryLenSupport
357414

358415
batch_spec = BATCH_SPECS[batch_spec_name]
359416
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
360-
spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA}
361-
362-
block_size = 16
417+
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values()))
418+
default_block_size = unique_block_sizes[0]
363419
required_blocks = sum(
364-
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
420+
(seq_len + default_block_size - 1) // default_block_size
421+
for seq_len in batch_spec.seq_lens
365422
)
366423
# Add 1 for null block at index 0, and some buffer
367424
num_gpu_blocks = required_blocks + 1 + 100
@@ -370,7 +427,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
370427
model_name=model,
371428
max_model_len=max(batch_spec.seq_lens),
372429
num_gpu_blocks=num_gpu_blocks,
373-
block_size=block_size,
430+
block_size=default_block_size,
374431
)
375432

376433
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
@@ -388,8 +445,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
388445

389446
device = torch.device("cuda:0")
390447

391-
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
392-
393448
# 1. Setup
394449
batch_size = batch_spec.batch_size
395450
seq_lens = batch_spec.seq_lens
@@ -399,7 +454,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
399454
)
400455
head_size = vllm_config.model_config.get_head_size()
401456
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
402-
block_size = vllm_config.cache_config.block_size
403457
kv_lora_rank = 512
404458
qk_rope_head_dim = 64
405459
qk_nope_head_dim = 128
@@ -598,33 +652,83 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
598652
)
599653
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
600654

601-
# Create metadata using original batch spec
602-
common_attn_metadata = create_common_attn_metadata(
603-
batch_spec, vllm_config.cache_config.block_size, device
604-
)
655+
# 3. Create metadata and KV caches for each block size
656+
# Group backends by block size and test each group
657+
metadata_per_block_size = {}
658+
kv_cache_per_block_size = {}
605659

606-
# 3. Simulate Paged KV Cache and a realistic slot_mapping
607-
kv_cache = create_and_prepopulate_kv_cache(
608-
kv_c_contexts=kv_c_contexts,
609-
k_pe_contexts=k_pe_contexts,
610-
block_size=block_size,
611-
head_size=head_size,
612-
dtype=dtype,
613-
device=device,
614-
num_blocks=vllm_config.cache_config.num_gpu_blocks,
615-
common_attn_metadata=common_attn_metadata,
616-
randomize_blocks=True,
617-
)
660+
for block_size in unique_block_sizes:
661+
# Create metadata for this block size
662+
common_attn_metadata = create_common_attn_metadata(
663+
batch_spec, block_size, device
664+
)
665+
666+
# Pad block table to meet requirement:
667+
# block_num % (128 / block_size) == 0
668+
required_divisor = int(128 / block_size)
669+
current_block_num = common_attn_metadata.block_table_tensor.shape[1]
670+
if current_block_num % required_divisor != 0:
671+
# Pad to next multiple of required_divisor
672+
padded_block_num = (
673+
(current_block_num + required_divisor - 1) // required_divisor
674+
) * required_divisor
675+
padding_cols = padded_block_num - current_block_num
676+
padding = torch.zeros(
677+
(common_attn_metadata.block_table_tensor.shape[0], padding_cols),
678+
dtype=torch.int32,
679+
device=device,
680+
)
681+
common_attn_metadata.block_table_tensor = torch.cat(
682+
[common_attn_metadata.block_table_tensor, padding], dim=1
683+
)
684+
685+
metadata_per_block_size[block_size] = common_attn_metadata
686+
687+
# Create KV cache for this block size
688+
required_blocks_for_size = sum(
689+
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
690+
)
691+
num_blocks_for_size = required_blocks_for_size + 1 + 100
692+
693+
kv_cache = create_and_prepopulate_kv_cache(
694+
kv_c_contexts=kv_c_contexts,
695+
k_pe_contexts=k_pe_contexts,
696+
block_size=block_size,
697+
head_size=head_size,
698+
dtype=dtype,
699+
device=device,
700+
num_blocks=num_blocks_for_size,
701+
common_attn_metadata=common_attn_metadata,
702+
randomize_blocks=True,
703+
)
704+
kv_cache_per_block_size[block_size] = kv_cache
618705

619706
# 4. Run vLLM backends and compare
707+
failures = []
620708
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
621709
# Skip backends that don't support spec decode for spec decode tests
622-
if is_spec_decode_test and backend_name not in spec_decode_backends:
710+
if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
623711
continue
624712

713+
# Get the appropriate block_size, metadata, and cache for this backend
714+
block_size = BACKEND_BLOCK_SIZES[backend_name]
715+
common_attn_metadata = metadata_per_block_size[block_size]
716+
kv_cache = kv_cache_per_block_size[block_size]
717+
718+
# Create kv_cache_spec with the correct block_size for this backend
719+
backend_kv_cache_spec = FullAttentionSpec(
720+
block_size=block_size,
721+
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
722+
vllm_config.parallel_config
723+
),
724+
head_size=vllm_config.model_config.get_head_size(),
725+
dtype=vllm_config.model_config.dtype,
726+
sliding_window=vllm_config.model_config.get_sliding_window(),
727+
)
728+
625729
backend_output = run_attention_backend(
626730
backend_name,
627-
kv_cache_spec,
731+
backend_kv_cache_spec,
628732
["placeholder"],
629733
vllm_config,
630734
device,
@@ -644,32 +748,48 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
644748
expected_output = sdpa_outputs[backend_name]
645749

646750
# Check shape and dtype consistency
647-
assert backend_output.shape == expected_output.shape, (
648-
f"[{backend_name}] shape {backend_output.shape} != "
649-
f"SDPA shape {expected_output.shape}"
650-
)
651-
assert backend_output.dtype == expected_output.dtype, (
652-
f"[{backend_name}] dtype {backend_output.dtype} != "
653-
f"SDPA dtype {expected_output.dtype}"
654-
)
751+
try:
752+
assert backend_output.shape == expected_output.shape, (
753+
f"[{backend_name}] shape {backend_output.shape} != "
754+
f"SDPA shape {expected_output.shape}"
755+
)
756+
assert backend_output.dtype == expected_output.dtype, (
757+
f"[{backend_name}] dtype {backend_output.dtype} != "
758+
f"SDPA dtype {expected_output.dtype}"
759+
)
655760

656-
assert torch.isfinite(backend_output).all(), (
657-
f"[{backend_name}] produced non-finite values"
658-
)
761+
assert torch.isfinite(backend_output).all(), (
762+
f"[{backend_name}] produced non-finite values"
763+
)
659764

660-
# Check numerical similarity
661-
rtol = 1e-2
662-
atol = 5e-1
765+
# Check numerical similarity
766+
rtol = 1e-2
767+
atol = 5e-1
663768

664-
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
665-
max_rel_diff = torch.max(
666-
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
667-
).item()
668-
all_close = torch.allclose(
669-
backend_output, expected_output, rtol=rtol, atol=atol
670-
)
769+
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
770+
max_rel_diff = torch.max(
771+
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
772+
).item()
773+
all_close = torch.allclose(
774+
backend_output, expected_output, rtol=rtol, atol=atol
775+
)
671776

672-
assert all_close, (
673-
f"[{backend_name}] output differs from SDPA baseline. "
674-
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
675-
)
777+
assert all_close, (
778+
f"[{backend_name}] output differs from SDPA baseline. "
779+
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
780+
)
781+
except AssertionError as e:
782+
failures.append(str(e))
783+
784+
# Report all failures at once
785+
if failures:
786+
# Create a summary for the single-line failure message
787+
backend_names = []
788+
for f in failures:
789+
if "[_Backend." in f:
790+
backend_name = f.split("[")[1].split("]")[0]
791+
backend_names.append(backend_name)
792+
793+
summary = f"{len(failures)} backend(s) failed: {', '.join(backend_names)}"
794+
detailed_msg = "\n".join(failures)
795+
pytest.fail(f"{summary}\n{detailed_msg}")

tests/v1/attention/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,17 @@ class BackendConfig:
285285
name="CutlassMLA",
286286
env_vars={
287287
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
288-
"FORCE_NUM_KV_SPLITS": "1", # TODO: remove this when hang issue is fixed
288+
},
289+
comp_config={
290+
"cudagraph_mode": "FULL_AND_PIECEWISE",
291+
},
292+
specific_gpu_arch=(10, 0),
293+
),
294+
# FlashInfer MLA on Blackwell
295+
"FlashInferMLA": BackendConfig(
296+
name="FlashInferMLA",
297+
env_vars={
298+
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
289299
},
290300
comp_config={
291301
"cudagraph_mode": "FULL_AND_PIECEWISE",

0 commit comments

Comments
 (0)