@@ -276,17 +276,12 @@ def get_attn_backend_cls(
276276                    "FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set " 
277277                    "VLLM_MLA_DISABLE=1 to disable MLA for this model." 
278278                )
279-             if  not  use_v1 :
280-                 raise  RuntimeError (
281-                     "MLA attention backends require the V1 engine. " 
282-                     "Set VLLM_USE_V1=1 to enable them." 
283-                 )
284279
285280            from  vllm .attention .ops .flashmla  import  is_flashmla_dense_supported 
286281            from  vllm .attention .utils .fa_utils  import  flash_attn_supports_mla 
287282
288283            if  use_sparse :
289-                 logger .info_once ("Using Sparse MLA backend on V1 engine ." )
284+                 logger .info_once ("Using Sparse MLA backend." )
290285                return  (
291286                    "vllm.v1.attention.backends.mla.flashmla_sparse." 
292287                    "FlashMLASparseBackend" 
@@ -313,15 +308,13 @@ def get_attn_backend_cls(
313308            )
314309
315310            if  use_cutlassmla :
316-                 logger .info_once (
317-                     "Using Cutlass MLA backend on V1 engine." , scope = "local" 
318-                 )
311+                 logger .info_once ("Using Cutlass MLA backend." , scope = "local" )
319312                return  "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" 
320313            if  use_flashinfermla :
321314                from  vllm .v1 .attention .backends .utils  import  set_kv_cache_layout 
322315
323316                set_kv_cache_layout ("HND" )
324-                 logger .info_once ("Using FlashInfer MLA backend on V1 engine ." )
317+                 logger .info_once ("Using FlashInfer MLA backend." )
325318                return  (
326319                    "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" 
327320                )
@@ -333,116 +326,107 @@ def get_attn_backend_cls(
333326                        block_size ,
334327                    )
335328                else :
336-                     logger .info_once ("Using FlashMLA backend on V1 engine ." )
329+                     logger .info_once ("Using FlashMLA backend." )
337330                    return  "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" 
338331            if  use_flashattn :
339-                 logger .info_once ("Using FlashAttention MLA backend on V1 engine ." )
332+                 logger .info_once ("Using FlashAttention MLA backend." )
340333                return  (
341334                    "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" 
342335                )
343336            if  use_triton :
344-                 logger .info_once ("Using Triton MLA backend on V1 engine ." )
337+                 logger .info_once ("Using Triton MLA backend." )
345338                return  "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" 
346-         if  use_v1 :
347-             FLASHINFER_V1  =  "vllm.v1.attention.backends.flashinfer.FlashInferBackend"   # noqa: E501 
348-             FLEX_ATTENTION_V1  =  (
349-                 "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"   # noqa: E501 
350-             )
351-             TRITON_ATTN  =  (
352-                 "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"   # noqa: E501 
353-             )
354-             FLASH_ATTN_V1  =  (
355-                 "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"   # noqa: E501 
356-             )
357-             TREE_ATTN_V1  =  "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"   # noqa: E501 
358-             XFORMERS_V1  =  "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"   # noqa: E501 
359339
360-             use_fp8_kv_cache  =  kv_cache_dtype  is  not   None  and  kv_cache_dtype .startswith (
361-                 "fp8" 
362-             )
340+         FLASHINFER_V1  =  "vllm.v1.attention.backends.flashinfer.FlashInferBackend"   # noqa: E501 
341+         FLEX_ATTENTION_V1  =  (
342+             "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"   # noqa: E501 
343+         )
344+         TRITON_ATTN  =  "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"   # noqa: E501 
345+         FLASH_ATTN_V1  =  "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"   # noqa: E501 
346+         TREE_ATTN_V1  =  "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"   # noqa: E501 
347+         XFORMERS_V1  =  "vllm.v1.attention.backends.xformers.XFormersAttentionBackend"   # noqa: E501 
363348
364-             if  selected_backend  ==  _Backend .FLASHINFER :
365-                 logger .info_once ("Using FlashInfer backend on V1 engine." )
366-                 if  cls .has_device_capability (100 ):
367-                     from  vllm .v1 .attention .backends .utils  import  set_kv_cache_layout 
349+         use_fp8_kv_cache  =  kv_cache_dtype  is  not   None  and  kv_cache_dtype .startswith (
350+             "fp8" 
351+         )
368352
369-                     set_kv_cache_layout ("HND" )
370-                 return  FLASHINFER_V1 
371-             elif  selected_backend  ==  _Backend .FLEX_ATTENTION :
372-                 logger .info_once ("Using FlexAttention backend on V1 engine." )
373-                 return  FLEX_ATTENTION_V1 
374-             elif  selected_backend  ==  _Backend .TRITON_ATTN :
375-                 logger .info_once ("Using Triton backend on V1 engine." )
376-                 return  TRITON_ATTN 
377-             elif  selected_backend  ==  _Backend .FLASH_ATTN :
378-                 logger .info_once ("Using Flash Attention backend on V1 engine." )
379-                 return  FLASH_ATTN_V1 
380-             elif  selected_backend  ==  _Backend .TREE_ATTN :
381-                 logger .info_once ("Using Tree Attention backend on V1 engine." )
382-                 return  TREE_ATTN_V1 
383-             elif  selected_backend  ==  _Backend .XFORMERS :
384-                 logger .info_once ("Using XFormers backend on V1 engine." )
385-                 return  XFORMERS_V1 
353+         if  selected_backend  ==  _Backend .FLASHINFER :
354+             logger .info_once ("Using FlashInfer backend." )
355+             if  cls .has_device_capability (100 ):
356+                 from  vllm .v1 .attention .backends .utils  import  set_kv_cache_layout 
386357
387-             from  vllm .attention .selector  import  is_attn_backend_supported 
358+                 set_kv_cache_layout ("HND" )
359+             return  FLASHINFER_V1 
360+         elif  selected_backend  ==  _Backend .FLEX_ATTENTION :
361+             logger .info_once ("Using FlexAttention backend." )
362+             return  FLEX_ATTENTION_V1 
363+         elif  selected_backend  ==  _Backend .TRITON_ATTN :
364+             logger .info_once ("Using Triton backend." )
365+             return  TRITON_ATTN 
366+         elif  selected_backend  ==  _Backend .FLASH_ATTN :
367+             logger .info_once ("Using Flash Attention backend." )
368+             return  FLASH_ATTN_V1 
369+         elif  selected_backend  ==  _Backend .TREE_ATTN :
370+             logger .info_once ("Using Tree Attention backend." )
371+             return  TREE_ATTN_V1 
372+         elif  selected_backend  ==  _Backend .XFORMERS :
373+             logger .info_once ("Using XFormers backend." )
374+             return  XFORMERS_V1 
375+ 
376+         from  vllm .attention .selector  import  is_attn_backend_supported 
377+ 
378+         # Default backends for V1 engine 
379+         # Prefer FlashInfer for Blackwell GPUs if installed 
380+         if  cls .is_device_capability (100 ):
381+             if  is_default_backend_supported  :=  is_attn_backend_supported (
382+                 FLASHINFER_V1 , head_size , dtype 
383+             ):
384+                 from  vllm .v1 .attention .backends .utils  import  set_kv_cache_layout 
388385
389-             # Default backends for V1 engine 
390-             # Prefer FlashInfer for Blackwell GPUs if installed 
391-             if  cls .is_device_capability (100 ):
392-                 if  is_default_backend_supported  :=  is_attn_backend_supported (
393-                     FLASHINFER_V1 , head_size , dtype 
394-                 ):
395-                     from  vllm .v1 .attention .backends .utils  import  set_kv_cache_layout 
396- 
397-                     logger .info_once (
398-                         "Using FlashInfer backend with HND KV cache layout on " 
399-                         "V1 engine by default for Blackwell (SM 10.0) GPUs." 
400-                     )
401-                     set_kv_cache_layout ("HND" )
386+                 logger .info_once (
387+                     "Using FlashInfer backend with HND KV cache layout on " 
388+                     "V1 engine by default for Blackwell (SM 10.0) GPUs." 
389+                 )
390+                 set_kv_cache_layout ("HND" )
402391
403-                      return  FLASHINFER_V1 
392+                 return  FLASHINFER_V1 
404393
405-                  if  not  is_default_backend_supported .can_import :
406-                      logger .warning_once (
407-                          "FlashInfer failed to import for V1 engine on  " 
408-                          "Blackwell (SM 10.0) GPUs;  it is recommended to "
409-                          "install FlashInfer for better  performance."
410-                      )
394+             if  not  is_default_backend_supported .can_import :
395+                 logger .warning_once (
396+                     "FlashInfer failed to import on Blackwell (SM 10.0) GPUs;  " 
397+                     " it is recommended to install FlashInfer for better  "
398+                     " performance."
399+                 )
411400
412-             # FlashAttention is the default for SM 8.0+ GPUs 
413-             if  cls .has_device_capability (80 ):
414-                 if  (has_sink  or  use_fp8_kv_cache ) and  not  cls .is_device_capability (90 ):
415-                     logger .info_once ("Using Triton backend on V1 engine." )
416-                     return  TRITON_ATTN 
417-                 elif  is_default_backend_supported  :=  is_attn_backend_supported (
418-                     FLASH_ATTN_V1 , head_size , dtype , allow_import_error = False 
419-                 ):
420-                     logger .info_once ("Using Flash Attention backend on V1 engine." )
421-                     return  FLASH_ATTN_V1 
422- 
423-             # FlexAttention is the default for older GPUs 
424-             else :
425-                 logger .info_once ("Using FlexAttention backend on V1 engine." )
426-                 return  FLEX_ATTENTION_V1 
401+         # FlashAttention is the default for SM 8.0+ GPUs 
402+         if  cls .has_device_capability (80 ):
403+             if  (has_sink  or  use_fp8_kv_cache ) and  not  cls .is_device_capability (90 ):
404+                 logger .info_once ("Using Triton backend." )
405+                 return  TRITON_ATTN 
406+             elif  is_default_backend_supported  :=  is_attn_backend_supported (
407+                 FLASH_ATTN_V1 , head_size , dtype , allow_import_error = False 
408+             ):
409+                 logger .info_once ("Using Flash Attention backend." )
410+                 return  FLASH_ATTN_V1 
427411
428-             assert  not  is_default_backend_supported 
412+         # FlexAttention is the default for older GPUs 
413+         else :
414+             logger .info_once ("Using FlexAttention backend." )
415+             return  FLEX_ATTENTION_V1 
429416
430-             use_flex_attention_reason  =  {}
431-             if  not  is_default_backend_supported .head_size :
432-                 use_flex_attention_reason ["head_size" ] =  head_size 
433-             if  not  is_default_backend_supported .dtype :
434-                 use_flex_attention_reason ["dtype" ] =  dtype 
417+         assert  not  is_default_backend_supported 
435418
436-              logger . info_once ( 
437-                  "Using FlexAttention backend for %s on V1 engine." , 
438-                  ", " . join ( f" { k } = { v } "   for   k ,  v   in   use_flex_attention_reason . items ()), 
439-             ) 
440-             return   FLEX_ATTENTION_V1 
419+         use_flex_attention_reason   =  {} 
420+         if   not   is_default_backend_supported . head_size : 
421+             use_flex_attention_reason [ "head_size" ]  =   head_size 
422+         if   not   is_default_backend_supported . dtype : 
423+             use_flex_attention_reason [ "dtype" ]  =   dtype 
441424
442-         raise   RuntimeError (
443-             "V0 attention backends have been removed. Set VLLM_USE_V1=1 "  
444-             "to select a supported backend."  
425+         logger . info_once (
426+             "Using FlexAttention backend for %s."  , 
427+             ", "  . join ( f" { k } = { v } "   for   k ,  v   in   use_flex_attention_reason . items ()), 
445428        )
429+         return  FLEX_ATTENTION_V1 
446430
447431    @classmethod  
448432    def  get_punica_wrapper (cls ) ->  str :
0 commit comments