File tree 2 files changed +15
-0
lines changed
2 files changed +15
-0
lines changed Original file line number Diff line number Diff line change 4
4
"""
5
5
import pytest
6
6
7
+ from vllm .attention .selector import VLLM_ATTENTION_BACKEND
8
+
7
9
MODELS = [
8
10
"facebook/opt-125m" ,
9
11
"meta-llama/Llama-2-7b-hf" ,
14
16
@pytest .mark .parametrize ("dtype" , ["half" ])
15
17
@pytest .mark .parametrize ("max_tokens" , [5 ])
16
18
@pytest .mark .parametrize ("enforce_eager" , [False , True ])
19
+ @pytest .mark .parametrize ("attn_backend" , ["XFORMERS" , "FLASH_ATTN" ])
17
20
def test_models (
18
21
hf_runner ,
19
22
vllm_runner ,
@@ -22,7 +25,10 @@ def test_models(
22
25
dtype : str ,
23
26
max_tokens : int ,
24
27
enforce_eager : bool ,
28
+ attn_backend : str ,
29
+ monkeypatch ,
25
30
) -> None :
31
+ monkeypatch .setenv (VLLM_ATTENTION_BACKEND , attn_backend )
26
32
hf_model = hf_runner (model , dtype = dtype )
27
33
hf_outputs = hf_model .generate_greedy (example_prompts , max_tokens )
28
34
del hf_model
Original file line number Diff line number Diff line change 1
1
import enum
2
+ import os
2
3
from functools import lru_cache
3
4
from typing import Type
4
5
10
11
11
12
logger = init_logger (__name__ )
12
13
14
+ VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
15
+
13
16
14
17
class _Backend (enum .Enum ):
15
18
FLASH_ATTN = enum .auto ()
@@ -75,4 +78,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
75
78
"Cannot use FlashAttention backend because the flash_attn package "
76
79
"is not found. Please install it for better performance." )
77
80
return _Backend .XFORMERS
81
+
82
+ backend_by_env_var = os .getenv (VLLM_ATTENTION_BACKEND )
83
+ if backend_by_env_var is not None :
84
+ return _Backend [backend_by_env_var ]
85
+
86
+ # Default case.
78
87
return _Backend .FLASH_ATTN
You can’t perform that action at this time.
0 commit comments