25
25
# FlexAttention mask type. For each mask type, we initialize it at most once per
26
26
# batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to
27
27
# track the initialized mask.
28
- FLEX_ATTN_MASK_T = tuple [str , int | None , int | None ] # (mask_type, fixed_block_size, sliding_window)
28
+ FLEX_ATTN_MASK_T = tuple [
29
+ str , int | None , int | None
30
+ ] # (mask_type, fixed_block_size, sliding_window)
29
31
30
32
31
33
class FlexAttention (torch .nn .Module ):
@@ -64,7 +66,10 @@ class FlexAttention(torch.nn.Module):
64
66
attn_mask_type : str
65
67
66
68
def __init__ (
67
- self , attn_mask_type : str , fixed_block_size : int | None = None , sliding_window : int | None = None
69
+ self ,
70
+ attn_mask_type : str ,
71
+ fixed_block_size : int | None = None ,
72
+ sliding_window : int | None = None ,
68
73
) -> None :
69
74
super ().__init__ ()
70
75
if attn_mask_type not in ["causal" , "block_causal" , "sliding_window" ]:
@@ -73,7 +78,6 @@ def __init__(
73
78
self .fixed_block_size = fixed_block_size
74
79
self .sliding_window = sliding_window
75
80
76
- self .mask_cache = {}
77
81
FlexAttention .used_attn_mask_types .add (self .mask_key )
78
82
79
83
@property
@@ -87,57 +91,44 @@ def forward(
87
91
v : torch .Tensor ,
88
92
scale : float | None = None ,
89
93
sink_weights : torch .Tensor | None = None ,
90
- # sliding_window: int = 0,
91
- enable_gqa : bool = False ,
92
94
) -> torch .Tensor :
93
-
95
+
94
96
# Use sink logic when sliding_window is used and sink_weights is provided
95
97
if self .attn_mask_type == "sliding_window" and sink_weights is not None :
96
- return self ._forward_with_sink (q , k , v , scale , sink_weights , enable_gqa )
97
-
98
- # Regular path without sink - use pre-compiled block masks
98
+ return self ._forward_with_sink (q , k , v , scale , sink_weights )
99
+
100
+ # Regular path without sink
99
101
block_mask = FlexAttention .block_masks [self .mask_key ]
100
102
return FlexAttention .flex_attn (q , k , v , block_mask = block_mask , scale = scale )
101
-
103
+
102
104
def _forward_with_sink (
103
105
self ,
104
106
q : torch .Tensor ,
105
- k : torch .Tensor ,
107
+ k : torch .Tensor ,
106
108
v : torch .Tensor ,
107
109
scale : float | None = None ,
108
110
sink_weights : torch .Tensor | None = None ,
109
- enable_gqa : bool = False ,
110
111
) -> torch .Tensor :
111
112
"""Forward pass with attention sink for sliding window attention."""
112
- B , H_q , S_q , D = q .shape
113
- _ , H_kv , S_kv , _ = k .shape
114
-
115
- if self .sliding_window is None or self .sliding_window <= 0 :
116
- raise RuntimeError ("sliding_window must be configured for sliding_window attention type" )
117
- mask_key = ("sliding_window_sink" , self .sliding_window , S_q , S_kv )
118
- if mask_key not in self .mask_cache :
119
- mask_mod = FlexAttention ._get_sliding_window_mask_mod (self .sliding_window )
120
- block_mask = create_block_mask (
121
- mask_mod , B , H_q , S_q , S_kv ,
122
- _compile = True , device = q .device
123
- )
124
- self .mask_cache [mask_key ] = block_mask
125
- block_mask = self .mask_cache [mask_key ]
113
+ # Use the pre-compiled static block mask
114
+ block_mask = FlexAttention .block_masks [self .mask_key ]
126
115
127
116
# Run flex_attn and return LSE for sink computation
128
117
out , lse = FlexAttention .flex_attn (
129
- q , k , v ,
118
+ q ,
119
+ k ,
120
+ v ,
130
121
block_mask = block_mask ,
131
- enable_gqa = enable_gqa ,
132
122
return_lse = True ,
133
- scale = scale
123
+ scale = scale ,
134
124
)
135
125
136
126
# Apply attention sink rescaling: rescale by σ(lse - w[h])
137
127
# This is mathematically equivalent to concatenating learnable sink weights
138
128
if sink_weights is not None :
139
- w = sink_weights # [H]
140
- sink_scale = torch .sigmoid (lse - w .view (1 , - 1 , 1 )).unsqueeze (- 1 ) # [B,H,S,1]
129
+ sink_scale = torch .sigmoid (lse - sink_weights .view (1 , - 1 , 1 )).unsqueeze (
130
+ - 1
131
+ ) # [B,H,S,1]
141
132
out = out * sink_scale
142
133
143
134
return out .to (q .dtype )
@@ -149,10 +140,12 @@ def _get_sliding_window_mask_mod(window: int):
149
140
- only allows kv_idx ≤ q_idx (causal)
150
141
- and only if (q_idx - kv_idx) ≤ window
151
142
"""
143
+
152
144
def sliding_mod (b , h , q_idx , kv_idx ):
153
145
# causal within window
154
146
keep = (kv_idx <= q_idx ) & (q_idx - kv_idx <= window )
155
147
return keep
148
+
156
149
return sliding_mod
157
150
158
151
@staticmethod
@@ -248,7 +241,9 @@ def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None:
248
241
# We don't care about batch dimension --
249
242
# all samples have the same sliding window mask.
250
243
batch_dimension = 1
251
- mask_mod = FlexAttention ._get_sliding_window_mask_mod (sliding_window )
244
+ mask_mod = FlexAttention ._get_sliding_window_mask_mod (
245
+ sliding_window
246
+ )
252
247
case _:
253
248
raise RuntimeError (f"Shouldn't reach here. { attn_mask_type } " )
254
249
@@ -303,7 +298,10 @@ def forward(
303
298
304
299
305
300
def build_attention (
306
- use_flex_attn : bool , attn_mask_type : str , fixed_block_size : int | None = None , sliding_window : int | None = None
301
+ use_flex_attn : bool ,
302
+ attn_mask_type : str ,
303
+ fixed_block_size : int | None = None ,
304
+ sliding_window : int | None = None ,
307
305
):
308
306
if use_flex_attn :
309
307
return FlexAttention (attn_mask_type , fixed_block_size , sliding_window )
0 commit comments