Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def set_determinism(
# reproducibility, since the autotune results may not be deterministic.
from torch.nn.attention.flex_attention import flex_attention

from torchtitan.models.attention import FlexAttention
from torchtitan.models.attention import FlexAttentionWrapper

FlexAttention.flex_attn = torch.compile(flex_attention)
FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention)

if not world_mesh:
if seed is not None:
Expand Down Expand Up @@ -207,14 +207,6 @@ def context(cp_context: Generator[None, None, None] | None = None):
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)

if cp_context is not None:
from torch.nn.attention import SDPBackend

from torchtitan.models.attention import ScaledDotProductAttention

if SDPBackend.MATH in ScaledDotProductAttention.backends:
ScaledDotProductAttention.backends.remove(SDPBackend.MATH)

stack.enter_context(cp_context)

yield
Expand Down
24 changes: 15 additions & 9 deletions torchtitan/experiments/forge/example_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,14 @@ def forward_backward_step(
model_parts = self.model_parts
parallel_dims = self.parallel_dims

# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
inputs = input_dict["input"]
# Create the FlexAttention mask according to the input
extra_args = {}

if getattr(self.model_args, "use_flex_attn", False):
cp_mesh = (
parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
extra_args["attention_masks"] = model_parts[0].get_attention_masks(
input_batch=inputs,
tokenizer=self.tokenizer,
)
init_attention_mask(inputs, self.tokenizer.eos_id, cp_mesh)

optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
Expand All @@ -187,11 +186,18 @@ def forward_backward_step(
)
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs, target=targets, losses=losses, input_batch=inputs
inputs,
**extra_args,
target=targets,
losses=losses,
input_batch=inputs,
)
else:
self.pp_schedule.step(
target=targets, losses=losses, input_batch=inputs
**extra_args,
target=targets,
losses=losses,
input_batch=inputs,
)

# accumulate losses across pipeline microbatches
Expand All @@ -209,7 +215,7 @@ def forward_backward_step(
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs)
pred = model_parts[0](inputs, **extra_args)
loss = self.loss_fn(pred, labels)
# need to free to before bwd to avoid peaking memory
del pred
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def apply_non_moe_tp(
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
input_layouts=(Shard(1), None, None),
desired_input_layouts=(Replicate(), None, None),
),
"attention.wq": colwise_parallel(),
"attention.wk": colwise_parallel(),
Expand Down
69 changes: 60 additions & 9 deletions torchtitan/experiments/llama4/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,20 @@
import torch
import torch.nn.functional as F
from torch import nn

from torchtitan.models.attention import build_attention
from torch.nn.attention.flex_attention import and_masks

from torchtitan.components.tokenizer import BaseTokenizer
from torchtitan.models.attention import (
create_attention_mask,
FlexAttentionWrapper,
get_causal_mask_mod,
get_document_mask_mod,
get_fixed_block_mask_mod,
ScaledDotProductAttentionWrapper,
)
from torchtitan.models.moe import MoE
from torchtitan.protocols import ModelProtocol
from torchtitan.protocols.model import AttentionMasksType
from torchtitan.protocols.train_spec import ModelProtocol

from .args import RoPEScalingArgs, TransformerModelArgs

Expand Down Expand Up @@ -192,9 +202,11 @@ def __init__(
# values of these two variables.
self.use_rope = use_rope

self.sdpa = build_attention(
model_args.use_flex_attn, model_args.attn_mask_type, fixed_block_size
)
self.use_flex_attn = model_args.use_flex_attn
if self.use_flex_attn:
self.inner_attention = FlexAttentionWrapper()
else:
self.inner_attention = ScaledDotProductAttentionWrapper()

def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
Expand All @@ -205,6 +217,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None,
):
"""
Forward pass of the attention module.
Expand Down Expand Up @@ -239,7 +252,13 @@ def forward(
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

output = self.sdpa(xq, xk, xv)
if self.use_flex_attn:
assert isinstance(attention_masks, dict), attention_masks
attention_mask = attention_masks["rope" if self.use_rope else "nope"]
output = self.inner_attention(xq, xk, xv, block_mask=attention_mask)
else:
assert attention_masks is None
output = self.inner_attention(xq, xk, xv)

output = output.transpose(
1, 2
Expand Down Expand Up @@ -372,6 +391,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None,
):
"""
Perform a forward pass through the TransformerBlock.
Expand All @@ -384,7 +404,7 @@ def forward(
torch.Tensor: Output tensor after applying attention and feedforward layers.

"""
h = x + self.attention(self.attention_norm(x), freqs_cis)
h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
if self.moe_enabled:
out = h + self.moe(self.ffn_norm(h))
else:
Expand Down Expand Up @@ -485,9 +505,40 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
self.model_args.rope_scaling_args,
)

def get_attention_masks(
self,
input_batch: torch.Tensor,
tokenizer: BaseTokenizer,
extra_inputs: dict[str, torch.Tensor] | None = None,
) -> AttentionMasksType:
mask_mods = [get_causal_mask_mod()]
match self.model_args.attn_mask_type:
case "causal":
B = 1
case "block_causal":
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
B = input_batch.shape[0]
case _:
raise ValueError(
f"Unknown attention mask type: {self.model_args.attn_mask_type}"
)

rope_mask_mod = and_masks(
*mask_mods,
get_fixed_block_mask_mod(self.model_args.fixed_attn_block_size),
)
nope_mask_mod = and_masks(*mask_mods)

seqlen = input_batch.shape[1]
return {
"rope": create_attention_mask(rope_mask_mod, B, None, seqlen, seqlen),
"nope": create_attention_mask(nope_mask_mod, B, None, seqlen, seqlen),
}

def forward(
self,
tokens: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
input_batch: torch.Tensor | None = None,
):
"""
Expand All @@ -511,7 +562,7 @@ def forward(
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

for layer in self.layers.values():
h = layer(h, self.freqs_cis)
h = layer(h, self.freqs_cis, attention_masks)

h = self.norm(h) if self.norm else h
output = self.output(h) if self.output else h
Expand Down
56 changes: 50 additions & 6 deletions torchtitan/experiments/qwen3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,23 @@
import torch
import torch.nn.functional as F
from torch import nn

from torchtitan.models.attention import build_attention
from torch.nn.attention.flex_attention import and_masks, BlockMask

from torchtitan.components.tokenizer import BaseTokenizer
from torchtitan.models.attention import (
create_attention_mask,
FlexAttentionWrapper,
get_causal_mask_mod,
get_document_mask_mod,
ScaledDotProductAttentionWrapper,
)
from torchtitan.models.moe import MoE
from torchtitan.protocols.model import AttentionMasksType
from torchtitan.protocols.train_spec import ModelProtocol

from .args import Qwen3ModelArgs


# Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_positional_embeddings.py
def precompute_rope_cache(
dim: int, max_seq_len: int, base: float = 1_000_000.0
Expand Down Expand Up @@ -133,6 +143,7 @@ def __init__(self, model_args: Qwen3ModelArgs):
self.n_rep = self.n_heads // self.n_kv_heads
self.head_dim = model_args.head_dim
self.scaling = self.head_dim**-0.5
self.use_flex_attn = getattr(model_args, "use_flex_attn", False)

# RMSNorm added here to the here to include the q-k norm
# This is one of the main differences between Llama3 and Qwen3
Expand All @@ -155,7 +166,11 @@ def __init__(self, model_args: Qwen3ModelArgs):
self.wo = nn.Linear(
model_args.n_heads * self.head_dim, model_args.dim, bias=False
)
self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type)

if self.use_flex_attn:
self.inner_attention = FlexAttentionWrapper()
else:
self.inner_attention = ScaledDotProductAttentionWrapper()

def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
Expand All @@ -170,6 +185,7 @@ def forward(
self,
x: torch.Tensor,
rope_cache: torch.Tensor,
attention_masks: AttentionMasksType | None,
):
"""
Forward pass of the attention module.
Expand Down Expand Up @@ -210,7 +226,12 @@ def forward(
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

output = self.sdpa(xq, xk, xv, scale=self.scaling)
if self.use_flex_attn:
assert isinstance(attention_masks, BlockMask), attention_masks
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
else:
assert attention_masks is None
output = self.inner_attention(xq, xk, xv)

output = output.transpose(
1, 2
Expand Down Expand Up @@ -308,6 +329,7 @@ def forward(
self,
x: torch.Tensor,
rope_cache: torch.Tensor,
attention_masks: AttentionMasksType | None,
):
"""
Perform a forward pass through the TransformerBlock.
Expand All @@ -320,7 +342,7 @@ def forward(
torch.Tensor: Output tensor after applying attention and feedforward layers.

"""
x = x + self.attention(self.attention_norm(x), rope_cache)
x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks)

if self.moe_enabled:
x = x + self.moe(self.ffn_norm(x))
Expand Down Expand Up @@ -423,9 +445,31 @@ def _precompute_rope_cache(self) -> torch.Tensor:
self.model_args.rope_theta,
)

def get_attention_masks(
self,
input_batch: torch.Tensor,
tokenizer: BaseTokenizer,
extra_inputs: dict[str, torch.Tensor] | None = None,
) -> AttentionMasksType:
mask_mods = [get_causal_mask_mod()]
match self.model_args.attn_mask_type:
case "causal":
B = 1
case "block_causal":
B = input_batch.shape[0]
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
case _:
raise ValueError(
f"Unknown attention mask type: {self.model_args.attn_mask_type}"
)
return create_attention_mask(
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
)

def forward(
self,
tokens: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
input_batch: torch.Tensor | None = None,
):
"""
Expand All @@ -449,7 +493,7 @@ def forward(
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

for layer in self.layers.values():
h = layer(h, self.rope_cache)
h = layer(h, self.rope_cache, attention_masks)

h = self.norm(h) if self.norm else h
output = self.output(h) if self.output else h
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _get_dict(obj) -> dict[str, Any]:

llama3_siglip2_configs = {
"debugmodel": Llama3Siglip2ModelArgs(
**_get_dict(llama3_configs["debugmodel"]),
**_get_dict(llama3_configs["debugmodel_flex_attn"]),
encoder=Siglip2ModelArgs(
dim=128,
ffn_dim=256,
Expand Down
Loading