Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions nemo_automodel/_transformers/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
get_hf_config,
get_is_hf_model,
no_hf_meta_device,
resolve_sdpa_method,
)

if not hasattr(_gen_utils, "NEED_SETUP_CACHE_CLASSES_MAPPING"):
Expand Down Expand Up @@ -371,7 +372,7 @@ def from_pretrained(
*model_args,
use_liger_kernel: bool = True,
use_sdpa_patching: bool = True,
sdpa_method: Optional[List[SDPBackend]] = None,
sdpa_method: Optional[List[Union[SDPBackend, str]]] = None,
torch_dtype="auto",
attn_implementation: str = DEFAULT_ATTN_IMPLEMENTATION,
quantization_config=None,
Expand Down Expand Up @@ -408,8 +409,11 @@ def from_pretrained(
the model with Liger kernels for faster inference/training.
use_sdpa_patching (bool, default=True): If `True`, patch the
model with SDPA-based attention optimizations.
sdpa_method (list[SDPBackend] | None, optional): Explicit list of
sdpa_method (list[SDPBackend | str] | None, optional): Explicit list of
SDPA back-ends to consider when `use_sdpa_patching=True`.
Accepts both SDPBackend enum values and string names (e.g.
``["flash_attention", "efficient_attention"]``). When ``None``,
auto-selects based on CP and activation checkpointing.
torch_dtype (str | torch.dtype | Literal["auto"], default="auto"):
Data type passed to the underlying `from_pretrained` call.
attn_implementation (str, optional):
Expand Down Expand Up @@ -480,6 +484,8 @@ def from_pretrained(
raise
is_hf_model = get_is_hf_model(hf_config, force_hf)

sdpa_method = resolve_sdpa_method(sdpa_method, device_mesh, activation_checkpointing)

return cls._build_model(
pretrained_model_name_or_path,
*model_args,
Expand Down Expand Up @@ -511,7 +517,7 @@ def from_config(
*model_args,
use_liger_kernel: bool = True,
use_sdpa_patching: bool = True,
sdpa_method: Optional[List[SDPBackend]] = None,
sdpa_method: Optional[List[Union[SDPBackend, str]]] = None,
torch_dtype: Union[str, torch.dtype] = "auto",
attn_implementation: str = DEFAULT_ATTN_IMPLEMENTATION,
quantization_config=None,
Expand Down Expand Up @@ -582,6 +588,8 @@ def from_config(
_consume_config_overrides(config, kwargs)
is_hf_model = get_is_hf_model(config, force_hf)

sdpa_method = resolve_sdpa_method(sdpa_method, device_mesh, activation_checkpointing)

return cls._build_model(
config,
*model_args,
Expand Down
27 changes: 27 additions & 0 deletions nemo_automodel/_transformers/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,23 @@ def instantiate_infrastructure(
return model_wrapper, autopipeline, parallelize_fn, qat_quantizer


def _uses_te_attention(model) -> bool:
"""Return True if any self_attn module uses TE's DotProductAttention."""
try:
from transformer_engine.pytorch.attention import DotProductAttention
except ImportError:
return False

model_parts = model.parts if hasattr(model, "parts") else [model]
for part in model_parts:
for name, module in part.named_modules():
if name.endswith("self_attn"):
attn_module = getattr(module, "attn_module", None)
if isinstance(attn_module, DotProductAttention):
return True
return False


# apply_model_infrastructure -- the main post-init orchestration function
def apply_model_infrastructure(
model,
Expand Down Expand Up @@ -509,4 +526,14 @@ def apply_model_infrastructure(
raise
print_trainable_parameters(model) # Once model's been sharded

# Attach CP attention-mask hooks for dense (non-TE) context parallelism.
# These hooks strip attention_mask and set is_causal=True on self_attn modules
# so that SDPA handles causal masking internally (compatible with DTensor sharding).
if mesh.cp_size > 1 and not _uses_te_attention(model):
from nemo_automodel.components.distributed.cp_utils import attach_context_parallel_hooks

model_parts = model.parts if hasattr(model, "parts") else [model]
for mp in model_parts:
attach_context_parallel_hooks(mp)

return model
60 changes: 60 additions & 0 deletions nemo_automodel/_transformers/model_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,63 @@ def _filter_kwargs_for_init(model_cls, kwargs: dict) -> dict:
# We pass `config` positionally.
allowed.discard("config")
return {k: v for k, v in kwargs.items() if k in allowed}


def resolve_sdpa_method(
sdpa_method: list | None = None,
device_mesh=None,
activation_checkpointing: bool = False,
) -> list["SDPBackend"] | None: # noqa: F821
"""Resolve SDPA backend list from config strings or runtime constraints.

When *sdpa_method* is provided (e.g. from YAML), string values are
converted to :class:`torch.nn.attention.SDPBackend` enum members.
Already-resolved ``SDPBackend`` values are passed through unchanged.
When ``None``, automatic defaults are applied based on context
parallelism and activation checkpointing settings.

Valid string values (case-insensitive): ``flash_attention``,
``efficient_attention``, ``math``, ``cudnn_attention``.

Args:
sdpa_method: List of backend name strings or SDPBackend enum values,
or ``None`` to use automatic defaults.
device_mesh: Device mesh for distributed training.
activation_checkpointing: Whether activation checkpointing is enabled.

Returns:
Ordered list of :class:`SDPBackend` members, or ``None`` to use
PyTorch's default selection.
"""
from torch.nn.attention import SDPBackend

_NAME_TO_BACKEND = dict(SDPBackend.__members__)

if sdpa_method is not None:
backends = []
for entry in sdpa_method:
if isinstance(entry, str):
key = entry.upper()
if key not in _NAME_TO_BACKEND:
raise ValueError(f"Unknown SDPA backend '{entry}'. Valid values: {sorted(_NAME_TO_BACKEND.keys())}")
backends.append(_NAME_TO_BACKEND[key])
else:
backends.append(entry)
return backends

# Auto-select based on runtime constraints
cp_size = 1
if device_mesh is not None and "cp" in device_mesh.mesh_dim_names:
cp_size = device_mesh["cp"].size()

if cp_size > 1:
# CP with DTensor only supports flash and efficient backends;
# MATH is not compatible with DTensor.
return [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
elif activation_checkpointing:
# For activation checkpointing, disable cudnn SDPA backend because
# it may not be selected during recomputation, causing:
# "Recomputed values have different metadata than during forward pass."
return [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]

return None
9 changes: 6 additions & 3 deletions nemo_automodel/components/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Any, Callable

import torch
Expand Down Expand Up @@ -48,13 +47,17 @@ def initialize_attn_module_and_func(
attn_func = attn_module.__call__
return attn_module, attn_func
elif attn_impl == "sdpa":
attn_func = functools.partial(
F.scaled_dot_product_attention,
defaults = dict(
scale=softmax_scale,
is_causal=attn_mask_type == "causal",
enable_gqa=num_gqa_groups is not None,
**kwargs,
)

def attn_func(*args, **call_kwargs):
merged = {**defaults, **call_kwargs}
return F.scaled_dot_product_attention(*args, **merged)

return None, attn_func
elif attn_impl == "flex":
attn_module = FlexAttention()
Expand Down
29 changes: 28 additions & 1 deletion nemo_automodel/components/distributed/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,29 @@ def create_context_parallel_ctx(
)


def attach_context_parallel_hooks(model: torch.nn.Module):
"""Attach forward pre-hooks to self_attn modules to fix attention masks for context parallelism.

Context parallelism shards Q/K/V on the sequence dimension as DTensors,
so explicit 4D attention masks would have mismatched shapes. This function
registers a hook on every ``self_attn`` sub-module that strips the
``attention_mask`` kwarg and sets ``is_causal=True`` instead, letting
SDPA handle causal masking internally.

Based on ``accelerate.big_modeling._attach_context_parallel_hooks``.
"""

def _self_attn_pre_forward_hook(_module, module_args, module_kwargs):
if "attention_mask" in module_kwargs:
module_kwargs["attention_mask"] = None
module_kwargs["is_causal"] = True
return module_args, module_kwargs

for name, module in model.named_modules():
if name.endswith("self_attn"):
module.register_forward_pre_hook(_self_attn_pre_forward_hook, with_kwargs=True, prepend=True)


def make_cp_batch_and_ctx(
device_mesh,
batch,
Expand Down Expand Up @@ -152,7 +175,11 @@ def _get_mesh_size(mesh):
if _get_mesh_size(cp_mesh) <= 1:
return nullcontext, batch

# CP doesn't support packed sequence currently. Let torch SDPA handle attention mask.
# Remove attention_mask from the batch so the model does not attempt to
# build a 4D causal mask (which would have mismatched shapes with
# DTensor-sharded Q/K/V). Each self_attn module's forward_pre_hook
# (registered by attach_context_parallel_hooks) will set is_causal=True
# so that SDPA handles causal masking internally.
batch.pop("attention_mask", None)

# Skip 1D injection if position_ids already in batch (e.g. mRoPE pre-computed)
Expand Down
16 changes: 7 additions & 9 deletions nemo_automodel/components/moe/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,13 @@ def apply_cp(model: torch.nn.Module, cp_mesh: DeviceMesh, cp_comm_type: str = "p

for _, block in _model.layers.named_children():
attn_module = block.self_attn.attn_module
assert isinstance(attn_module, DotProductAttention), (
"Context parallelism is only supported for TransformerEngine's DotProductAttention"
)
attn_module.set_context_parallel_group(
cp_mesh.get_group(),
torch.distributed.get_process_group_ranks(cp_mesh.get_group()),
_get_cp_stream(),
cp_comm_type=cp_comm_type,
)
if isinstance(attn_module, DotProductAttention):
attn_module.set_context_parallel_group(
cp_mesh.get_group(),
torch.distributed.get_process_group_ranks(cp_mesh.get_group()),
_get_cp_stream(),
cp_comm_type=cp_comm_type,
)

moe_module = block.moe if hasattr(block, "moe") else block.mlp
if isinstance(moe_module, MoE):
Expand Down
10 changes: 10 additions & 0 deletions nemo_automodel/recipes/llm/train_ft.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move the changes in this file into model_init.py

Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def build_model(
cfg_moe=None,
activation_checkpointing=False,
unfreeze_modules: list[str] | None = None,
sdpa_method: list[str] | None = None,
) -> tuple[nn.Module | AutoPipeline, list["Optimizer"]]: # noqa: F821
"""Build and initialize a model.

Expand All @@ -170,6 +171,9 @@ def build_model(
cfg_moe: MoEParallelizerConfig instance, or ConfigNode to be converted.
activation_checkpointing: Whether to enable activation checkpointing.
unfreeze_modules: List of module names/substrings to unfreeze.
sdpa_method: Explicit list of SDPA backend name strings (e.g.
``["flash_attention", "efficient_attention"]``), or ``None`` to
auto-select based on CP / activation checkpointing.
"""
with ScopedRNG(seed=seed, ranked=True):
kwargs = {
Expand All @@ -179,6 +183,7 @@ def build_model(
"moe_mesh": moe_mesh,
"distributed_config": distributed_config,
"pipeline_config": pipeline_config,
"sdpa_method": sdpa_method,
}

if cfg_qat is not None and cfg_qat.get("enabled", False):
Expand Down Expand Up @@ -229,6 +234,10 @@ def build_model(
else:
# For non-NemoAutoModel entry points (e.g., build_gpt2_model),
# instantiate the model first, then apply infrastructure separately.
# Note: sdpa_method is not supported here — SDPA patching only runs
# inside NeMoAutoModel._build_model.
if sdpa_method is not None:
logger.warning("sdpa_method is ignored for non-NeMoAutoModel targets.")
# We must convert config objects into runtime objects (model_wrapper,
# autopipeline, parallelize_fn, etc.) via instantiate_infrastructure,
# exactly as from_pretrained/from_config do internally.
Expand Down Expand Up @@ -965,6 +974,7 @@ def setup(self):
cfg_qat=self.cfg.get("qat", None),
cfg_moe=self.dist_setup.moe_config,
activation_checkpointing=self.dist_setup.activation_checkpointing,
sdpa_method=self.cfg.get("sdpa_method", None),
)
self.optimizer = build_optimizer(model, self.cfg.optimizer, self.distributed_config, self.device_mesh)

Expand Down
58 changes: 57 additions & 1 deletion tests/unit_tests/attention/test_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F

from nemo_automodel.components.attention.utils import (
initialize_attn_module_and_func,
Expand Down Expand Up @@ -105,6 +108,53 @@ def test_unsupported_attention_implementation(self):
softmax_scale=0.125,
)

def test_sdpa_late_binding_picks_up_monkey_patch(self):
"""Test that SDPA attn_func uses late-bound lookup of F.scaled_dot_product_attention.

Context Parallelism monkey-patches F.scaled_dot_product_attention at runtime.
The returned attn_func must resolve the function at call time (not init time)
so that CP's patched version is used.
"""
_, attn_func = initialize_attn_module_and_func(
attn_impl="sdpa",
num_attention_heads=8,
num_qk_channels=64,
num_v_channels=64,
softmax_scale=0.125,
attn_mask_type="causal",
num_gqa_groups=4,
)

original_sdpa = F.scaled_dot_product_attention
sentinel = object()
wrapper = mock.MagicMock(return_value=sentinel)

# Simulate CP monkey-patching F.scaled_dot_product_attention
F.scaled_dot_product_attention = wrapper
try:
q = torch.randn(1, 1, 4, 8)
k = torch.randn(1, 1, 4, 8)
v = torch.randn(1, 1, 4, 8)
result = attn_func(q, k, v)

assert result is sentinel, "attn_func should call the patched function"
wrapper.assert_called_once()
args, kwargs = wrapper.call_args
assert torch.equal(args[0], q)
assert torch.equal(args[1], k)
assert torch.equal(args[2], v)
finally:
F.scaled_dot_product_attention = original_sdpa

# After restoring, verify original is called again
original_wrapper = mock.MagicMock(wraps=original_sdpa)
F.scaled_dot_product_attention = original_wrapper
try:
attn_func(q, k, v)
original_wrapper.assert_called_once()
finally:
F.scaled_dot_product_attention = original_sdpa


class TestPreprocessArgsAndKwargsForAttn:
"""Tests for preprocess_args_and_kwargs_for_attn function."""
Expand Down Expand Up @@ -175,7 +225,13 @@ def test_te_with_cu_seqlens_q_and_kv(self):
v_gpu = self.v.to(device)

q_out, k_out, v_out, attn_kwargs = preprocess_args_and_kwargs_for_attn(
q_gpu, k_gpu, v_gpu, attention_mask=None, attn_impl="te", cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv
q_gpu,
k_gpu,
v_gpu,
attention_mask=None,
attn_impl="te",
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)

assert "cu_seqlens_q" in attn_kwargs
Expand Down
Loading
Loading