Skip to content

Commit e42df72

Browse files
rkooo567simon-mo
andauthored
[Test] Add xformer and flash attn tests (vllm-project#3961)
Co-authored-by: Simon Mo <[email protected]>
1 parent caada5e commit e42df72

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

tests/basic_correctness/test_basic_correctness.py

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
"""
55
import pytest
66

7+
from vllm.attention.selector import VLLM_ATTENTION_BACKEND
8+
79
MODELS = [
810
"facebook/opt-125m",
911
"meta-llama/Llama-2-7b-hf",
@@ -14,6 +16,7 @@
1416
@pytest.mark.parametrize("dtype", ["half"])
1517
@pytest.mark.parametrize("max_tokens", [5])
1618
@pytest.mark.parametrize("enforce_eager", [False, True])
19+
@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"])
1720
def test_models(
1821
hf_runner,
1922
vllm_runner,
@@ -22,7 +25,10 @@ def test_models(
2225
dtype: str,
2326
max_tokens: int,
2427
enforce_eager: bool,
28+
attn_backend: str,
29+
monkeypatch,
2530
) -> None:
31+
monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend)
2632
hf_model = hf_runner(model, dtype=dtype)
2733
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
2834
del hf_model

vllm/attention/selector.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import enum
2+
import os
23
from functools import lru_cache
34
from typing import Type
45

@@ -10,6 +11,8 @@
1011

1112
logger = init_logger(__name__)
1213

14+
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
15+
1316

1417
class _Backend(enum.Enum):
1518
FLASH_ATTN = enum.auto()
@@ -75,4 +78,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
7578
"Cannot use FlashAttention backend because the flash_attn package "
7679
"is not found. Please install it for better performance.")
7780
return _Backend.XFORMERS
81+
82+
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
83+
if backend_by_env_var is not None:
84+
return _Backend[backend_by_env_var]
85+
86+
# Default case.
7887
return _Backend.FLASH_ATTN

0 commit comments

Comments
 (0)