Skip to content

Commit 589ce62

Browse files
committed
fix flexattn
1 parent 122e93a commit 589ce62

File tree

2 files changed

+53
-52
lines changed

2 files changed

+53
-52
lines changed

torchtitan/experiments/gpt_oss/model/model.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -180,16 +180,17 @@ def apply_rotary_emb(q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor):
180180
return q_out, k_out
181181

182182
# Torch Attention backup implementation (for debugging and sampling) from HuggingFace
183-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
184-
"""
185-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
186-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
187-
"""
188-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
183+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
184+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
185+
bs, slen, n_kv_heads, head_dim = x.shape
189186
if n_rep == 1:
190-
return hidden_states
191-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
192-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
187+
return x
188+
return (
189+
torch.unsqueeze(x, dim=3)
190+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
191+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
192+
)
193+
193194

194195
# TODO(jianw): This is eager version from HuggingFace
195196
def eager_attention_forward(
@@ -200,12 +201,9 @@ def eager_attention_forward(
200201
attention_mask: torch.Tensor,
201202
scaling: float,
202203
dropout: float = 0.0,
203-
num_key_value_groups: int = 1,
204204
**kwargs,
205205
):
206-
key_states = repeat_kv(key, num_key_value_groups)
207-
value_states = repeat_kv(value, num_key_value_groups)
208-
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
206+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
209207
if attention_mask is not None:
210208
# attention_mask can be [Tq, Tk] or [B, H, Tq, Tk]
211209
# Convert boolean "allowed" -> additive mask
@@ -230,7 +228,7 @@ def eager_attention_forward(
230228
probs = nn.functional.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
231229
scores = probs[..., :-1] # we drop the sink here
232230
attn_weights = nn.functional.dropout(scores, p=dropout, training=False)
233-
attn_output = torch.matmul(attn_weights, value_states)
231+
attn_output = torch.matmul(attn_weights, value)
234232
return attn_output
235233

236234
class Attention(nn.Module):
@@ -243,6 +241,10 @@ def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = Fa
243241

244242
self.sliding_window = model_args.sliding_window if use_sliding_attention else None
245243
self.head_dim = model_args.head_dim
244+
self.n_heads = model_args.num_attention_heads
245+
self.n_kv_heads = model_args.num_key_value_heads
246+
247+
self.n_rep = self.n_heads // self.n_kv_heads
246248

247249
self.wq = nn.Linear(
248250
model_args.hidden_size, model_args.num_attention_heads * model_args.head_dim, bias=True
@@ -294,17 +296,19 @@ def forward(
294296

295297
q, k = apply_rotary_emb(q, k, freqs_cis)
296298

299+
# repeat k/v heads if n_kv_heads < n_heads
300+
keys = repeat_kv(k, self.n_rep)
301+
values = repeat_kv(v, self.n_rep)
302+
297303
q = q.transpose(1, 2).contiguous()
298-
k = k.transpose(1, 2).contiguous()
299-
v = v.transpose(1, 2).contiguous()
304+
k = keys.transpose(1, 2).contiguous()
305+
v = values.transpose(1, 2).contiguous()
300306

301307
if self.use_flex_attn:
302308
output = self.attn(
303309
q, k, v,
304310
scale=None,
305311
sink_weights=self.sinks.to_local() if isinstance(self.sinks, DTensor) else self.sinks,
306-
# sliding_window=self.sliding_window,
307-
enable_gqa=True if self.sliding_window else False,
308312
)
309313
else:
310314
# eager attention forward
@@ -313,7 +317,6 @@ def forward(
313317
attention_mask=self.sliding_window_causal(seqlen, x.device),
314318
scaling=self.head_dim**-0.5,
315319
dropout=0.0,
316-
num_key_value_groups=8,
317320
)
318321
output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D)
319322

torchtitan/models/attention.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
# FlexAttention mask type. For each mask type, we initialize it at most once per
2626
# batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to
2727
# track the initialized mask.
28-
FLEX_ATTN_MASK_T = tuple[str, int | None, int | None] # (mask_type, fixed_block_size, sliding_window)
28+
FLEX_ATTN_MASK_T = tuple[
29+
str, int | None, int | None
30+
] # (mask_type, fixed_block_size, sliding_window)
2931

3032

3133
class FlexAttention(torch.nn.Module):
@@ -64,7 +66,10 @@ class FlexAttention(torch.nn.Module):
6466
attn_mask_type: str
6567

6668
def __init__(
67-
self, attn_mask_type: str, fixed_block_size: int | None = None, sliding_window: int | None = None
69+
self,
70+
attn_mask_type: str,
71+
fixed_block_size: int | None = None,
72+
sliding_window: int | None = None,
6873
) -> None:
6974
super().__init__()
7075
if attn_mask_type not in ["causal", "block_causal", "sliding_window"]:
@@ -73,7 +78,6 @@ def __init__(
7378
self.fixed_block_size = fixed_block_size
7479
self.sliding_window = sliding_window
7580

76-
self.mask_cache = {}
7781
FlexAttention.used_attn_mask_types.add(self.mask_key)
7882

7983
@property
@@ -87,57 +91,44 @@ def forward(
8791
v: torch.Tensor,
8892
scale: float | None = None,
8993
sink_weights: torch.Tensor | None = None,
90-
# sliding_window: int = 0,
91-
enable_gqa: bool = False,
9294
) -> torch.Tensor:
93-
95+
9496
# Use sink logic when sliding_window is used and sink_weights is provided
9597
if self.attn_mask_type == "sliding_window" and sink_weights is not None:
96-
return self._forward_with_sink(q, k, v, scale, sink_weights, enable_gqa)
97-
98-
# Regular path without sink - use pre-compiled block masks
98+
return self._forward_with_sink(q, k, v, scale, sink_weights)
99+
100+
# Regular path without sink
99101
block_mask = FlexAttention.block_masks[self.mask_key]
100102
return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale)
101-
103+
102104
def _forward_with_sink(
103105
self,
104106
q: torch.Tensor,
105-
k: torch.Tensor,
107+
k: torch.Tensor,
106108
v: torch.Tensor,
107109
scale: float | None = None,
108110
sink_weights: torch.Tensor | None = None,
109-
enable_gqa: bool = False,
110111
) -> torch.Tensor:
111112
"""Forward pass with attention sink for sliding window attention."""
112-
B, H_q, S_q, D = q.shape
113-
_, H_kv, S_kv, _ = k.shape
114-
115-
if self.sliding_window is None or self.sliding_window <= 0:
116-
raise RuntimeError("sliding_window must be configured for sliding_window attention type")
117-
mask_key = ("sliding_window_sink", self.sliding_window, S_q, S_kv)
118-
if mask_key not in self.mask_cache:
119-
mask_mod = FlexAttention._get_sliding_window_mask_mod(self.sliding_window)
120-
block_mask = create_block_mask(
121-
mask_mod, B, H_q, S_q, S_kv,
122-
_compile=True, device=q.device
123-
)
124-
self.mask_cache[mask_key] = block_mask
125-
block_mask = self.mask_cache[mask_key]
113+
# Use the pre-compiled static block mask
114+
block_mask = FlexAttention.block_masks[self.mask_key]
126115

127116
# Run flex_attn and return LSE for sink computation
128117
out, lse = FlexAttention.flex_attn(
129-
q, k, v,
118+
q,
119+
k,
120+
v,
130121
block_mask=block_mask,
131-
enable_gqa=enable_gqa,
132122
return_lse=True,
133-
scale=scale
123+
scale=scale,
134124
)
135125

136126
# Apply attention sink rescaling: rescale by σ(lse - w[h])
137127
# This is mathematically equivalent to concatenating learnable sink weights
138128
if sink_weights is not None:
139-
w = sink_weights # [H]
140-
sink_scale = torch.sigmoid(lse - w.view(1, -1, 1)).unsqueeze(-1) # [B,H,S,1]
129+
sink_scale = torch.sigmoid(lse - sink_weights.view(1, -1, 1)).unsqueeze(
130+
-1
131+
) # [B,H,S,1]
141132
out = out * sink_scale
142133

143134
return out.to(q.dtype)
@@ -149,10 +140,12 @@ def _get_sliding_window_mask_mod(window: int):
149140
- only allows kv_idx ≤ q_idx (causal)
150141
- and only if (q_idx - kv_idx) ≤ window
151142
"""
143+
152144
def sliding_mod(b, h, q_idx, kv_idx):
153145
# causal within window
154146
keep = (kv_idx <= q_idx) & (q_idx - kv_idx <= window)
155147
return keep
148+
156149
return sliding_mod
157150

158151
@staticmethod
@@ -248,7 +241,9 @@ def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None:
248241
# We don't care about batch dimension --
249242
# all samples have the same sliding window mask.
250243
batch_dimension = 1
251-
mask_mod = FlexAttention._get_sliding_window_mask_mod(sliding_window)
244+
mask_mod = FlexAttention._get_sliding_window_mask_mod(
245+
sliding_window
246+
)
252247
case _:
253248
raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}")
254249

@@ -303,7 +298,10 @@ def forward(
303298

304299

305300
def build_attention(
306-
use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None, sliding_window: int | None = None
301+
use_flex_attn: bool,
302+
attn_mask_type: str,
303+
fixed_block_size: int | None = None,
304+
sliding_window: int | None = None,
307305
):
308306
if use_flex_attn:
309307
return FlexAttention(attn_mask_type, fixed_block_size, sliding_window)

0 commit comments

Comments
 (0)