1414from tests .v1 .attention .utils import (
1515 BatchSpec ,
1616 create_common_attn_metadata ,
17- create_standard_kv_cache_spec ,
1817 create_vllm_config ,
1918 try_get_attention_backend ,
2019)
2120from vllm import _custom_ops as ops
22- from vllm .attention .backends .registry import _Backend
21+ from vllm .attention .backends .registry import _Backend , backend_to_class_str
2322from vllm .attention .ops .flashmla import is_flashmla_dense_supported
23+ from vllm .attention .utils .fa_utils import flash_attn_supports_mla
2424from vllm .config .vllm import set_current_vllm_config
25+ from vllm .model_executor .layers .attention_layer_base import AttentionLayerBase
26+ from vllm .utils .import_utils import resolve_obj_by_qualname
2527from vllm .utils .math_utils import cdiv
2628from vllm .utils .torch_utils import STR_DTYPE_TO_TORCH_DTYPE
29+ from vllm .v1 .attention .backends .mla .common import QueryLenSupport
2730from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
2831from vllm .v1 .kv_cache_interface import FullAttentionSpec
2932
3033BACKENDS_TO_TEST = [
3134 _Backend .CUTLASS_MLA ,
3235 _Backend .FLASHMLA ,
3336 _Backend .FLASH_ATTN_MLA ,
37+ _Backend .FLASHINFER_MLA ,
3438 _Backend .TRITON_MLA ,
3539]
3640
37- # Remove CUTLASS_MLA from the list if not using sm100
41+ # Remove sm100 backends from the list if not using sm100
3842if not torch .cuda .is_available () or torch .cuda .get_device_properties (0 ).major < 10 :
3943 BACKENDS_TO_TEST .remove (_Backend .CUTLASS_MLA )
44+ BACKENDS_TO_TEST .remove (_Backend .FLASHINFER_MLA )
45+
46+ # Remove FLASH_ATTN_MLA from the list if not supported
47+ if not flash_attn_supports_mla ():
48+ BACKENDS_TO_TEST .remove (_Backend .FLASH_ATTN_MLA )
4049
4150# Remove FLASHMLA from the list if not supported
4251if not is_flashmla_dense_supported ()[0 ]:
4352 BACKENDS_TO_TEST .remove (_Backend .FLASHMLA )
4453
54+ SPEC_DECODE_BACKENDS = []
55+ for backend in BACKENDS_TO_TEST :
56+ builder_cls , _ = try_get_attention_backend (backend )
57+ query_len_support = getattr (
58+ builder_cls , "query_len_support" , QueryLenSupport .SINGLE_ONLY
59+ )
60+ if query_len_support != QueryLenSupport .SINGLE_ONLY :
61+ SPEC_DECODE_BACKENDS .append (backend )
62+
63+ BACKEND_BLOCK_SIZES = {}
64+ for backend in BACKENDS_TO_TEST :
65+ backend_class_str = backend_to_class_str (backend )
66+ backend_class = resolve_obj_by_qualname (backend_class_str )
67+ supported_sizes = backend_class .get_supported_kernel_block_size ()
68+ if supported_sizes :
69+ default_size = supported_sizes [0 ]
70+ block_size = (
71+ default_size if isinstance (default_size , int ) else default_size .base
72+ )
73+ else :
74+ block_size = 16
75+ BACKEND_BLOCK_SIZES [backend ] = block_size
76+
4577torch .manual_seed (42 )
4678
4779
@@ -236,6 +268,26 @@ def __init__(self, device: torch.device):
236268 self ._q_scale = torch .tensor (1.0 , device = device )
237269 self ._k_scale = torch .tensor (1.0 , device = device )
238270 self ._v_scale = torch .tensor (1.0 , device = device )
271+ self ._prob_scale = torch .tensor (1.0 , device = device )
272+ self ._q_scale_float = 1.0
273+ self ._k_scale_float = 1.0
274+ self ._v_scale_float = 1.0
275+
276+ def forward (self , * _args , ** _kwargs ):
277+ raise NotImplementedError
278+
279+
280+ class MockMLAAttentionLayer (AttentionLayerBase ):
281+ """A mock MLA attention layer for populating static_forward_context."""
282+
283+ def __init__ (self , impl ):
284+ self .impl = impl
285+
286+ def get_attn_backend (self ):
287+ raise NotImplementedError
288+
289+ def get_kv_cache_spec (self , vllm_config ):
290+ raise NotImplementedError
239291
240292
241293def run_attention_backend (
@@ -262,13 +314,6 @@ def run_attention_backend(
262314 # Set the current vllm config so that get_current_vllm_config() works
263315 # in the backend implementations
264316 with set_current_vllm_config (vllm_config ):
265- # Build metadata
266- builder = builder_cls (kv_cache_spec , layer_names , vllm_config , device )
267- attn_metadata = builder .build (
268- common_prefix_len = 0 ,
269- common_attn_metadata = common_attn_metadata ,
270- )
271-
272317 # Instantiate MLA implementation
273318 num_heads = vllm_config .model_config .get_num_attention_heads (
274319 vllm_config .parallel_config
@@ -302,6 +347,19 @@ def run_attention_backend(
302347 act_dtype = _convert_dtype_to_torch (vllm_config .model_config .dtype )
303348 impl .process_weights_after_loading (act_dtype )
304349
350+ # Populate static_forward_context with mock attention layers
351+ for layer_name in layer_names :
352+ vllm_config .compilation_config .static_forward_context [layer_name ] = (
353+ MockMLAAttentionLayer (impl )
354+ )
355+
356+ # Build metadata
357+ builder = builder_cls (kv_cache_spec , layer_names , vllm_config , device )
358+ attn_metadata = builder .build (
359+ common_prefix_len = 0 ,
360+ common_attn_metadata = common_attn_metadata ,
361+ )
362+
305363 # Create mock layer and output buffer
306364 mock_layer = MockAttentionLayer (device )
307365 num_tokens = query .shape [0 ]
@@ -353,15 +411,14 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
353411 simulated paged KV cache.
354412 5. Comparing the vLLM backend's output to the ground-truth SDPA output.
355413 """
356- from vllm .v1 .attention .backends .mla .common import QueryLenSupport
357414
358415 batch_spec = BATCH_SPECS [batch_spec_name ]
359416 is_spec_decode_test = batch_spec_name .startswith ("spec_decode" )
360- spec_decode_backends = {_Backend .FLASH_ATTN_MLA , _Backend .FLASHMLA }
361-
362- block_size = 16
417+ unique_block_sizes = sorted (set (BACKEND_BLOCK_SIZES .values ()))
418+ default_block_size = unique_block_sizes [0 ]
363419 required_blocks = sum (
364- (seq_len + block_size - 1 ) // block_size for seq_len in batch_spec .seq_lens
420+ (seq_len + default_block_size - 1 ) // default_block_size
421+ for seq_len in batch_spec .seq_lens
365422 )
366423 # Add 1 for null block at index 0, and some buffer
367424 num_gpu_blocks = required_blocks + 1 + 100
@@ -370,7 +427,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
370427 model_name = model ,
371428 max_model_len = max (batch_spec .seq_lens ),
372429 num_gpu_blocks = num_gpu_blocks ,
373- block_size = block_size ,
430+ block_size = default_block_size ,
374431 )
375432
376433 # For spec decode tests, add a speculative_config to set the reorder_batch_threshold
@@ -388,8 +445,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
388445
389446 device = torch .device ("cuda:0" )
390447
391- kv_cache_spec = create_standard_kv_cache_spec (vllm_config )
392-
393448 # 1. Setup
394449 batch_size = batch_spec .batch_size
395450 seq_lens = batch_spec .seq_lens
@@ -399,7 +454,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
399454 )
400455 head_size = vllm_config .model_config .get_head_size ()
401456 dtype = _convert_dtype_to_torch (vllm_config .model_config .dtype )
402- block_size = vllm_config .cache_config .block_size
403457 kv_lora_rank = 512
404458 qk_rope_head_dim = 64
405459 qk_nope_head_dim = 128
@@ -598,33 +652,83 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
598652 )
599653 mock_kv_b_proj .weight = torch .nn .Parameter (kv_b_proj_weight .T , requires_grad = False )
600654
601- # Create metadata using original batch spec
602- common_attn_metadata = create_common_attn_metadata (
603- batch_spec , vllm_config . cache_config . block_size , device
604- )
655+ # 3. Create metadata and KV caches for each block size
656+ # Group backends by block size and test each group
657+ metadata_per_block_size = {}
658+ kv_cache_per_block_size = {}
605659
606- # 3. Simulate Paged KV Cache and a realistic slot_mapping
607- kv_cache = create_and_prepopulate_kv_cache (
608- kv_c_contexts = kv_c_contexts ,
609- k_pe_contexts = k_pe_contexts ,
610- block_size = block_size ,
611- head_size = head_size ,
612- dtype = dtype ,
613- device = device ,
614- num_blocks = vllm_config .cache_config .num_gpu_blocks ,
615- common_attn_metadata = common_attn_metadata ,
616- randomize_blocks = True ,
617- )
660+ for block_size in unique_block_sizes :
661+ # Create metadata for this block size
662+ common_attn_metadata = create_common_attn_metadata (
663+ batch_spec , block_size , device
664+ )
665+
666+ # Pad block table to meet requirement:
667+ # block_num % (128 / block_size) == 0
668+ required_divisor = int (128 / block_size )
669+ current_block_num = common_attn_metadata .block_table_tensor .shape [1 ]
670+ if current_block_num % required_divisor != 0 :
671+ # Pad to next multiple of required_divisor
672+ padded_block_num = (
673+ (current_block_num + required_divisor - 1 ) // required_divisor
674+ ) * required_divisor
675+ padding_cols = padded_block_num - current_block_num
676+ padding = torch .zeros (
677+ (common_attn_metadata .block_table_tensor .shape [0 ], padding_cols ),
678+ dtype = torch .int32 ,
679+ device = device ,
680+ )
681+ common_attn_metadata .block_table_tensor = torch .cat (
682+ [common_attn_metadata .block_table_tensor , padding ], dim = 1
683+ )
684+
685+ metadata_per_block_size [block_size ] = common_attn_metadata
686+
687+ # Create KV cache for this block size
688+ required_blocks_for_size = sum (
689+ (seq_len + block_size - 1 ) // block_size for seq_len in batch_spec .seq_lens
690+ )
691+ num_blocks_for_size = required_blocks_for_size + 1 + 100
692+
693+ kv_cache = create_and_prepopulate_kv_cache (
694+ kv_c_contexts = kv_c_contexts ,
695+ k_pe_contexts = k_pe_contexts ,
696+ block_size = block_size ,
697+ head_size = head_size ,
698+ dtype = dtype ,
699+ device = device ,
700+ num_blocks = num_blocks_for_size ,
701+ common_attn_metadata = common_attn_metadata ,
702+ randomize_blocks = True ,
703+ )
704+ kv_cache_per_block_size [block_size ] = kv_cache
618705
619706 # 4. Run vLLM backends and compare
707+ failures = []
620708 for backend_idx , backend_name in enumerate (BACKENDS_TO_TEST ):
621709 # Skip backends that don't support spec decode for spec decode tests
622- if is_spec_decode_test and backend_name not in spec_decode_backends :
710+ if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS :
623711 continue
624712
713+ # Get the appropriate block_size, metadata, and cache for this backend
714+ block_size = BACKEND_BLOCK_SIZES [backend_name ]
715+ common_attn_metadata = metadata_per_block_size [block_size ]
716+ kv_cache = kv_cache_per_block_size [block_size ]
717+
718+ # Create kv_cache_spec with the correct block_size for this backend
719+ backend_kv_cache_spec = FullAttentionSpec (
720+ block_size = block_size ,
721+ num_kv_heads = vllm_config .model_config .get_num_kv_heads (
722+ vllm_config .parallel_config
723+ ),
724+ head_size = vllm_config .model_config .get_head_size (),
725+ dtype = vllm_config .model_config .dtype ,
726+ sliding_window = vllm_config .model_config .get_sliding_window (),
727+ )
728+
625729 backend_output = run_attention_backend (
626730 backend_name ,
627- kv_cache_spec ,
731+ backend_kv_cache_spec ,
628732 ["placeholder" ],
629733 vllm_config ,
630734 device ,
@@ -644,32 +748,48 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
644748 expected_output = sdpa_outputs [backend_name ]
645749
646750 # Check shape and dtype consistency
647- assert backend_output .shape == expected_output .shape , (
648- f"[{ backend_name } ] shape { backend_output .shape } != "
649- f"SDPA shape { expected_output .shape } "
650- )
651- assert backend_output .dtype == expected_output .dtype , (
652- f"[{ backend_name } ] dtype { backend_output .dtype } != "
653- f"SDPA dtype { expected_output .dtype } "
654- )
751+ try :
752+ assert backend_output .shape == expected_output .shape , (
753+ f"[{ backend_name } ] shape { backend_output .shape } != "
754+ f"SDPA shape { expected_output .shape } "
755+ )
756+ assert backend_output .dtype == expected_output .dtype , (
757+ f"[{ backend_name } ] dtype { backend_output .dtype } != "
758+ f"SDPA dtype { expected_output .dtype } "
759+ )
655760
656- assert torch .isfinite (backend_output ).all (), (
657- f"[{ backend_name } ] produced non-finite values"
658- )
761+ assert torch .isfinite (backend_output ).all (), (
762+ f"[{ backend_name } ] produced non-finite values"
763+ )
659764
660- # Check numerical similarity
661- rtol = 1e-2
662- atol = 5e-1
765+ # Check numerical similarity
766+ rtol = 1e-2
767+ atol = 5e-1
663768
664- max_diff = torch .max (torch .abs (backend_output - expected_output )).item ()
665- max_rel_diff = torch .max (
666- torch .abs (backend_output - expected_output ) / torch .abs (expected_output )
667- ).item ()
668- all_close = torch .allclose (
669- backend_output , expected_output , rtol = rtol , atol = atol
670- )
769+ max_diff = torch .max (torch .abs (backend_output - expected_output )).item ()
770+ max_rel_diff = torch .max (
771+ torch .abs (backend_output - expected_output ) / torch .abs (expected_output )
772+ ).item ()
773+ all_close = torch .allclose (
774+ backend_output , expected_output , rtol = rtol , atol = atol
775+ )
671776
672- assert all_close , (
673- f"[{ backend_name } ] output differs from SDPA baseline. "
674- f"Max diff: { max_diff :.6f} , max rel diff: { max_rel_diff :.6f} )"
675- )
777+ assert all_close , (
778+ f"[{ backend_name } ] output differs from SDPA baseline. "
779+ f"Max diff: { max_diff :.6f} , max rel diff: { max_rel_diff :.6f} )"
780+ )
781+ except AssertionError as e :
782+ failures .append (str (e ))
783+
784+ # Report all failures at once
785+ if failures :
786+ # Create a summary for the single-line failure message
787+ backend_names = []
788+ for f in failures :
789+ if "[_Backend." in f :
790+ backend_name = f .split ("[" )[1 ].split ("]" )[0 ]
791+ backend_names .append (backend_name )
792+
793+ summary = f"{ len (failures )} backend(s) failed: { ', ' .join (backend_names )} "
794+ detailed_msg = "\n " .join (failures )
795+ pytest .fail (f"{ summary } \n { detailed_msg } " )
0 commit comments