Skip to content

Commit 30a14b0

Browse files
wangxiyuanhmellor
andauthored
[V0 deprecation] Remove VLLM_USE_V1 usage in platform and v1 module (vllm-project#27798)
Signed-off-by: wangxiyuan <[email protected]> Co-authored-by: Harry Mellor <[email protected]>
1 parent 799ce45 commit 30a14b0

File tree

8 files changed

+125
-201
lines changed

8 files changed

+125
-201
lines changed

vllm/platforms/cuda.py

Lines changed: 84 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -276,17 +276,12 @@ def get_attn_backend_cls(
276276
"FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set "
277277
"VLLM_MLA_DISABLE=1 to disable MLA for this model."
278278
)
279-
if not use_v1:
280-
raise RuntimeError(
281-
"MLA attention backends require the V1 engine. "
282-
"Set VLLM_USE_V1=1 to enable them."
283-
)
284279

285280
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
286281
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
287282

288283
if use_sparse:
289-
logger.info_once("Using Sparse MLA backend on V1 engine.")
284+
logger.info_once("Using Sparse MLA backend.")
290285
return (
291286
"vllm.v1.attention.backends.mla.flashmla_sparse."
292287
"FlashMLASparseBackend"
@@ -313,15 +308,13 @@ def get_attn_backend_cls(
313308
)
314309

315310
if use_cutlassmla:
316-
logger.info_once(
317-
"Using Cutlass MLA backend on V1 engine.", scope="local"
318-
)
311+
logger.info_once("Using Cutlass MLA backend.", scope="local")
319312
return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
320313
if use_flashinfermla:
321314
from vllm.v1.attention.backends.utils import set_kv_cache_layout
322315

323316
set_kv_cache_layout("HND")
324-
logger.info_once("Using FlashInfer MLA backend on V1 engine.")
317+
logger.info_once("Using FlashInfer MLA backend.")
325318
return (
326319
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
327320
)
@@ -333,116 +326,107 @@ def get_attn_backend_cls(
333326
block_size,
334327
)
335328
else:
336-
logger.info_once("Using FlashMLA backend on V1 engine.")
329+
logger.info_once("Using FlashMLA backend.")
337330
return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
338331
if use_flashattn:
339-
logger.info_once("Using FlashAttention MLA backend on V1 engine.")
332+
logger.info_once("Using FlashAttention MLA backend.")
340333
return (
341334
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
342335
)
343336
if use_triton:
344-
logger.info_once("Using Triton MLA backend on V1 engine.")
337+
logger.info_once("Using Triton MLA backend.")
345338
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
346-
if use_v1:
347-
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
348-
FLEX_ATTENTION_V1 = (
349-
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
350-
)
351-
TRITON_ATTN = (
352-
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
353-
)
354-
FLASH_ATTN_V1 = (
355-
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
356-
)
357-
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
358-
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
359339

360-
use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith(
361-
"fp8"
362-
)
340+
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
341+
FLEX_ATTENTION_V1 = (
342+
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
343+
)
344+
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
345+
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
346+
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
347+
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
363348

364-
if selected_backend == _Backend.FLASHINFER:
365-
logger.info_once("Using FlashInfer backend on V1 engine.")
366-
if cls.has_device_capability(100):
367-
from vllm.v1.attention.backends.utils import set_kv_cache_layout
349+
use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith(
350+
"fp8"
351+
)
368352

369-
set_kv_cache_layout("HND")
370-
return FLASHINFER_V1
371-
elif selected_backend == _Backend.FLEX_ATTENTION:
372-
logger.info_once("Using FlexAttention backend on V1 engine.")
373-
return FLEX_ATTENTION_V1
374-
elif selected_backend == _Backend.TRITON_ATTN:
375-
logger.info_once("Using Triton backend on V1 engine.")
376-
return TRITON_ATTN
377-
elif selected_backend == _Backend.FLASH_ATTN:
378-
logger.info_once("Using Flash Attention backend on V1 engine.")
379-
return FLASH_ATTN_V1
380-
elif selected_backend == _Backend.TREE_ATTN:
381-
logger.info_once("Using Tree Attention backend on V1 engine.")
382-
return TREE_ATTN_V1
383-
elif selected_backend == _Backend.XFORMERS:
384-
logger.info_once("Using XFormers backend on V1 engine.")
385-
return XFORMERS_V1
353+
if selected_backend == _Backend.FLASHINFER:
354+
logger.info_once("Using FlashInfer backend.")
355+
if cls.has_device_capability(100):
356+
from vllm.v1.attention.backends.utils import set_kv_cache_layout
386357

387-
from vllm.attention.selector import is_attn_backend_supported
358+
set_kv_cache_layout("HND")
359+
return FLASHINFER_V1
360+
elif selected_backend == _Backend.FLEX_ATTENTION:
361+
logger.info_once("Using FlexAttention backend.")
362+
return FLEX_ATTENTION_V1
363+
elif selected_backend == _Backend.TRITON_ATTN:
364+
logger.info_once("Using Triton backend.")
365+
return TRITON_ATTN
366+
elif selected_backend == _Backend.FLASH_ATTN:
367+
logger.info_once("Using Flash Attention backend.")
368+
return FLASH_ATTN_V1
369+
elif selected_backend == _Backend.TREE_ATTN:
370+
logger.info_once("Using Tree Attention backend.")
371+
return TREE_ATTN_V1
372+
elif selected_backend == _Backend.XFORMERS:
373+
logger.info_once("Using XFormers backend.")
374+
return XFORMERS_V1
375+
376+
from vllm.attention.selector import is_attn_backend_supported
377+
378+
# Default backends for V1 engine
379+
# Prefer FlashInfer for Blackwell GPUs if installed
380+
if cls.is_device_capability(100):
381+
if is_default_backend_supported := is_attn_backend_supported(
382+
FLASHINFER_V1, head_size, dtype
383+
):
384+
from vllm.v1.attention.backends.utils import set_kv_cache_layout
388385

389-
# Default backends for V1 engine
390-
# Prefer FlashInfer for Blackwell GPUs if installed
391-
if cls.is_device_capability(100):
392-
if is_default_backend_supported := is_attn_backend_supported(
393-
FLASHINFER_V1, head_size, dtype
394-
):
395-
from vllm.v1.attention.backends.utils import set_kv_cache_layout
396-
397-
logger.info_once(
398-
"Using FlashInfer backend with HND KV cache layout on "
399-
"V1 engine by default for Blackwell (SM 10.0) GPUs."
400-
)
401-
set_kv_cache_layout("HND")
386+
logger.info_once(
387+
"Using FlashInfer backend with HND KV cache layout on "
388+
"V1 engine by default for Blackwell (SM 10.0) GPUs."
389+
)
390+
set_kv_cache_layout("HND")
402391

403-
return FLASHINFER_V1
392+
return FLASHINFER_V1
404393

405-
if not is_default_backend_supported.can_import:
406-
logger.warning_once(
407-
"FlashInfer failed to import for V1 engine on "
408-
"Blackwell (SM 10.0) GPUs; it is recommended to "
409-
"install FlashInfer for better performance."
410-
)
394+
if not is_default_backend_supported.can_import:
395+
logger.warning_once(
396+
"FlashInfer failed to import on Blackwell (SM 10.0) GPUs; "
397+
"it is recommended to install FlashInfer for better "
398+
"performance."
399+
)
411400

412-
# FlashAttention is the default for SM 8.0+ GPUs
413-
if cls.has_device_capability(80):
414-
if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90):
415-
logger.info_once("Using Triton backend on V1 engine.")
416-
return TRITON_ATTN
417-
elif is_default_backend_supported := is_attn_backend_supported(
418-
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
419-
):
420-
logger.info_once("Using Flash Attention backend on V1 engine.")
421-
return FLASH_ATTN_V1
422-
423-
# FlexAttention is the default for older GPUs
424-
else:
425-
logger.info_once("Using FlexAttention backend on V1 engine.")
426-
return FLEX_ATTENTION_V1
401+
# FlashAttention is the default for SM 8.0+ GPUs
402+
if cls.has_device_capability(80):
403+
if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90):
404+
logger.info_once("Using Triton backend.")
405+
return TRITON_ATTN
406+
elif is_default_backend_supported := is_attn_backend_supported(
407+
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
408+
):
409+
logger.info_once("Using Flash Attention backend.")
410+
return FLASH_ATTN_V1
427411

428-
assert not is_default_backend_supported
412+
# FlexAttention is the default for older GPUs
413+
else:
414+
logger.info_once("Using FlexAttention backend.")
415+
return FLEX_ATTENTION_V1
429416

430-
use_flex_attention_reason = {}
431-
if not is_default_backend_supported.head_size:
432-
use_flex_attention_reason["head_size"] = head_size
433-
if not is_default_backend_supported.dtype:
434-
use_flex_attention_reason["dtype"] = dtype
417+
assert not is_default_backend_supported
435418

436-
logger.info_once(
437-
"Using FlexAttention backend for %s on V1 engine.",
438-
", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
439-
)
440-
return FLEX_ATTENTION_V1
419+
use_flex_attention_reason = {}
420+
if not is_default_backend_supported.head_size:
421+
use_flex_attention_reason["head_size"] = head_size
422+
if not is_default_backend_supported.dtype:
423+
use_flex_attention_reason["dtype"] = dtype
441424

442-
raise RuntimeError(
443-
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
444-
"to select a supported backend."
425+
logger.info_once(
426+
"Using FlexAttention backend for %s.",
427+
", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
445428
)
429+
return FLEX_ATTENTION_V1
446430

447431
@classmethod
448432
def get_punica_wrapper(cls) -> str:

vllm/platforms/interface.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -467,14 +467,7 @@ def use_all_gather(cls) -> bool:
467467
"""
468468
Whether to use allgather in LogitsProcessor to gather the logits.
469469
"""
470-
import vllm.envs as envs
471-
from vllm.config import get_current_vllm_config
472-
473-
parallel_config = get_current_vllm_config().parallel_config
474-
return (
475-
envs.VLLM_USE_V1
476-
or parallel_config.distributed_executor_backend == "external_launcher"
477-
)
470+
return True
478471

479472
@classmethod
480473
def use_custom_allreduce(cls) -> bool:

vllm/platforms/rocm.py

Lines changed: 32 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def use_rocm_custom_paged_attention(
149149
# disabled due to observed numerical discrepancy.
150150
if ON_GFX9:
151151
return (
152-
(not envs.VLLM_USE_V1 or sliding_window == 0 or sliding_window == (-1, -1))
152+
(sliding_window == 0 or sliding_window == (-1, -1))
153153
and (qtype == torch.half or qtype == torch.bfloat16)
154154
and (head_size == 64 or head_size == 128)
155155
and (block_size == 16 or block_size == 32)
@@ -163,11 +163,7 @@ def use_rocm_custom_paged_attention(
163163
else:
164164
return (
165165
ON_GFX11_GFX12
166-
and (
167-
not envs.VLLM_USE_V1
168-
or sliding_window == 0
169-
or sliding_window == (-1, -1)
170-
)
166+
and (sliding_window == 0 or sliding_window == (-1, -1))
171167
and (qtype == torch.half or qtype == torch.bfloat16)
172168
and head_size == 128
173169
and block_size == 16
@@ -236,12 +232,6 @@ def get_attn_backend_cls(
236232
if use_sparse:
237233
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
238234
if use_mla:
239-
if not use_v1:
240-
raise RuntimeError(
241-
"MLA attention backends require the V1 engine. "
242-
"Set VLLM_USE_V1=1 to enable them."
243-
)
244-
245235
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
246236
is_aiter_mla_enabled,
247237
)
@@ -255,15 +245,15 @@ def get_attn_backend_cls(
255245

256246
if selected_backend == _Backend.TRITON_MLA:
257247
if block_size != 1:
258-
logger.info_once("Using Triton MLA backend on V1 engine.")
248+
logger.info_once("Using Triton MLA backend.")
259249
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
260250
raise ValueError(
261251
f" The selected backend, {selected_backend.name},"
262252
f"does not support block size {block_size}."
263253
)
264254
if selected_backend == _Backend.ROCM_AITER_MLA:
265255
if block_size == 1:
266-
logger.info("Using AITER MLA backend on V1 engine.")
256+
logger.info("Using AITER MLA backend.")
267257
return (
268258
"vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
269259
)
@@ -277,41 +267,33 @@ def get_attn_backend_cls(
277267
f"is not MLA type while requested for MLA backend."
278268
)
279269

280-
if envs.VLLM_USE_V1:
281-
if selected_backend == _Backend.FLEX_ATTENTION:
282-
logger.info("Using FlexAttention backend on V1 engine.")
283-
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
284-
if (
285-
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
286-
) or selected_backend == _Backend.ROCM_AITER_FA:
287-
logger.info("Using Aiter Flash Attention backend on V1 engine.")
288-
return (
289-
"vllm.v1.attention.backends."
290-
"rocm_aiter_fa.AiterFlashAttentionBackend"
291-
)
292-
if (
293-
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
294-
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
295-
logger.info("Using Aiter Unified Attention backend on V1 engine.")
296-
return (
297-
"vllm.v1.attention.backends."
298-
"rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
299-
)
300-
if (
301-
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
302-
or selected_backend == _Backend.ROCM_ATTN
303-
):
304-
# rocm specific backend, with aiter and/or
305-
# triton prefix-prefill
306-
logger.info("Using Rocm Attention backend on V1 engine.")
307-
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
308-
# default case, using triton unified attention
309-
logger.info("Using Triton Attention backend on V1 engine.")
310-
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
311-
raise RuntimeError(
312-
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
313-
"to select a supported backend."
314-
)
270+
if selected_backend == _Backend.FLEX_ATTENTION:
271+
logger.info("Using FlexAttention backend.")
272+
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
273+
if (
274+
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
275+
) or selected_backend == _Backend.ROCM_AITER_FA:
276+
logger.info("Using Aiter Flash Attention backend.")
277+
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
278+
if (
279+
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
280+
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
281+
logger.info("Using Aiter Unified Attention backend.")
282+
return (
283+
"vllm.v1.attention.backends."
284+
"rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
285+
)
286+
if (
287+
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
288+
or selected_backend == _Backend.ROCM_ATTN
289+
):
290+
# rocm specific backend, with aiter and/or
291+
# triton prefix-prefill
292+
logger.info("Using Rocm Attention backend.")
293+
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
294+
# default case, using triton unified attention
295+
logger.info("Using Triton Attention backend.")
296+
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
315297

316298
@classmethod
317299
def set_device(cls, device: torch.device) -> None:
@@ -372,7 +354,6 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
372354
parallel_config = vllm_config.parallel_config
373355
is_eager_execution = compilation_config == CUDAGraphMode.NONE
374356

375-
use_v1 = envs.VLLM_USE_V1
376357
use_aiter_rms_norm = (
377358
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM
378359
)
@@ -384,8 +365,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
384365
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
385366
# Aiter rms norm perform best when CUDA Graph capture is enabled.
386367
if (
387-
use_v1
388-
and use_aiter_rms_norm
368+
use_aiter_rms_norm
389369
and not is_eager_execution
390370
and "-rms_norm" not in compilation_config.custom_ops
391371
):

0 commit comments

Comments
 (0)