Skip to content

Commit afba179

Browse files
committed
[RFC] Refactor attention and make attention mask an argument to the model
**Status** The PR is not landable yet but server as a RFC. If people are okay with this design, this PR requires following changes and verifications: 1. Change all models, including the experimental ones. 2. E2E loss verification (this has been done for functional check, but loss verification is noot done yet). 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a seperate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is #1723. Now that pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix #1723 with pytorch/pytorch#164111 and this PR. 3. Provide a single AttentionOp instead of two. Justification: since the masking logic is moved outside, we don't need to do bookkeeping of masks in FlexAttentionWrapper. The logic is so simple that one AttentionOp makes things cleaner. Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certinaly can be confusion for Titan's users. I'm opn to merge them to AttentionOp. See the discussion in #1723. ghstack-source-id: 35aa425 Pull-Request-resolved: #1776
1 parent a310420 commit afba179

File tree

6 files changed

+323
-226
lines changed

6 files changed

+323
-226
lines changed

torchtitan/experiments/llama4/model/model.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,21 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Callable
78

89
import torch
910
import torch.nn.functional as F
1011
from torch import nn
1112

12-
from torchtitan.models.attention import build_attention
13+
from torchtitan.models.attention import (
14+
AttentionOp,
15+
get_causal_mask_mod,
16+
get_document_mask_mod,
17+
get_fixed_block_mask_mod,
18+
)
1319
from torchtitan.models.moe import MoE
14-
from torchtitan.protocols import ModelProtocol
20+
from torchtitan.protocols.model import AttentionMasksType
21+
from torchtitan.protocols.train_spec import ModelProtocol
1522

1623
from .args import TransformerModelArgs
1724

@@ -155,9 +162,7 @@ def __init__(
155162
# values of these two variables.
156163
self.use_rope = use_rope
157164

158-
self.sdpa = build_attention(
159-
model_args.use_flex_attn, model_args.attn_mask_type, fixed_block_size
160-
)
165+
self.attention_op = AttentionOp(model_args.use_flex_attn)
161166

162167
def init_weights(self, init_std: float):
163168
for linear in (self.wq, self.wk, self.wv):
@@ -168,6 +173,7 @@ def forward(
168173
self,
169174
x: torch.Tensor,
170175
freqs_cis: torch.Tensor,
176+
attention_masks: AttentionMasksType,
171177
):
172178
"""
173179
Forward pass of the attention module.
@@ -202,7 +208,12 @@ def forward(
202208
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
203209
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
204210

205-
output = self.sdpa(xq, xk, xv)
211+
assert isinstance(attention_masks, dict) or attention_masks is None
212+
if self.attention_op.use_flex_attn:
213+
attention_mask = attention_masks["rope" if self.use_rope else "nope"]
214+
output = self.attention_op(xq, xk, xv, attention_mask=attention_mask)
215+
else:
216+
output = self.attention_op(xq, xk, xv, attention_mask=None)
206217

207218
output = output.transpose(
208219
1, 2
@@ -335,6 +346,7 @@ def forward(
335346
self,
336347
x: torch.Tensor,
337348
freqs_cis: torch.Tensor,
349+
attention_masks: AttentionMasksType,
338350
):
339351
"""
340352
Perform a forward pass through the TransformerBlock.
@@ -347,7 +359,7 @@ def forward(
347359
torch.Tensor: Output tensor after applying attention and feedforward layers.
348360
349361
"""
350-
h = x + self.attention(self.attention_norm(x), freqs_cis)
362+
h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
351363
if self.moe_enabled:
352364
out = h + self.moe(self.ffn_norm(h))
353365
else:
@@ -447,9 +459,36 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
447459
self.model_args.rope_theta,
448460
)
449461

462+
def get_attention_masks(
463+
self, create_mask_fn: Callable, batch: torch.Tensor, eos_id: int
464+
) -> AttentionMasksType:
465+
if not self.model_args.use_flex_attn:
466+
return None
467+
468+
nope_mask_mod = get_causal_mask_mod()
469+
match self.model_args.attn_mask_type:
470+
case "causal":
471+
B = 1
472+
case "block_causal":
473+
B = batch.shape[0]
474+
rope_mask_mod = get_document_mask_mod(nope_mask_mod, batch, eos_id)
475+
case _:
476+
raise ValueError(f"Unknown attention mask type: {self.attn_mask_type}")
477+
478+
rope_mask_mod = get_fixed_block_mask_mod(
479+
nope_mask_mod, self.model_args.fixed_attn_block_size
480+
)
481+
482+
seqlen = batch.shape[1]
483+
return {
484+
"rope": create_mask_fn(rope_mask_mod, B, None, seqlen, seqlen),
485+
"nope": create_mask_fn(nope_mask_mod, B, None, seqlen, seqlen),
486+
}
487+
450488
def forward(
451489
self,
452490
tokens: torch.Tensor,
491+
attention_masks: AttentionMasksType,
453492
input_batch: torch.Tensor | None = None,
454493
):
455494
"""
@@ -473,7 +512,7 @@ def forward(
473512
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
474513

475514
for layer in self.layers.values():
476-
h = layer(h, self.freqs_cis)
515+
h = layer(h, self.freqs_cis, attention_masks)
477516

478517
h = self.norm(h) if self.norm else h
479518
output = self.output(h) if self.output else h

0 commit comments

Comments
 (0)