diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 7bec252ba..c2ec7bd77 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -110,9 +110,9 @@ def set_determinism( # reproducibility, since the autotune results may not be deterministic. from torch.nn.attention.flex_attention import flex_attention - from torchtitan.models.attention import FlexAttention + from torchtitan.models.attention import FlexAttentionWrapper - FlexAttention.flex_attn = torch.compile(flex_attention) + FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) if not world_mesh: if seed is not None: @@ -207,14 +207,6 @@ def context(cp_context: Generator[None, None, None] | None = None): torch._dynamo.utils.maybe_enable_compiled_autograd(True) ) - if cp_context is not None: - from torch.nn.attention import SDPBackend - - from torchtitan.models.attention import ScaledDotProductAttention - - if SDPBackend.MATH in ScaledDotProductAttention.backends: - ScaledDotProductAttention.backends.remove(SDPBackend.MATH) - stack.enter_context(cp_context) yield diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 8feb547b7..d3a7d39b8 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -157,15 +157,14 @@ def forward_backward_step( model_parts = self.model_parts parallel_dims = self.parallel_dims - # apply context parallelism if cp is enabled - # ensure CP handles the separate freqs_cis buffer for each pp stage inputs = input_dict["input"] - # Create the FlexAttention mask according to the input + extra_args = {} + if getattr(self.model_args, "use_flex_attn", False): - cp_mesh = ( - parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None + extra_args["attention_masks"] = model_parts[0].get_attention_masks( + input_batch=inputs, + tokenizer=self.tokenizer, ) - init_attention_mask(inputs, self.tokenizer.eos_id, cp_mesh) optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( @@ -187,11 +186,18 @@ def forward_backward_step( ) if self.pp_has_first_stage: self.pp_schedule.step( - inputs, target=targets, losses=losses, input_batch=inputs + inputs, + **extra_args, + target=targets, + losses=losses, + input_batch=inputs, ) else: self.pp_schedule.step( - target=targets, losses=losses, input_batch=inputs + **extra_args, + target=targets, + losses=losses, + input_batch=inputs, ) # accumulate losses across pipeline microbatches @@ -209,7 +215,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs) + pred = model_parts[0](inputs, **extra_args) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 18b2253b3..c0607318d 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -239,8 +239,8 @@ def apply_non_moe_tp( layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), + input_layouts=(Shard(1), None, None), + desired_input_layouts=(Replicate(), None, None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), diff --git a/torchtitan/experiments/llama4/model/model.py b/torchtitan/experiments/llama4/model/model.py index 871c20726..93ff4e89b 100644 --- a/torchtitan/experiments/llama4/model/model.py +++ b/torchtitan/experiments/llama4/model/model.py @@ -9,10 +9,20 @@ import torch import torch.nn.functional as F from torch import nn - -from torchtitan.models.attention import build_attention +from torch.nn.attention.flex_attention import and_masks + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + get_fixed_block_mask_mod, + ScaledDotProductAttentionWrapper, +) from torchtitan.models.moe import MoE -from torchtitan.protocols import ModelProtocol +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.protocols.train_spec import ModelProtocol from .args import RoPEScalingArgs, TransformerModelArgs @@ -192,9 +202,11 @@ def __init__( # values of these two variables. self.use_rope = use_rope - self.sdpa = build_attention( - model_args.use_flex_attn, model_args.attn_mask_type, fixed_block_size - ) + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -205,6 +217,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Forward pass of the attention module. @@ -239,7 +252,13 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - output = self.sdpa(xq, xk, xv) + if self.use_flex_attn: + assert isinstance(attention_masks, dict), attention_masks + attention_mask = attention_masks["rope" if self.use_rope else "nope"] + output = self.inner_attention(xq, xk, xv, block_mask=attention_mask) + else: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 @@ -372,6 +391,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Perform a forward pass through the TransformerBlock. @@ -384,7 +404,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis) + h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) if self.moe_enabled: out = h + self.moe(self.ffn_norm(h)) else: @@ -485,9 +505,40 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_scaling_args, ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + B = input_batch.shape[0] + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + + rope_mask_mod = and_masks( + *mask_mods, + get_fixed_block_mask_mod(self.model_args.fixed_attn_block_size), + ) + nope_mask_mod = and_masks(*mask_mods) + + seqlen = input_batch.shape[1] + return { + "rope": create_attention_mask(rope_mask_mod, B, None, seqlen, seqlen), + "nope": create_attention_mask(nope_mask_mod, B, None, seqlen, seqlen), + } + def forward( self, tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -511,7 +562,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, self.freqs_cis, attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/experiments/qwen3/model/model.py b/torchtitan/experiments/qwen3/model/model.py index f2a77e99c..0fff490bf 100644 --- a/torchtitan/experiments/qwen3/model/model.py +++ b/torchtitan/experiments/qwen3/model/model.py @@ -10,13 +10,23 @@ import torch import torch.nn.functional as F from torch import nn - -from torchtitan.models.attention import build_attention +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) from torchtitan.models.moe import MoE +from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol from .args import Qwen3ModelArgs + # Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_positional_embeddings.py def precompute_rope_cache( dim: int, max_seq_len: int, base: float = 1_000_000.0 @@ -133,6 +143,7 @@ def __init__(self, model_args: Qwen3ModelArgs): self.n_rep = self.n_heads // self.n_kv_heads self.head_dim = model_args.head_dim self.scaling = self.head_dim**-0.5 + self.use_flex_attn = getattr(model_args, "use_flex_attn", False) # RMSNorm added here to the here to include the q-k norm # This is one of the main differences between Llama3 and Qwen3 @@ -155,7 +166,11 @@ def __init__(self, model_args: Qwen3ModelArgs): self.wo = nn.Linear( model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -170,6 +185,7 @@ def forward( self, x: torch.Tensor, rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Forward pass of the attention module. @@ -210,7 +226,12 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - output = self.sdpa(xq, xk, xv, scale=self.scaling) + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + else: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 @@ -308,6 +329,7 @@ def forward( self, x: torch.Tensor, rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Perform a forward pass through the TransformerBlock. @@ -320,7 +342,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - x = x + self.attention(self.attention_norm(x), rope_cache) + x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) @@ -423,9 +445,31 @@ def _precompute_rope_cache(self) -> torch.Tensor: self.model_args.rope_theta, ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + def forward( self, tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -449,7 +493,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.rope_cache) + h = layer(h, self.rope_cache, attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/experiments/vlm/__init__.py b/torchtitan/experiments/vlm/__init__.py index 86c3faa5e..19452ac12 100644 --- a/torchtitan/experiments/vlm/__init__.py +++ b/torchtitan/experiments/vlm/__init__.py @@ -35,7 +35,7 @@ def _get_dict(obj) -> dict[str, Any]: llama3_siglip2_configs = { "debugmodel": Llama3Siglip2ModelArgs( - **_get_dict(llama3_configs["debugmodel"]), + **_get_dict(llama3_configs["debugmodel_flex_attn"]), encoder=Siglip2ModelArgs( dim=128, ffn_dim=256, diff --git a/torchtitan/experiments/vlm/model/model.py b/torchtitan/experiments/vlm/model/model.py index 71c8a7395..712cd8058 100644 --- a/torchtitan/experiments/vlm/model/model.py +++ b/torchtitan/experiments/vlm/model/model.py @@ -7,8 +7,11 @@ import einops as E import torch from torch import nn +from torch.nn.attention.flex_attention import BlockMask +from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.models.llama3 import Transformer as Llama3 +from torchtitan.protocols.model import AttentionMasksType from ..datasets.mm_datasets import SpecialTokens @@ -71,28 +74,49 @@ def init_weights(self, buffer_device=None): if self.projector is not None: self.projector.init_weights() + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + masks = super().get_attention_masks(input_batch, tokenizer, extra_inputs) + assert isinstance(masks, BlockMask) + if self.encoder is not None: + encoder_masks = self.encoder.get_attention_masks( + input_batch, tokenizer, extra_inputs + ) + assert isinstance(encoder_masks, BlockMask) + return {"llama3_masks": masks, "encoder_masks": encoder_masks} + def forward( self, tokens: torch.Tensor, pixel_values: torch.Tensor, grid_thw: torch.Tensor, special_tokens: SpecialTokens, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h_BSD = self.tok_embeddings(tokens) if self.tok_embeddings else tokens if self.encoder is not None: + assert ( + attention_masks is not None + ), "encoder only allows FlexAttention, so the llama3 must use FlexAttention as well." grid_hw = grid_thw[:, :, 1:] # Siglip2 only support image hw pixel_masks = E.reduce(grid_hw != -1, "n l hw -> n l", reduction="all") - i_NLD = self.encoder(pixel_values, pixel_masks, grid_hw) + i_NLD = self.encoder( + pixel_values, pixel_masks, grid_hw, attention_masks["encoder_masks"] + ) i_NLD = self.projector(i_NLD) h_BSD = _scatter_img_tokens( h_BSD, tokens, i_NLD, pixel_masks, special_tokens.img_id ) for layer in self.layers.values(): - h_BSD = layer(h_BSD, self.freqs_cis) + h_BSD = layer(h_BSD, self.freqs_cis, attention_masks["llama3_masks"]) h_BSD = self.norm(h_BSD) if self.norm else h_BSD output = self.output(h_BSD) if self.output else h_BSD diff --git a/torchtitan/experiments/vlm/model/siglip2.py b/torchtitan/experiments/vlm/model/siglip2.py index a1183f7cb..69278350d 100644 --- a/torchtitan/experiments/vlm/model/siglip2.py +++ b/torchtitan/experiments/vlm/model/siglip2.py @@ -8,8 +8,16 @@ import torch import torch.nn.functional as F from torch import nn +from torch.nn.attention.flex_attention import and_masks, BlockMask -from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, +) +from torchtitan.protocols.model import AttentionMasksType from .args import Siglip2ModelArgs @@ -125,11 +133,9 @@ def __init__(self, args: Siglip2ModelArgs): self.v_proj = nn.Linear(self.dim, self.dim) self.out_proj = nn.Linear(self.dim, self.dim) - self.attn = build_attention( - use_flex_attn=True, attn_mask_type=args.attn_mask_type - ) + self.inner_attention = FlexAttentionWrapper() - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, attention_masks: AttentionMasksType): xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) # Use self.head_dim instead of `n_heads` to infer the actual @@ -139,7 +145,8 @@ def forward(self, x: torch.Tensor): xk = E.rearrange(xk, "b l (h d) -> b h l d", d=self.head_dim) xv = E.rearrange(xv, "b l (h d) -> b h l d", d=self.head_dim) - output = self.attn(xq, xk, xv) + assert isinstance(attention_masks, BlockMask) + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) output = E.rearrange(output, "b h l d -> b l (h d)").contiguous() return self.out_proj(output) @@ -174,8 +181,10 @@ def __init__(self, args: Siglip2ModelArgs): self.layer_norm2 = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) self.mlp = FeedForward(args) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.self_attn(self.layer_norm1(x)) + def forward( + self, x: torch.Tensor, attention_masks: AttentionMasksType + ) -> torch.Tensor: + x = x + self.self_attn(self.layer_norm1(x), attention_masks) x = x + self.mlp(self.layer_norm2(x)) return x @@ -198,18 +207,46 @@ def __init__(self, args: Siglip2ModelArgs): ) self.post_layernorm = nn.LayerNorm(args.dim, eps=args.layer_norm_eps) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + + # TODO: this is duplicated in the main model forward. + # TODO: is this really required? Can we call this `get_attention_masks` + # inside the main model forward? At that time PP should already split the + # grid_thw correctly. + grid_hw = extra_inputs["grid_thw"][:, :, 1:] # Siglip2 only support image hw + pixel_masks = E.reduce(grid_hw != -1, "n l hw -> n l", reduction="all") + + mask_mods = [get_causal_mask_mod()] + match self.args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = pixel_masks.shape[0] + mask_mods.append(get_document_mask_mod(pixel_masks, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, pixel_masks.shape[1], pixel_masks.shape[1] + ) + def forward( self, pixel_values_NLD: torch.FloatTensor, pixel_masks_NL: torch.BoolTensor, grid_hw: torch.LongTensor, + attention_masks: AttentionMasksType, ): - init_attention_mask(pixel_masks_NL, eos_id=self.eos_id) - h = self.embeddings(pixel_values_NLD, grid_hw) for layer in self.layers.values(): - h = layer(h) + h = layer(h, attention_masks) h = self.post_layernorm(h) return h diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 277d64be1..bf963a5b5 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -6,238 +6,182 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from typing import Callable, ClassVar +import functools +from collections.abc import Callable +from typing import ClassVar import torch import torch.nn.functional as F from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention.flex_attention import ( _mask_mod_signature, + AuxOutput, BlockMask, create_block_mask, flex_attention, ) -from torchtitan.tools.utils import has_cuda_capability -# FlexAttention mask type. For each mask type, we initialize it at most once per -# batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to -# track the initialized mask. -FLEX_ATTN_MASK_T = tuple[str, int | None] +__all__ = [ + "FlexAttentionWrapper", + "ScaledDotProductAttentionWrapper", + "get_causal_mask_mod", + "get_document_mask_mod", + "get_fixed_block_mask_mod", + "create_attention_mask", +] -class FlexAttention(torch.nn.Module): - """FlexAttention module that uses torch.nn.attention.flex_attention. +class FlexAttentionWrapper(torch.nn.Module): + """Wrapper around `flex_attention` to make it torch.compile and CP compatible. - This module is a wrapper around torch.nn.attention.flex_attention. This module - implements certain common attention types, such as causal and block_causal. + This wrapper serves two purposes: + 1) Invoke `torch.compile` with a valid mode "max-autotune-no-cudagraphs" to + achieve good performance. + 2) Being a wrapper allows us to apply _ContextParallel to it. - Args: - attn_mask_type (str): The type of attention mask. Currently, we support - "causal" and "block_causal". "causal" means the lower triangle of the - attention matrix is masked. "block_causal" means the attention matrix - is divided into blocks, where block boundary is defined by EOS token, - and the lower triangle of each block is masked. - fixed_block_size (int | None): The block size to be used to perform attention. - If specified, each sequence will be further divided to blocks, where each - block has the maximum size of ``fixed_block_size``. A query will only attend - to the keys within the same block. + Note: + The forward function must have q, k, v as the first three arguments, and + block_mask as a keyword argument to be compatible with _ContextParallel. """ - # We registered flex_attention related attributes as class variables as we - # need to amortize the cost of compilation. - flex_attn: ClassVar[Callable] = torch.compile( + _compiled_flex_attn: ClassVar[Callable] = torch.compile( flex_attention, mode="max-autotune-no-cudagraphs" ) - compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) - used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set() - # Attention mask type to the created BlockMask. - # This allows us to keep track the created block masks for each - # new batch. We will use this to update the block mask when a - # new batch is created. This also allows user to create different - # block masks for different layers. - block_masks: ClassVar[dict[FLEX_ATTN_MASK_T, BlockMask]] = {} - - # Instance variables. - attn_mask_type: str - - def __init__( - self, attn_mask_type: str, fixed_block_size: int | None = None - ) -> None: - super().__init__() - if attn_mask_type not in ["causal", "block_causal"]: - raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") - self.attn_mask_type = attn_mask_type - self.fixed_block_size = fixed_block_size - - FlexAttention.used_attn_mask_types.add(self.mask_key) - - @property - def mask_key(self) -> FLEX_ATTN_MASK_T: - return (self.attn_mask_type, self.fixed_block_size) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + *, + block_mask: BlockMask, scale: float | None = None, - ) -> torch.Tensor: - block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) - - @staticmethod - def _get_causal_mask_mod() -> _mask_mod_signature: - def causal_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ): - return q_idx >= kv_idx - - return causal_mask - - @staticmethod - def _get_block_causal_mask_mod( - batch: torch.Tensor, eos_id: int - ) -> _mask_mod_signature: - # batch is [b, s, h, d] shape - mask = batch == eos_id - mask[:, -1] = True - acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1) - seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32) - seq_idx[:, 1:] = acc_mask[:, :-1] - - def block_causal_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ): - return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx) - - return block_causal_mask - - @staticmethod - def _fixed_block_mask_mod( - mask_mod: _mask_mod_signature, fixed_block_size: int - ) -> _mask_mod_signature: - """ - Given an arbitrary mask_mod, divide the input sequence to blocks - and only allow attention within the same block. - - Args: - mask_mod: The mask mod to apply to the documents - fixed_block_size: The number of tokens in each block. - """ - - # Credit to @drisspg. - def blocked_mask_mod( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor - ): - # Get the block index of the query and key - q_block = q_idx // fixed_block_size - kv_block = kv_idx // fixed_block_size - # Only allow attention within the same block - same_block = q_block == kv_block - # Apply the original mask mod - inner_mask = mask_mod( - b, h, q_idx % fixed_block_size, kv_idx % fixed_block_size - ) - - return same_block & inner_mask - - blocked_mask_mod.__name__ = ( - f"blocked_mask_mod_{mask_mod.__name__}_fixed_block_size_{fixed_block_size}" + ) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]: + # 1. _compiled_flex_attn has to be a class variable, otherwise there will + # be multiple compiled flex_attention instances, which can be slow. + # 2. `self._compiled_flex_attn` is not correct, `self` will be passed in + # as the first argument, which will cause an error. + # `FlexAttentionWrapper._compiled_flex_attn` is correct. + return FlexAttentionWrapper._compiled_flex_attn( + q, k, v, block_mask=block_mask, scale=scale ) - return blocked_mask_mod - - @staticmethod - @torch.no_grad() - def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: - # batch is [b, s, h, d] shape - for mask_key in FlexAttention.used_attn_mask_types: - attn_mask_type, fixed_block_size = mask_key - match attn_mask_type: - case "causal": - if FlexAttention.block_masks.get(mask_key, None) is not None: - continue - # We don't care about batch dimension -- - # all samples have the same lower triangle mask. - batch_dimension = 1 - mask_mod = FlexAttention._get_causal_mask_mod() - case "block_causal": - if eos_id is None: - raise RuntimeError( - "eos_id must be provided for block_causal mask." - ) - batch_dimension = batch.shape[0] - mask_mod = FlexAttention._get_block_causal_mask_mod(batch, eos_id) - case _: - raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") - - if fixed_block_size is not None and fixed_block_size > 0: - mask_mod = FlexAttention._fixed_block_mask_mod( - mask_mod, fixed_block_size - ) - - seq_len = batch.shape[1] - block_mask = FlexAttention.compiled_create_block_mask( - mask_mod, batch_dimension, None, seq_len, seq_len - ) - FlexAttention.block_masks[mask_key] = block_mask - - -class ScaledDotProductAttention(torch.nn.Module): - backends: ClassVar[list[SDPBackend]] = [] - - def __init__(self, attn_mask_type: str) -> None: + +class ScaledDotProductAttentionWrapper(torch.nn.Module): + """Wrapper around `F.scaled_dot_product_attention` to make it CP compatible. + + This wrapper is needed because `F.scaled_dot_product_attention` is not + a torch.nn.Module, and thus cannot be applied with _ContextParallel. + We need to wrap it into a torch.nn.Module. + + Note: + The forward function must have q, k, v as the first three arguments to be + compatible with _ContextParallel. + """ + + # TODO: remove sdpa_backends after PyTorch 2.9 is released. + sdpa_backends: ClassVar[list[SDPBackend]] = [] + + def __init__(self) -> None: super().__init__() - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) - - ScaledDotProductAttention._init_backend() - - @classmethod - def _init_backend(cls) -> None: - if cls.backends: - return - - # Add CuDNN on B200 w/ highest priority - cls.backends = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.MATH, - ] - if has_cuda_capability(10, 0): - cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) + if not self.sdpa_backends: + self.sdpa_backends = [ + SDPBackend.CUDNN_ATTENTION, + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + ] def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + *, scale: float | None = None, ) -> torch.Tensor: - assert self.backends, "SDPA Backends should not be empty." - with sdpa_kernel(self.backends, set_priority=True): - return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale) - - -def build_attention( - use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None -): - if use_flex_attn: - return FlexAttention(attn_mask_type, fixed_block_size) - else: - if fixed_block_size is not None: - raise ValueError( - "TorchTitan with SDPA currently does not support fixed_block_size." - ) - if attn_mask_type != "causal": - raise ValueError( - "TorchTitan with SDPA currently only supports causal mask." - ) - return ScaledDotProductAttention(attn_mask_type) - - -def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: - FlexAttention.init_attention_mask(batch, eos_id) + with sdpa_kernel(self.sdpa_backends, set_priority=True): + return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True) + + +# We cannot do inner function/closure because we won't be able to cache it -- +# if we an inner function, a new closure will be created every time +# `get_causal_mask_mod` is called. +def _causal_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor +) -> torch.Tensor: + """Causal mask that prevents attention to future tokens.""" + return q_idx >= kv_idx + + +def get_causal_mask_mod() -> _mask_mod_signature: + """Returns a causal mask modifier for flex attention. + + Returns: + A mask modifier function that implements causal masking. + """ + return _causal_mask + + +def get_document_mask_mod(batch: torch.Tensor, eos_id: int) -> _mask_mod_signature: + """Creates a document mask that prevents attention across document boundaries. + + Args: + batch: Input batch tensor with shape [b, s, h, d] + eos_id: End-of-sequence token ID that marks document boundaries + + Returns: + A mask modifier function that implements document-level masking. + """ + # batch is [b, s, h, d] shape + eos_mask = batch == eos_id + eos_mask[:, -1] = True + cumulative_mask = torch.cumsum(torch.where(eos_mask, 1, 0), dim=1) + sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32) + sequence_indices[:, 1:] = cumulative_mask[:, :-1] + + def document_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx] + + return document_mask + + +def get_fixed_block_mask_mod(fixed_block_size: int) -> _mask_mod_signature: + """ + Divide the input sequence into blocks and only allow attention within the same block. + + Args: + fixed_block_size: The number of tokens in each block. + + Returns: + A mask modifier function that implements block-wise attention masking. + """ + + # Credit to @drisspg. + def blocked_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + # Get the block index of the query and key + q_block = q_idx // fixed_block_size + kv_block = kv_idx // fixed_block_size + # Only allow attention within the same block + return q_block == kv_block + + blocked_mask_mod.__name__ = f"blocked_mask_mod_fixed_block_size_{fixed_block_size}" + + return blocked_mask_mod + + +_compiled_create_block_mask = torch.compile(create_block_mask) + + +@functools.lru_cache(4) +def create_attention_mask(*args, **kwargs): + """Create an attention mask using compiled create_block_mask. + + This function is cached to avoid recreating BlockMasks for the same + argumens. + """ + return _compiled_create_block_mask(*args, **kwargs) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 1c73cef79..fc79e5ba4 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -84,6 +84,7 @@ def parallelize_deepseekv3( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, + use_flex_attn=use_flex_attn, ) maybe_enable_async_tp(job_config, world_mesh["tp"]) @@ -181,6 +182,7 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, + use_flex_attn: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -210,6 +212,18 @@ def apply_non_moe_tp( PrepareModuleInput, ) + if use_flex_attn: + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) + else: + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. @@ -218,8 +232,8 @@ def apply_non_moe_tp( layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate()), - desired_input_layouts=(Replicate(), Replicate()), + input_layouts=(Shard(1), Replicate(), None), + desired_input_layouts=(Replicate(), Replicate(), None), ), # NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor # so that the intermedidate results k is generated as a DTensor and its gradient is @@ -228,11 +242,7 @@ def apply_non_moe_tp( "attention.wkv_b": colwise_parallel(use_local_output=False), "attention.kv_norm": NoParallel(use_local_output=False), # NOTE: use_local_output=True so that the inputs to FlexAttention are plain Tensors - "attention.sdpa": prepare_module_input( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(1), Shard(1), Shard(1)), - use_local_output=True, - ), + "attention.inner_attention": attention_kernel_plan, "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index dc612fafb..d5bc9b101 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -5,13 +5,22 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Tuple import torch from torch import nn -from torchtitan.models.attention import build_attention +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) from torchtitan.models.moe import FeedForward, MoE +from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs @@ -58,7 +67,7 @@ def find_correction_dim( def find_correction_range( low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """ Computes the range of correction dimensions for rotary positional embeddings. @@ -70,7 +79,7 @@ def find_correction_range( max_seq_len (int): Maximum sequence length. Returns: - Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. """ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) @@ -175,12 +184,17 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0 self.softmax_scale = self.softmax_scale * mscale * mscale - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Forward pass for the Multi-Head Latent Attention (MLA) Layer. @@ -231,7 +245,14 @@ def forward( k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) - output = self.sdpa(q, k, v, scale=self.softmax_scale) + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask) + output = self.inner_attention( + q, k, v, block_mask=attention_masks, scale=self.softmax_scale + ) + else: + assert attention_masks is None + output = self.inner_attention(q, k, v, scale=self.softmax_scale) # Reshape and project output output = output.transpose( @@ -284,7 +305,12 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 self.layer_id = layer_id - def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): """ Forward pass for the Transformer block. @@ -295,7 +321,7 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): Returns: torch.Tensor: Output tensor with the same shape as the input. """ - x = x + self.attention(self.attention_norm(x), freqs_cis) + x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) else: @@ -360,9 +386,31 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: b=cutoff_factor * final_out_std, ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + def forward( self, tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -385,7 +433,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, self.freqs_cis, attention_masks) h = self.norm(h) if self.norm is not None else h output = self.output(h) if self.output is not None else h return output diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 89066e865..4944af569 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -207,8 +207,8 @@ def apply_tp( layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), + input_layouts=(Shard(1), None, None), + desired_input_layouts=(Replicate(), None, None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 753ffae09..6f10719d1 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -11,8 +11,17 @@ import torch import torch.nn.functional as F from torch import nn - -from torchtitan.models.attention import build_attention +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) +from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol from .args import RoPEScalingArgs, TransformerModelArgs @@ -181,7 +190,12 @@ def __init__(self, model_args: TransformerModelArgs): self.wo = nn.Linear( model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -192,6 +206,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Forward pass of the attention module. @@ -225,7 +240,16 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - output = self.sdpa(xq, xk, xv) + assert ( + isinstance(attention_masks, BlockMask) or attention_masks is None + ), attention_masks + + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + else: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) output = output.transpose( 1, 2 @@ -321,6 +345,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, ): """ Perform a forward pass through the TransformerBlock. @@ -333,7 +358,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis) + h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -428,9 +453,31 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_scaling_args, ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + def forward( self, tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -454,7 +501,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, self.freqs_cis, attention_masks=attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index a4f28bc89..a713bec65 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -11,9 +11,16 @@ import torch import torch.nn as nn +from torch.nn.attention.flex_attention import BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer + from torchtitan.config import JobConfig +AttentionMasksType = dict[str, BlockMask] | BlockMask + + @dataclass class BaseModelArgs: """All ModelArgs should inherit from this class. @@ -53,3 +60,13 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: buffer_device: Optional device to place buffers on during initialization. """ pass + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + raise NotImplementedError( + "This model does not support attention masking/Flex Attention." + ) diff --git a/torchtitan/train.py b/torchtitan/train.py index 287828d86..6441ff0b2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -24,7 +24,6 @@ ) from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils -from torchtitan.models.attention import init_attention_mask from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger @@ -416,12 +415,21 @@ def forward_backward_step( inputs = input_dict["input"] extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} - # Create the FlexAttention mask according to the input + # For arguments, like attention_masks, we have to put them in a separate + # dict as extra_inputs are not forwarded to other stages in PP, but + # extra_args are. + extra_args = {} + if getattr(self.model_args, "use_flex_attn", False): - init_attention_mask(inputs, self.tokenizer.eos_id) + extra_args["attention_masks"] = model_parts[0].get_attention_masks( + input_batch=inputs, + tokenizer=self.tokenizer, + extra_inputs=extra_inputs, + ) # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage + cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=parallel_dims.world_mesh["cp"], @@ -444,13 +452,17 @@ def forward_backward_step( self.pp_schedule.step( inputs, **extra_inputs, + **extra_args, target=targets, losses=losses, input_batch=inputs, ) else: self.pp_schedule.step( - target=targets, losses=losses, input_batch=inputs + **extra_args, + target=targets, + losses=losses, + input_batch=inputs, ) # accumulate losses across pipeline microbatches @@ -468,7 +480,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, **extra_inputs) + pred = model_parts[0](inputs, **extra_inputs, **extra_args) loss = self.loss_fn(pred, labels) # need to free pred before bwd to avoid peaking memory del pred