diff --git a/nemo_automodel/components/distributed/cp_utils.py b/nemo_automodel/components/distributed/cp_utils.py index c12ff9f39..f7efc9416 100644 --- a/nemo_automodel/components/distributed/cp_utils.py +++ b/nemo_automodel/components/distributed/cp_utils.py @@ -101,6 +101,52 @@ def create_context_parallel_ctx( ) +def _dual_chunk_swap_select(tensor, cp_size, cp_rank, seq_dim=1): + """Select DualChunkSwap chunks for this CP rank along seq_dim. + + Splits into 2*cp_size chunks, selects chunks [cp_rank, 2*cp_size - cp_rank - 1]. + Requires seq_len divisible by 2*cp_size. + """ + seq_len = tensor.shape[seq_dim] + assert seq_len % (2 * cp_size) == 0, ( + f"Sequence length {seq_len} must be divisible by 2*cp_size={2 * cp_size} for DualChunkSwap CP" + ) + shape = list(tensor.shape) + shape[seq_dim : seq_dim + 1] = [2 * cp_size, seq_len // (2 * cp_size)] + tensor = tensor.view(*shape) + + index = torch.tensor([cp_rank, 2 * cp_size - cp_rank - 1], dtype=torch.int64, device=tensor.device) + tensor = tensor.index_select(seq_dim, index) + + shape = list(tensor.shape) + shape[seq_dim : seq_dim + 2] = [shape[seq_dim] * shape[seq_dim + 1]] + return tensor.reshape(*shape).contiguous() + + +def _split_batch_bshd_for_cp(batch, cp_mesh): + """Split BSHD batch with DualChunkSwap for hybrid Mamba-Attention CP. + + Each CP rank gets 2 non-contiguous chunks of the sequence for load-balanced + causal attention. For cp_size=2, the sequence is split into 4 chunks: + rank 0 gets chunks [0, 3], rank 1 gets chunks [1, 2]. + """ + cp_size = cp_mesh.size() + cp_rank = torch.distributed.get_rank(group=cp_mesh.get_group()) + + for key in ("input_ids", "labels", "position_ids"): + if key in batch and isinstance(batch[key], torch.Tensor) and batch[key].dim() >= 2: + batch[key] = _dual_chunk_swap_select(batch[key], cp_size, cp_rank, seq_dim=1) + + if "attention_mask" in batch and isinstance(batch["attention_mask"], torch.Tensor): + mask = batch["attention_mask"] + # attention_mask may be 2D [B, S] or 3D+ — seq dim varies + seq_dim = 2 if mask.dim() > 2 else 1 + batch["attention_mask"] = _dual_chunk_swap_select(mask, cp_size, cp_rank, seq_dim=seq_dim) + + batch.pop("causal_mask_mapping", None) + return batch + + def make_cp_batch_and_ctx( device_mesh, batch, @@ -109,6 +155,7 @@ def make_cp_batch_and_ctx( padding_token_id: int = 0, num_chunks: int = 1, seq_lens_padding_value: int = -1000, + use_hybrid_cp: bool = False, ): """ Build a CP context manager and shards a batch. If the input device_mesh is None or the size @@ -139,7 +186,7 @@ def _get_mesh_size(mesh): cp_mesh = _get_submesh(device_mesh, "cp") tp_mesh = _get_submesh(device_mesh, "tp") - if use_te: + if use_te and not use_hybrid_cp: return nullcontext, make_cp_batch_for_te( cp_mesh, batch, @@ -149,6 +196,11 @@ def _get_mesh_size(mesh): seq_lens_padding_value=seq_lens_padding_value, ) + if use_hybrid_cp: + if cp_mesh is None or cp_mesh.size() <= 1: + return nullcontext, batch + return nullcontext, _split_batch_bshd_for_cp(batch, cp_mesh) + if _get_mesh_size(cp_mesh) <= 1: return nullcontext, batch diff --git a/nemo_automodel/components/distributed/mamba_cp.py b/nemo_automodel/components/distributed/mamba_cp.py new file mode 100644 index 000000000..4cd934a86 --- /dev/null +++ b/nemo_automodel/components/distributed/mamba_cp.py @@ -0,0 +1,378 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Context parallelism for Mamba/SSM layers using a hidden-parallel strategy. + +Instead of splitting the sequence across CP ranks (as attention CP does), this module +uses an all-to-all redistribution so that each CP rank processes the *full* sequence +but only a *subset* of heads (d_inner / cp_size). The data flow is:: + + [B, L_local, D] --> all-to-all --> [B, L_global, D/cp] + --> conv1d + SSM kernel --> + [B, L_global, D/cp] --> all-to-all --> [B, L_local, D] + +This module is intentionally **not** a subclass of ``nn.Module`` because it owns +no trainable parameters. It holds *references* to the Mamba mixer's parameters +and slices them in the forward path so that gradients flow back to the full +(unsliced) parameters. +""" + +import torch +import torch.distributed +import torch.nn as nn + +# --------------------------------------------------------------------------- +# Autograd-aware all-to-all primitive +# --------------------------------------------------------------------------- + + +class _AllToAll(torch.autograd.Function): + """Autograd wrapper around ``torch.distributed.all_to_all_single``. + + For equal-sized splits the all-to-all operation is its own inverse, + so the backward pass is simply another all-to-all on the same group. + """ + + @staticmethod + def forward(ctx, input_: torch.Tensor, group: torch.distributed.ProcessGroup) -> torch.Tensor: + ctx.group = group + output = torch.empty_like(input_) + torch.distributed.all_to_all_single(output, input_, group=group) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + group = ctx.group + grad_input = torch.empty_like(grad_output) + torch.distributed.all_to_all_single(grad_input, grad_output, group=group) + return grad_input, None + + +def _all_to_all(input_: torch.Tensor, group: torch.distributed.ProcessGroup) -> torch.Tensor: + """Functional entry-point for the autograd-aware all-to-all.""" + return _AllToAll.apply(input_, group) + + +# --------------------------------------------------------------------------- +# Sequence-sharded <-> Hidden-sharded layout transformations (batch-first) +# --------------------------------------------------------------------------- + + +def _all_to_all_cp2hp( + input_: torch.Tensor, + cp_group: torch.distributed.ProcessGroup, + batch_size: int, +) -> torch.Tensor: + """Transform from sequence-sharded to hidden-sharded layout (batch-first). + + Args: + input_: Tensor of shape ``[B, L_local, H]`` where H is the full hidden + dimension on this rank. + cp_group: Context-parallel process group. + batch_size: Batch size ``B`` (needed to recover dimensions after reshape). + + Returns: + Tensor of shape ``[B, L_global, H / cp_size]``. + """ + cp_size = cp_group.size() + B, L_local, H = input_.shape + H_local = H // cp_size + + # [B*L_local, cp, H_local] -> [cp, B*L_local, H_local] -> flatten for all-to-all + send_tensor = ( + input_.reshape(B * L_local, cp_size, H_local) + .permute(1, 0, 2) + .contiguous() + .reshape(cp_size * B * L_local, H_local) + ) + + recv_tensor = _all_to_all(send_tensor, cp_group) + + # [cp, B, L_local, H_local] -> [B, cp*L_local, H_local] + return ( + recv_tensor.reshape(cp_size, B, L_local, H_local) + .permute(1, 0, 2, 3) + .contiguous() + .reshape(B, cp_size * L_local, H_local) + ) + + +def _all_to_all_hp2cp( + input_: torch.Tensor, + cp_group: torch.distributed.ProcessGroup, + batch_size: int, +) -> torch.Tensor: + """Transform from hidden-sharded to sequence-sharded layout (batch-first). + + This is the inverse of :func:`_all_to_all_cp2hp`. + + Args: + input_: Tensor of shape ``[B, L_global, H_local]`` where ``H_local = H / cp_size``. + cp_group: Context-parallel process group. + batch_size: Batch size ``B``. + + Returns: + Tensor of shape ``[B, L_local, H]`` where ``L_local = L_global / cp_size`` + and ``H = H_local * cp_size``. + """ + cp_size = cp_group.size() + B, L_global, H_local = input_.shape + L_local = L_global // cp_size + + # [B, cp, L_local, H_local] -> [cp, B, L_local, H_local] -> flatten for all-to-all + send_tensor = ( + input_.reshape(B, cp_size, L_local, H_local) + .permute(1, 0, 2, 3) + .contiguous() + .reshape(cp_size * B * L_local, H_local) + ) + + recv_tensor = _all_to_all(send_tensor, cp_group) + + # [cp, B*L_local, H_local] -> [B*L_local, cp, H_local] -> [B, L_local, H] + return ( + recv_tensor.reshape(cp_size, B * L_local, H_local) + .permute(1, 0, 2) + .contiguous() + .reshape(B, L_local, cp_size * H_local) + ) + + +def _undo_attention_load_balancing(input_: torch.Tensor, cp_size: int) -> torch.Tensor: + """Reorder from DualChunkSwap to sequential for SSM processing. + + Operates on dim=1 (sequence) of [B, L_global, H_local]. + For cp_size=2 (4 chunks): reorder [0,3,1,2] → [0,1,2,3]. + """ + num_chunks = 2 * cp_size + chunks = torch.chunk(input_, chunks=num_chunks, dim=1) + # Even indices select from first halves, odd indices (reversed) select from second halves + order = [2 * i for i in range(cp_size)] + [num_chunks - 1 - 2 * i for i in range(cp_size)] + return torch.cat([chunks[i] for i in order], dim=1) + + +def _redo_attention_load_balancing(input_: torch.Tensor, cp_size: int) -> torch.Tensor: + """Reorder from sequential back to DualChunkSwap for attention. + + Inverse of _undo_attention_load_balancing. + """ + num_chunks = 2 * cp_size + chunks = torch.chunk(input_, chunks=num_chunks, dim=1) + order = [None] * num_chunks + order[::2] = range(cp_size) + order[1::2] = reversed(range(cp_size, num_chunks)) + return torch.cat([chunks[i] for i in order], dim=1) + + +# --------------------------------------------------------------------------- +# MambaContextParallel – orchestrates CP for a single Mamba mixer layer +# --------------------------------------------------------------------------- + + +class MambaContextParallel: + """Hidden-parallel context parallelism for a Mamba2 mixer layer. + + This class does **not** own trainable parameters. It stores *references* + to the mixer's parameters (conv1d, dt_bias, A_log, D) and slices them on + the fly so that gradients propagate to the original (full) parameters. + + Args: + cp_group: Context-parallel process group. + num_heads: Total number of SSM heads (before any parallelism). + head_dim: Dimension per head. + n_groups: Number of SSM groups (for grouped B/C states). + d_state: SSM state dimension. + conv1d: Reference to the mixer's ``nn.Conv1d`` module. + dt_bias: Reference to the mixer's ``dt_bias`` parameter. + A_log: Reference to the mixer's ``A_log`` parameter. + D: Reference to the mixer's ``D`` parameter. + """ + + def __init__( + self, + cp_group: torch.distributed.ProcessGroup, + num_heads: int, + head_dim: int, + n_groups: int, + d_state: int, + conv1d: nn.Conv1d, + dt_bias: torch.Tensor, + A_log: torch.Tensor, + D: torch.Tensor, + ) -> None: + self.cp_group = cp_group + self.num_heads = num_heads + self.head_dim = head_dim + self.n_groups = n_groups + self.d_state = d_state + self.conv1d = conv1d + self.dt_bias = dt_bias + self.A_log = A_log + self.D = D + + self.cp_size = cp_group.size() + self.cp_rank = cp_group.rank() + + self.d_inner = num_heads * head_dim + + # --- Validate and compute per-rank sizes --- + + # Each CP rank must get at least one head. + assert num_heads % self.cp_size == 0, f"num_heads ({num_heads}) must be divisible by cp_size ({self.cp_size})" + self.num_heads_local = num_heads // self.cp_size + self.d_inner_local = self.num_heads_local * head_dim + + # Groups: when n_groups < cp_size we need to replicate B/C states. + if n_groups < self.cp_size: + assert self.cp_size % n_groups == 0, ( + f"cp_size ({self.cp_size}) must be divisible by n_groups ({n_groups}) when n_groups < cp_size" + ) + self.group_repeat_count = self.cp_size // n_groups + self.n_groups_local = 1 + else: + assert n_groups % self.cp_size == 0, f"n_groups ({n_groups}) must be divisible by cp_size ({self.cp_size})" + self.group_repeat_count = 1 + self.n_groups_local = n_groups // self.cp_size + + # ------------------------------------------------------------------ # + # Activation transforms (before / after conv+SSM) # + # ------------------------------------------------------------------ # + + def pre_conv_ssm(self, projected_states: torch.Tensor) -> torch.Tensor: + """Redistribute ``[B, L_local, proj_dim]`` to ``[B, L_global, proj_dim_local]``. + + Splits the packed projection into [z, x, B, C, dt], optionally replicates + B/C states, performs cp2hp all-to-all on each, and re-concatenates. + """ + if self.cp_size == 1: + return projected_states + + B = projected_states.shape[0] + groups_state_size = self.n_groups * self.d_state + + z, x, B_state, C_state, dt = torch.split( + projected_states, + [self.d_inner, self.d_inner, groups_state_size, groups_state_size, self.num_heads], + dim=-1, + ) + + # Replicate B and C group states when n_groups < cp_size so that + # replicas land on consecutive CP ranks with their associated heads. + if self.group_repeat_count > 1: + B_state = self._repeat_group_state(B_state) + C_state = self._repeat_group_state(C_state) + + z = _all_to_all_cp2hp(z, self.cp_group, B) + x = _all_to_all_cp2hp(x, self.cp_group, B) + B_state = _all_to_all_cp2hp(B_state, self.cp_group, B) + C_state = _all_to_all_cp2hp(C_state, self.cp_group, B) + dt = _all_to_all_cp2hp(dt, self.cp_group, B) + + result = torch.cat([z, x, B_state, C_state, dt], dim=-1) + return _undo_attention_load_balancing(result, self.cp_size) + + def post_conv_ssm(self, output: torch.Tensor) -> torch.Tensor: + """Redistribute SSM output from hidden-sharded back to sequence-sharded layout. + + Args: + output: ``[B, L_global, d_inner / cp_size]`` — the (already gated) + SSM output on this rank. + + Returns: + ``[B, L_local, d_inner]`` — sequence-sharded output. + """ + if self.cp_size == 1: + return output + + B = output.shape[0] + output = _redo_attention_load_balancing(output, self.cp_size) + return _all_to_all_hp2cp(output, self.cp_group, B) + + # ------------------------------------------------------------------ # + # Parameter slicing (returns views so grads flow to full params) # + # ------------------------------------------------------------------ # + + def get_conv1d_weight(self) -> torch.Tensor: + """Slice ``conv1d.weight`` for the current CP rank. + + Weight shape: ``[conv_dim, 1, kernel_size]`` where + ``conv_dim = d_inner + 2 * n_groups * d_state``. + Returns ``[conv_dim_local, kernel_size]`` (squeezed for causal_conv1d kernel). + """ + return self._slice_conv_param(self.conv1d.weight).squeeze(1) + + def get_conv1d_bias(self) -> torch.Tensor: + """Slice ``conv1d.bias`` for the current CP rank. + + Bias shape: ``[conv_dim]``. Returns ``[conv_dim_local]``. + """ + if self.conv1d.bias is None: + return None + return self._slice_conv_param(self.conv1d.bias) + + def get_dt_bias(self) -> torch.Tensor: + """Slice ``dt_bias`` for the current CP rank.""" + return self._slice_vector_param(self.dt_bias) + + def get_A_log(self) -> torch.Tensor: + """Slice ``A_log`` for the current CP rank.""" + return self._slice_vector_param(self.A_log) + + def get_D(self) -> torch.Tensor: + """Slice ``D`` for the current CP rank.""" + return self._slice_vector_param(self.D) + + # ------------------------------------------------------------------ # + # Internal helpers # + # ------------------------------------------------------------------ # + + def _repeat_group_state(self, state: torch.Tensor) -> torch.Tensor: + """Repeat group states for CP ranks when n_groups < cp_size. + + ``[B, L, n_groups * d_state]`` -> ``[B, L, n_groups * repeat * d_state]`` + """ + return ( + state.reshape(*state.shape[:-1], self.n_groups, self.d_state) + .unsqueeze(-2) + .expand(-1, -1, -1, self.group_repeat_count, -1) + .reshape(*state.shape[:-1], self.n_groups * self.group_repeat_count * self.d_state) + ) + + def _slice_vector_param(self, param: torch.Tensor) -> torch.Tensor: + """Slice a per-head vector parameter for the current CP rank.""" + start = self.cp_rank * self.num_heads_local + return param[start : start + self.num_heads_local] + + def _slice_conv_param(self, param: torch.Tensor) -> torch.Tensor: + """Slice a conv1d parameter (weight or bias) along its channel dimension. + + Parameter slicing is done in the forward path so that gradients + backpropagate to the original (full) parameters. + """ + groups_state_size = self.n_groups * self.d_state + x, B_param, C_param = torch.split( + param, + [self.d_inner, groups_state_size, groups_state_size], + dim=0, + ) + + x_start = self.cp_rank * self.d_inner_local + x_sliced = x[x_start : x_start + self.d_inner_local] + + bc_size = self.n_groups_local * self.d_state + bc_start = (self.cp_rank // self.group_repeat_count) * bc_size + B_sliced = B_param[bc_start : bc_start + bc_size] + C_sliced = C_param[bc_start : bc_start + bc_size] + + return torch.cat([x_sliced, B_sliced, C_sliced], dim=0).contiguous() diff --git a/nemo_automodel/components/distributed/parallelizer.py b/nemo_automodel/components/distributed/parallelizer.py index a45e00642..dbbd0ca30 100644 --- a/nemo_automodel/components/distributed/parallelizer.py +++ b/nemo_automodel/components/distributed/parallelizer.py @@ -251,6 +251,38 @@ def parallelize( if layer.block_type == "mlp": parallelize_module(layer, tp_mesh, mlp_tp_plan) + # Set up context parallel for Mamba and Attention layers + cp_mesh = device_mesh["cp"] if "cp" in device_mesh.mesh_dim_names else None + if cp_mesh is not None and cp_mesh.size() > 1: + cp_group = cp_mesh.get_group() + for layer in layers: + if hasattr(layer, "block_type") and layer.block_type == "mamba": + from nemo_automodel.components.distributed.mamba_cp import MambaContextParallel + + mixer = layer.mixer + mixer.cp = MambaContextParallel( + cp_group=cp_group, + num_heads=mixer.num_heads, + head_dim=mixer.head_dim, + n_groups=mixer.n_groups, + d_state=mixer.ssm_state_size, + conv1d=mixer.conv1d, + dt_bias=mixer.dt_bias, + A_log=mixer.A_log, + D=mixer.D, + ) + elif hasattr(layer, "block_type") and layer.block_type == "attention": + from transformer_engine.pytorch.attention import DotProductAttention + + attn_module = layer.mixer.attn_module + if isinstance(attn_module, DotProductAttention): + attn_module.set_context_parallel_group( + cp_group, + torch.distributed.get_process_group_ranks(cp_group), + torch.cuda.Stream(), + cp_comm_type="p2p", + ) + if activation_checkpointing: for i in range(len(layers)): if layers[i].block_type == "mlp": diff --git a/nemo_automodel/components/models/nemotron_v3/layers.py b/nemo_automodel/components/models/nemotron_v3/layers.py index d86e530c2..a8f8edeb6 100644 --- a/nemo_automodel/components/models/nemotron_v3/layers.py +++ b/nemo_automodel/components/models/nemotron_v3/layers.py @@ -19,19 +19,24 @@ from torch import nn from torch.distributed.tensor import DTensor -from nemo_automodel.components.models.common import initialize_rms_norm_module +from nemo_automodel.components.attention.utils import ( + initialize_attn_module_and_func, + postprocess_output_for_attn, + preprocess_args_and_kwargs_for_attn, +) +from nemo_automodel.components.models.common import ( + BackendConfig, + initialize_linear_module, + initialize_rms_norm_module, +) class NemotronV3Attention(nn.Module): - """Multi-headed attention for NemotronV3 (Nano-v3). + """GQA attention for NemotronV3 (no RoPE), compatible with TE/SDPA backends.""" - This is a standard GQA attention module following the NemotronH architecture. - Uses PyTorch's scaled_dot_product_attention (SDPA) for the attention computation. - Note: RoPE is not applied in this module, matching the HF NemotronHAttention implementation. - """ - - def __init__(self, config): + def __init__(self, config, backend: BackendConfig | None = None): super().__init__() + self.backend = backend or BackendConfig() self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads @@ -40,26 +45,27 @@ def __init__(self, config): self.attention_bias = getattr(config, "attention_bias", False) self.attention_dropout = getattr(config, "attention_dropout", 0.0) - # Q, K, V, O projections - self.q_proj = nn.Linear( - self.hidden_size, - self.num_attention_heads * self.head_dim, - bias=self.attention_bias, + self.q_proj = initialize_linear_module( + self.backend.linear, self.hidden_size, self.num_attention_heads * self.head_dim, self.attention_bias ) - self.k_proj = nn.Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=self.attention_bias, + self.k_proj = initialize_linear_module( + self.backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, self.attention_bias ) - self.v_proj = nn.Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=self.attention_bias, + self.v_proj = initialize_linear_module( + self.backend.linear, self.hidden_size, self.num_key_value_heads * self.head_dim, self.attention_bias ) - self.o_proj = nn.Linear( - self.num_attention_heads * self.head_dim, - self.hidden_size, - bias=self.attention_bias, + self.o_proj = initialize_linear_module( + self.backend.linear, self.num_attention_heads * self.head_dim, self.hidden_size, self.attention_bias + ) + + softmax_scale = self.head_dim**-0.5 + self.attn_module, self.attn_func = initialize_attn_module_and_func( + attn_impl=self.backend.attn, + num_attention_heads=self.num_attention_heads, + num_qk_channels=self.head_dim, + num_v_channels=self.head_dim, + softmax_scale=softmax_scale, + num_gqa_groups=self.num_key_value_heads, ) def forward( @@ -68,44 +74,43 @@ def forward( attention_mask: torch.Tensor | None = None, past_key_values=None, layer_idx: int | None = None, + **attn_kwargs, ) -> torch.Tensor: bsz, seqlen, _ = hidden_states.size() + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + + # Inference path: KV cache with SDPA + if past_key_values is not None: + q = q.view(bsz, seqlen, self.num_attention_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seqlen, self.num_key_value_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seqlen, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if layer_idx is not None: + k, v = past_key_values.update(k, v, layer_idx) + is_causal = attention_mask is None and q.shape[2] > 1 and q.shape[2] == k.shape[2] + output = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + enable_gqa=self.num_key_value_heads != self.num_attention_heads, + ) + output = output.transpose(1, 2).contiguous() + output = output.view(bsz, seqlen, self.num_attention_heads * self.head_dim) + return self.o_proj(output) - # Compute Q, K, V - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - # Reshape to (B, H, S, D) for SDPA - q = q.view(bsz, seqlen, self.num_attention_heads, self.head_dim).transpose(1, 2) - k = k.view(bsz, seqlen, self.num_key_value_heads, self.head_dim).transpose(1, 2) - v = v.view(bsz, seqlen, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - # Update KV cache if present - if past_key_values is not None and layer_idx is not None: - k, v = past_key_values.update(k, v, layer_idx) - - # Run attention with SDPA - # During cached decode (q has 1 token, k/v have many), use explicit mask - # instead of is_causal since SDPA's causal mask requires matching seq lengths. - is_causal = attention_mask is None and q.shape[2] > 1 and q.shape[2] == k.shape[2] - output = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - enable_gqa=self.num_key_value_heads != self.num_attention_heads, - ) - - # Reshape back to (B, S, H * D) - output = output.transpose(1, 2).contiguous() - output = output.view(bsz, seqlen, self.num_attention_heads * self.head_dim) - - # Output projection - output = self.o_proj(output) + # Training path: backend-aware attention (TE or SDPA) + q = q.view(bsz, seqlen, self.num_attention_heads, self.head_dim) + k = k.view(bsz, seqlen, self.num_key_value_heads, self.head_dim) + v = v.view(bsz, seqlen, self.num_key_value_heads, self.head_dim) + q, k, v, _attn_kwargs = preprocess_args_and_kwargs_for_attn( + q, k, v, attention_mask, self.backend.attn, **attn_kwargs + ) + output = self.attn_func(q, k, v, **_attn_kwargs) + output = postprocess_output_for_attn(output, self.backend.attn) + output = self.o_proj(output.flatten(2)) return output @torch.no_grad() @@ -115,13 +120,11 @@ def init_weights( rescale_prenorm_residual: bool = True, buffer_device: torch.device | None = None, ) -> None: - """Initialize attention weights following NemotronV3 spec.""" with buffer_device: for proj in [self.q_proj, self.k_proj, self.v_proj, self.o_proj]: if proj.bias is not None: nn.init.zeros_(proj.bias) - # Rescale o_proj for stable residual stream if rescale_prenorm_residual: self.o_proj.weight /= math.sqrt(num_hidden_layers) @@ -221,6 +224,9 @@ def __init__(self, config, layer_idx: int): # Output projection self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + # Context parallelism — set post-construction by the parallelizer + self.cp = None + def forward( self, hidden_states: torch.Tensor, @@ -262,30 +268,61 @@ def forward( hidden_states = hidden_states * attention_mask.unsqueeze(-1) projected_states = self.in_proj(hidden_states) - A = -torch.exp(self.A_log.float()) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} - out = mamba_split_conv1d_scan_combined( - projected_states, - self.conv1d.weight.squeeze(1), - self.conv1d.bias, - self.dt_bias, - A, - D=self.D, - chunk_size=self.chunk_size, - seq_idx=None, - activation=self.activation, - rmsnorm_weight=self.norm.weight, - rmsnorm_eps=self.norm.variance_epsilon, - outproj_weight=self.out_proj.weight, - outproj_bias=self.out_proj.bias, - headdim=self.head_dim, - ngroups=self.n_groups, - norm_before_gate=False, - return_final_states=False, - **dt_limit_kwargs, - ) - return out + if self.cp is not None: + projected_states = self.cp.pre_conv_ssm(projected_states) + A = -torch.exp(self.cp.get_A_log().float()) + + out = mamba_split_conv1d_scan_combined( + projected_states, + self.cp.get_conv1d_weight(), + self.cp.get_conv1d_bias(), + self.cp.get_dt_bias(), + A, + D=self.cp.get_D(), + chunk_size=self.chunk_size, + activation=self.activation, + rmsnorm_weight=None, + outproj_weight=None, + headdim=self.head_dim, + ngroups=self.cp.n_groups_local, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + if out.ndim == 4: + out = out.reshape(out.shape[0], out.shape[1], -1) + + out = self.cp.post_conv_ssm(out) + out = self.norm(out, gate=None) + out = self.out_proj(out) + return out + else: + A = -torch.exp(self.A_log.float()) + + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + return out # --- Path C: Decode (single token, cached state) --- if use_precomputed_states: @@ -491,7 +528,7 @@ def __init__(self, config, layer_idx: int, moe_config=None, backend=None): if self.block_type == "mamba": self.mixer = NemotronV3Mamba2Mixer(config, layer_idx=layer_idx) elif self.block_type == "attention": - self.mixer = NemotronV3Attention(config) + self.mixer = NemotronV3Attention(config, backend=backend) elif self.block_type == "mlp": from nemo_automodel.components.moe.layers import MLP from nemo_automodel.shared.utils import dtype_from_str @@ -529,6 +566,7 @@ def forward( attention_mask: torch.Tensor | None = None, past_key_values=None, cache_position: torch.LongTensor | None = None, + **attn_kwargs, ) -> torch.Tensor: """Forward pass through the block. @@ -540,6 +578,8 @@ def forward( - For mlp/moe: None past_key_values: Optional NemotronHybridCache for KV/SSM caching. cache_position: Token position indices for cache updates. + **attn_kwargs: Additional keyword arguments forwarded to attention layers + only (e.g. cu_seqlens, cp_size, cp_rank for Context Parallelism). Returns: Output tensor of shape (batch, seq_len, hidden_size) @@ -568,6 +608,7 @@ def forward( attention_mask=attention_mask, past_key_values=past_key_values, layer_idx=self.layer_idx, + **attn_kwargs, ) elif self.block_type in ["mlp", "moe"]: hidden_states = self.mixer(hidden_states) diff --git a/nemo_automodel/components/models/nemotron_v3/model.py b/nemo_automodel/components/models/nemotron_v3/model.py index 1c65e82e3..21a539c0c 100644 --- a/nemo_automodel/components/models/nemotron_v3/model.py +++ b/nemo_automodel/components/models/nemotron_v3/model.py @@ -119,7 +119,7 @@ def forward( inputs_embeds: Input embeddings [batch_size, seq_len, hidden_size] (optional) past_key_values: Optional NemotronHybridCache for incremental decoding. cache_position: Token position indices for cache updates. - **kwargs: Additional arguments (ignored) + **kwargs: Forwarded to attention layers (e.g. CP kwargs). Returns: Hidden states tensor [batch_size, seq_len, hidden_size] @@ -141,8 +141,10 @@ def forward( for layer in self.layers.values(): # Pass appropriate mask based on layer type if layer.block_type == "attention": - # Attention layers use 4D causal mask - mask = causal_mask + # Attention layers use 4D causal mask; fall back to 2D attention_mask + # when causal_mask is None (e.g. during TE+CP training where CP split + # removes the precomputed 4D mask) so TE can use padding_causal mode. + mask = causal_mask if causal_mask is not None else attention_mask elif layer.block_type == "mamba": # Mamba layers use 2D padding mask during prefill, None during decode mask = None if (past_key_values is not None and past_key_values.has_previous_state) else attention_mask @@ -155,6 +157,7 @@ def forward( attention_mask=mask, past_key_values=past_key_values, cache_position=cache_position, + **kwargs, ) # Final norm diff --git a/nemo_automodel/recipes/llm/train_ft.py b/nemo_automodel/recipes/llm/train_ft.py index f3482c406..8319209c4 100644 --- a/nemo_automodel/recipes/llm/train_ft.py +++ b/nemo_automodel/recipes/llm/train_ft.py @@ -135,6 +135,17 @@ def _get_num_thd_chunks(pp_enabled, cfg): return 1 +def _is_hybrid_mamba_attention(cfg_model): + """Check if model has mixed Mamba and Attention layers (e.g., NemotronV3).""" + config = getattr(cfg_model, "config", None) + if config is None: + return False + block_types = getattr(config, "layers_block_type", None) + if block_types is None: + return False + return "mamba" in block_types and "attention" in block_types + + def build_model( cfg_model, cfg_peft, @@ -1225,6 +1236,7 @@ def _forward_backward_step( self.device_mesh, batch, use_te=_uses_te_dot_product_attention(self.cfg.model) and _uses_thd_collater(self.cfg.dataloader), + use_hybrid_cp=_is_hybrid_mamba_attention(self.cfg.model), padding_token_id=self.tokenizer.pad_token_id if self.tokenizer else 0, num_chunks=_get_num_thd_chunks(self.pp_enabled, self.cfg), ) diff --git a/tests/functional_tests/context_parallel/L2_CP_NemotronV3_Attention_Test.sh b/tests/functional_tests/context_parallel/L2_CP_NemotronV3_Attention_Test.sh new file mode 100644 index 000000000..6ec3e1d92 --- /dev/null +++ b/tests/functional_tests/context_parallel/L2_CP_NemotronV3_Attention_Test.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeuo pipefail + +export PYTHONPATH=${PYTHONPATH:-}:$(pwd) +export CUDA_VISIBLE_DEVICES="0,1" + +torchrun --nproc_per_node=2 --nnodes=1 \ + tests/functional_tests/context_parallel/run_nemotron_v3_attention_cp.py diff --git a/tests/functional_tests/context_parallel/L2_CP_NemotronV3_Mamba_Test.sh b/tests/functional_tests/context_parallel/L2_CP_NemotronV3_Mamba_Test.sh new file mode 100755 index 000000000..bf9f7573d --- /dev/null +++ b/tests/functional_tests/context_parallel/L2_CP_NemotronV3_Mamba_Test.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeuo pipefail # Exit immediately if a command exits with a non-zero status + +export PYTHONPATH=${PYTHONPATH:-}:$(pwd) +export CUDA_VISIBLE_DEVICES="0,1" + +# Run NemotronV3 Mamba2Mixer layer CP test with 2 GPUs +torchrun --nproc_per_node=2 --nnodes=1 \ + tests/functional_tests/context_parallel/run_mamba_cp.py diff --git a/tests/functional_tests/context_parallel/run_hybrid_nemotron_v3_cp.py b/tests/functional_tests/context_parallel/run_hybrid_nemotron_v3_cp.py new file mode 100644 index 000000000..efb7ec8e7 --- /dev/null +++ b/tests/functional_tests/context_parallel/run_hybrid_nemotron_v3_cp.py @@ -0,0 +1,323 @@ +#!/usr/bin/env python +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end hybrid NemotronV3 CP test. + +Validates that a hybrid model with interleaved attention (TE p2p CP) and +mamba (hidden-parallel CP) layers produces matching outputs/gradients +between CP=1 and CP=2 with DualChunkSwap sequence distribution. + +Usage: + torchrun --nproc_per_node=2 tests/functional_tests/context_parallel/run_hybrid_nemotron_v3_cp.py +""" + +import os +import sys + +import torch +import torch.distributed as dist + +from nemo_automodel.components.distributed.cp_utils import _dual_chunk_swap_select + + +def dual_chunk_swap_unsplit(chunks_per_rank, cp_size, seq_dim=1): + """Reconstruct full sequence from DualChunkSwap-ordered rank outputs.""" + all_chunks = [None] * (2 * cp_size) + for rank_idx, rank_output in enumerate(chunks_per_rank): + c0, c1 = torch.chunk(rank_output, 2, dim=seq_dim) + all_chunks[rank_idx] = c0 + all_chunks[2 * cp_size - rank_idx - 1] = c1 + return torch.cat(all_chunks, dim=seq_dim) + + +def init_distributed(): + """Initialize distributed environment from torchrun env vars.""" + if not (dist.is_available() and dist.is_initialized()): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + dist.init_process_group(backend="nccl") + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + +class MockHybridConfig: + """Mock configuration for a hybrid NemotronV3 model (attention + mamba layers). + + Provides only the fields required by NemotronV3Model and its block types. + MoE-related fields are still required because NemotronV3Model constructs + a MoEConfig in __init__ regardless of layer types; they are set to minimal + values that avoid errors without activating MoE layers. + """ + + def __init__(self): + # Attention config + self.num_attention_heads = 8 + self.num_key_value_heads = 4 + self.head_dim = 32 + self.hidden_size = 256 # num_attention_heads * head_dim + self.attention_bias = False + self.attention_dropout = 0.0 + + # Mamba config + self.mamba_num_heads = 8 + self.mamba_head_dim = 32 + self.ssm_state_size = 16 + self.n_groups = 2 # must be >= cp_size for non-replicated mode + self.chunk_size = 256 + self.conv_kernel = 4 + self.use_conv_bias = True + self.mamba_hidden_act = "silu" + self.time_step_limit = (0.0, float("inf")) + self.time_step_min = 0.001 + self.time_step_max = 0.1 + self.time_step_floor = 1e-4 + self.use_bias = False + + # Shared norm / model config + self.layer_norm_epsilon = 1e-5 + self.num_hidden_layers = 4 + self.vocab_size = 128 + self.torch_dtype = "bfloat16" + self.initializer_range = 0.02 + self.rescale_prenorm_residual = True + self.residual_in_fp32 = False + + # Hybrid layer schedule: interleaved attention and mamba + self.layers_block_type = ["attention", "mamba", "attention", "mamba"] + + # MLP config (required by MLP block type, kept here for completeness) + self.intermediate_size = 512 + self.mlp_bias = False + self.mlp_hidden_act = "silu" + + # MoE config fields — required by NemotronV3Model.__init__ even when + # no MoE layers are present in layers_block_type. + self.n_routed_experts = 1 + self.num_experts_per_tok = 1 + self.n_group = 1 + self.topk_group = 1 + self.routed_scaling_factor = 1.0 + self.moe_intermediate_size = self.intermediate_size + self.norm_topk_prob = False + self.moe_shared_expert_intermediate_size = self.intermediate_size + + +def run_test(): + """Run the end-to-end CP validation test for hybrid NemotronV3.""" + world_size = dist.get_world_size() + rank = dist.get_rank() + + if world_size != 2: + if rank == 0: + print(f"ERROR: This test requires exactly 2 GPUs, got {world_size}", file=sys.stderr) + return 1 + + try: + import transformer_engine.pytorch # noqa: F401 + except ImportError: + if rank == 0: + print("ERROR: transformer_engine is required but not installed", file=sys.stderr) + return 1 + + device = torch.device(f"cuda:{rank}") + + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + + from torch.distributed.device_mesh import init_device_mesh + from transformer_engine.pytorch.attention import DotProductAttention + + from nemo_automodel.components.distributed.mamba_cp import MambaContextParallel + from nemo_automodel.components.models.common import BackendConfig + from nemo_automodel.components.models.nemotron_v3.model import NemotronV3Model + + config = MockHybridConfig() + # Use torch linear and rms_norm to avoid TE internal buffers that would + # complicate state_dict transfers, but keep TE attention for p2p CP. + backend = BackendConfig( + linear="torch", + attn="te", + rms_norm="torch", + enable_hf_state_dict_adapter=False, + ) + + # ===== Baseline: CP=1 (no context parallelism) ===== + model_baseline = NemotronV3Model(config, backend=backend).to(device=device, dtype=torch.bfloat16) + model_baseline.train() + + # Sync weights across ranks so both start from identical parameters + for p in model_baseline.parameters(): + dist.broadcast(p.data, src=0) + + batch_size = 2 + seq_len = 128 # must be divisible by 2 * cp_size = 4 + + torch.manual_seed(42) + input_ids_full = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device) + dist.broadcast(input_ids_full, src=0) + + output_baseline = model_baseline(input_ids=input_ids_full) + loss_baseline = output_baseline.sum() + loss_baseline.backward() + + # Save baseline results before any further operations + output_baseline_detached = output_baseline.detach().clone() + embed_grad_baseline = model_baseline.embed_tokens.weight.grad.detach().clone() + + dist.barrier() + + # ===== Test: CP=2 (context parallelism with p2p for attn, hidden-parallel for mamba) ===== + model_cp = NemotronV3Model(config, backend=backend).to(device=device, dtype=torch.bfloat16) + model_cp.train() + + # Copy weights from baseline model; use strict=False to tolerate any TE + # internal buffers that may not appear in the baseline's state_dict + model_cp.load_state_dict(model_baseline.state_dict(), strict=False) + + # Zero out any gradients that may have accumulated during load + model_cp.zero_grad() + + # Build CP process group + cp_size = world_size # 2 + cp_mesh = init_device_mesh("cuda", (cp_size,), mesh_dim_names=("cp",)) + cp_group = cp_mesh["cp"].get_group() + + # Wire CP on each hybrid layer + for layer in model_cp.layers.values(): + if layer.block_type == "mamba": + mixer = layer.mixer + mixer.cp = MambaContextParallel( + cp_group=cp_group, + num_heads=mixer.num_heads, + head_dim=mixer.head_dim, + n_groups=mixer.n_groups, + d_state=mixer.ssm_state_size, + conv1d=mixer.conv1d, + dt_bias=mixer.dt_bias, + A_log=mixer.A_log, + D=mixer.D, + ) + elif layer.block_type == "attention": + attn_module = layer.mixer.attn_module + if isinstance(attn_module, DotProductAttention): + attn_module.set_context_parallel_group( + cp_group, + torch.distributed.get_process_group_ranks(cp_group), + torch.cuda.Stream(), + cp_comm_type="p2p", + ) + + # DualChunkSwap: each rank gets two non-contiguous chunks + input_ids_local = _dual_chunk_swap_select( + input_ids_full, cp_size=cp_size, cp_rank=rank, seq_dim=1 + ) + + output_cp_local = model_cp(input_ids=input_ids_local) + loss_cp = output_cp_local.sum() + loss_cp.backward() + + # Gather local outputs from all CP ranks + local_seq = output_cp_local.shape[1] + output_gathered = [ + torch.zeros(batch_size, local_seq, config.hidden_size, device=device, dtype=torch.bfloat16) + for _ in range(cp_size) + ] + dist.all_gather(output_gathered, output_cp_local.detach().contiguous(), group=cp_group) + + # Reconstruct full-sequence output from DualChunkSwap ordering + output_cp_full = dual_chunk_swap_unsplit(output_gathered, cp_size=cp_size, seq_dim=1) + + # Embedding weight gradient: each rank only sees a subset of the sequence, + # so the embedding grad is only partial — all-reduce to get the full gradient. + embed_grad_cp = model_cp.embed_tokens.weight.grad.detach().clone() + dist.all_reduce(embed_grad_cp, op=dist.ReduceOp.SUM, group=cp_group) + + # ===== Comparison ===== + output_diff = (output_cp_full - output_baseline_detached).abs() + grad_diff = (embed_grad_cp - embed_grad_baseline).abs() + + if rank == 0: + print(f"\n{'='*70}") + print("End-to-End Hybrid CP Test - NemotronV3 (Attention + Mamba)") + print(f"{'='*70}") + print(f"Config: {config.num_hidden_layers} layers {config.layers_block_type}") + print( + f"Sequence: batch={batch_size}, seq_len={seq_len} -> " + f"{local_seq} tokens/rank with CP={cp_size}" + ) + print(f"\nForward output:") + print(f" Shape: CP={output_cp_full.shape}, Baseline={output_baseline_detached.shape}") + print( + f" Diff - mean: {output_diff.mean().item():.6f}, " + f"max: {output_diff.max().item():.6f}, " + f"std: {output_diff.std().item():.6f}" + ) + print( + f" Relative diff - mean: " + f"{(output_diff / (output_baseline_detached.abs() + 1e-8)).mean().item():.6f}" + ) + print(f"\nEmbedding weight gradient:") + print( + f" Baseline - mean: {embed_grad_baseline.abs().mean().item():.6f}, " + f"max: {embed_grad_baseline.abs().max().item():.6f}" + ) + print( + f" CP - mean: {embed_grad_cp.abs().mean().item():.6f}, " + f"max: {embed_grad_cp.abs().max().item():.6f}" + ) + print( + f" Diff - mean: {grad_diff.mean().item():.6f}, " + f"max: {grad_diff.max().item():.6f}" + ) + + try: + torch.testing.assert_close( + output_cp_full, + output_baseline_detached, + rtol=1e-2, + atol=5e-2, + msg=f"[Rank {rank}] Forward outputs differ between CP=1 and CP=2", + ) + + torch.testing.assert_close( + embed_grad_cp, + embed_grad_baseline, + rtol=5e-2, + atol=1e-1, + msg=f"[Rank {rank}] embed_tokens.weight.grad differs between CP=1 and CP=2", + ) + + if rank == 0: + print(f"Test PASSED: Forward outputs and embedding gradients match between CP=1 and CP=2") + print(f"{'='*70}\n") + return 0 + + except AssertionError as e: + if rank == 0: + print(f"Test FAILED: {e}") + print(f"Note: Some numerical differences are expected with bfloat16 and multi-layer accumulation") + print(f"{'='*70}\n") + return 1 + + +def main(): + init_distributed() + exit_code = run_test() + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/tests/functional_tests/context_parallel/run_mamba_cp.py b/tests/functional_tests/context_parallel/run_mamba_cp.py new file mode 100644 index 000000000..3609980ea --- /dev/null +++ b/tests/functional_tests/context_parallel/run_mamba_cp.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Standalone test script for Mamba layer context parallelism validation. + +This script validates that the NemotronV3Mamba2Mixer produces identical forward +outputs and gradients when using context parallelism (CP=2, hidden-parallel +strategy) versus no context parallelism (CP=1). + +Usage: + torchrun --nproc_per_node=2 tests/functional_tests/context_parallel/run_mamba_cp.py +""" + +import os +import sys + +import torch +import torch.distributed as dist + +from nemo_automodel.components.distributed.cp_utils import _dual_chunk_swap_select + + +def dual_chunk_swap_unsplit(chunks_per_rank, cp_size, seq_dim=1): + """Reconstruct full sequence from DualChunkSwap-ordered rank outputs.""" + all_chunks = [None] * (2 * cp_size) + for rank_idx, rank_output in enumerate(chunks_per_rank): + c0, c1 = torch.chunk(rank_output, 2, dim=seq_dim) + all_chunks[rank_idx] = c0 + all_chunks[2 * cp_size - rank_idx - 1] = c1 + return torch.cat(all_chunks, dim=seq_dim) + + +def init_distributed(): + """Initialize distributed environment from torchrun env vars.""" + if not (dist.is_available() and dist.is_initialized()): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + dist.init_process_group(backend="nccl") + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + +class MockNemotronV3Config: + """Mock configuration for NemotronV3Mamba2Mixer.""" + + def __init__(self): + self.hidden_size = 256 + self.mamba_num_heads = 8 + self.mamba_head_dim = 32 + self.ssm_state_size = 16 + self.n_groups = 1 + self.chunk_size = 256 + self.conv_kernel = 4 + self.use_conv_bias = True + self.mamba_hidden_act = "silu" + self.time_step_limit = (0.0, float("inf")) + self.time_step_min = 0.001 + self.time_step_max = 0.1 + self.time_step_floor = 1e-4 + self.use_bias = False + self.layer_norm_epsilon = 1e-5 + self.num_hidden_layers = 4 + + +def run_test(): + """Run the CP validation test for NemotronV3Mamba2Mixer.""" + world_size = dist.get_world_size() + rank = dist.get_rank() + + if world_size != 2: + if rank == 0: + print(f"ERROR: This test requires exactly 2 GPUs, got {world_size}", file=sys.stderr) + return 1 + + device = torch.device(f"cuda:{rank}") + + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + + from nemo_automodel.components.models.nemotron_v3.layers import NemotronV3Mamba2Mixer + + config = MockNemotronV3Config() + + mixer_no_cp = NemotronV3Mamba2Mixer(config, layer_idx=0).to(device).to(torch.bfloat16) + mixer_with_cp = NemotronV3Mamba2Mixer(config, layer_idx=0).to(device).to(torch.bfloat16) + + mixer_no_cp.eval() + mixer_with_cp.eval() + mixer_with_cp.load_state_dict(mixer_no_cp.state_dict()) + + for param_no_cp, param_with_cp in zip(mixer_no_cp.parameters(), mixer_with_cp.parameters()): + dist.broadcast(param_no_cp.data, src=0) + dist.broadcast(param_with_cp.data, src=0) + + # Baseline: CP=1 + batch_size = 2 + seq_len = 64 + + torch.manual_seed(42 + rank) + x_full = torch.randn( + batch_size, seq_len, config.hidden_size, + device=device, dtype=torch.bfloat16, requires_grad=True, + ) + + dist.broadcast(x_full.data, src=0) + x_no_cp = x_full.detach().clone().requires_grad_(True) + + output_no_cp = mixer_no_cp(x_no_cp) + loss_no_cp = output_no_cp.sum() + loss_no_cp.backward() + + output_baseline = output_no_cp.detach().clone() + grad_baseline = x_no_cp.grad.detach().clone() + in_proj_grad_baseline = mixer_no_cp.in_proj.weight.grad.detach().clone() + + dist.barrier() + + # Test: CP=2 + from torch.distributed.device_mesh import init_device_mesh + from nemo_automodel.components.distributed.mamba_cp import MambaContextParallel + + cp_mesh = init_device_mesh("cuda", (world_size,), mesh_dim_names=("cp",)) + cp_group = cp_mesh["cp"].get_group() + + mixer_with_cp.cp = MambaContextParallel( + cp_group=cp_group, + num_heads=config.mamba_num_heads, + head_dim=config.mamba_head_dim, + n_groups=config.n_groups, + d_state=config.ssm_state_size, + conv1d=mixer_with_cp.conv1d, + dt_bias=mixer_with_cp.dt_bias, + A_log=mixer_with_cp.A_log, + D=mixer_with_cp.D, + ) + + cp_size = world_size + x_local = _dual_chunk_swap_select(x_full.detach(), cp_size=cp_size, cp_rank=rank, seq_dim=1).clone().requires_grad_(True) + half_seq = x_local.shape[1] + + output_with_cp = mixer_with_cp(x_local) + loss_with_cp = output_with_cp.sum() + loss_with_cp.backward() + + output_gathered = [ + torch.zeros(batch_size, half_seq, config.hidden_size, device=device, dtype=torch.bfloat16) + for _ in range(world_size) + ] + grad_gathered = [ + torch.zeros(batch_size, half_seq, config.hidden_size, device=device, dtype=torch.bfloat16) + for _ in range(world_size) + ] + + dist.all_gather(output_gathered, output_with_cp.contiguous()) + dist.all_gather(grad_gathered, x_local.grad.contiguous()) + + output_with_cp_full = dual_chunk_swap_unsplit(output_gathered, cp_size=cp_size, seq_dim=1) + grad_with_cp_full = dual_chunk_swap_unsplit(grad_gathered, cp_size=cp_size, seq_dim=1) + + in_proj_grad_cp = mixer_with_cp.in_proj.weight.grad.detach().clone() + dist.all_reduce(in_proj_grad_cp, op=dist.ReduceOp.SUM) + + if rank == 0: + output_diff = (output_with_cp_full - output_baseline).abs() + grad_diff = (grad_with_cp_full - grad_baseline).abs() + in_proj_grad_diff = (in_proj_grad_cp - in_proj_grad_baseline).abs() + + print(f"\n{'='*70}") + print(f"Context Parallelism Validation Test - NemotronV3 Mamba2Mixer") + print(f"{'='*70}") + print(f"Output shape: CP={output_with_cp_full.shape}, Baseline={output_baseline.shape}") + print(f"Output diff - mean: {output_diff.mean():.6f}, max: {output_diff.max():.6f}, std: {output_diff.std():.6f}") + print(f"Output relative diff - mean: {(output_diff / (output_baseline.abs() + 1e-8)).mean():.6f}") + print(f"\nInput gradient statistics:") + print(f" Baseline - min: {grad_baseline.abs().min():.6f}, max: {grad_baseline.abs().max():.6f}, mean: {grad_baseline.abs().mean():.6f}") + print(f" CP - min: {grad_with_cp_full.abs().min():.6f}, max: {grad_with_cp_full.abs().max():.6f}, mean: {grad_with_cp_full.abs().mean():.6f}") + print(f"Grad diff - mean: {grad_diff.mean():.6f}, max: {grad_diff.max():.6f}, std: {grad_diff.std():.6f}") + print(f"\nin_proj.weight.grad statistics:") + print(f" Baseline - mean: {in_proj_grad_baseline.abs().mean():.6f}, max: {in_proj_grad_baseline.abs().max():.6f}") + print(f" CP - mean: {in_proj_grad_cp.abs().mean():.6f}, max: {in_proj_grad_cp.abs().max():.6f}") + print(f" Diff - mean: {in_proj_grad_diff.mean():.6f}, max: {in_proj_grad_diff.max():.6f}") + + try: + torch.testing.assert_close( + output_with_cp_full, + output_baseline, + rtol=1e-2, + atol=0.01, + msg=f"[Rank {rank}] Forward outputs differ between CP=1 and CP=2", + ) + + torch.testing.assert_close( + grad_with_cp_full, + grad_baseline, + rtol=2e-2, + atol=0.05, + msg=f"[Rank {rank}] Input gradients differ between CP=1 and CP=2", + ) + + torch.testing.assert_close( + in_proj_grad_cp, + in_proj_grad_baseline, + rtol=5e-2, + atol=1.5, + msg=f"[Rank {rank}] in_proj.weight.grad differs between CP=1 and CP=2", + ) + + if rank == 0: + print(f"Test PASSED: Forward outputs and gradients match between CP=1 and CP=2") + print(f"{'='*70}\n") + return 0 + + except AssertionError as e: + if rank == 0: + print(f"Test FAILED: {e}") + print(f"Note: Some numerical differences are expected with bfloat16 precision") + print(f"{'='*70}\n") + return 1 + + +def main(): + init_distributed() + exit_code = run_test() + if dist.is_initialized(): + dist.barrier() + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/tests/functional_tests/context_parallel/run_nemotron_v3_attention_cp.py b/tests/functional_tests/context_parallel/run_nemotron_v3_attention_cp.py new file mode 100644 index 000000000..1ca87a5c1 --- /dev/null +++ b/tests/functional_tests/context_parallel/run_nemotron_v3_attention_cp.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NemotronV3Attention CP validation: verifies CP=2 matches CP=1 for forward + grads. + +Uses BSHD format with TE p2p CP and DualChunkSwap sequence distribution. Each rank +receives two non-contiguous chunks selected via _dual_chunk_swap_select. TE p2p CP +ring-exchanges K,V between neighboring ranks and applies the correct global causal +mask. This matches the production code path in _split_batch_bshd_for_cp / +NemotronHParallelizationStrategy. + +Usage: + torchrun --nproc_per_node=2 tests/functional_tests/context_parallel/run_nemotron_v3_attention_cp.py +""" + +import os +import sys + +import torch +import torch.distributed as dist + +from nemo_automodel.components.distributed.cp_utils import _dual_chunk_swap_select + + +def dual_chunk_swap_unsplit(chunks_per_rank, cp_size, seq_dim=1): + """Reconstruct full sequence from DualChunkSwap-ordered rank outputs.""" + all_chunks = [None] * (2 * cp_size) + for rank_idx, rank_output in enumerate(chunks_per_rank): + c0, c1 = torch.chunk(rank_output, 2, dim=seq_dim) + all_chunks[rank_idx] = c0 + all_chunks[2 * cp_size - rank_idx - 1] = c1 + return torch.cat(all_chunks, dim=seq_dim) + + +def init_distributed(): + """Initialize distributed environment from torchrun env vars.""" + if not (dist.is_available() and dist.is_initialized()): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + dist.init_process_group(backend="nccl") + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + +class MockNemotronV3AttentionConfig: + """Mock configuration for NemotronV3Attention.""" + + def __init__(self): + self.num_attention_heads = 8 + self.num_key_value_heads = 4 + self.head_dim = 32 + self.hidden_size = 256 # num_attention_heads * head_dim + self.attention_bias = False + self.attention_dropout = 0.0 + + +def run_test(): + """Run the CP validation test for NemotronV3Attention.""" + world_size = dist.get_world_size() + rank = dist.get_rank() + + if world_size != 2: + if rank == 0: + print(f"ERROR: This test requires exactly 2 GPUs, got {world_size}", file=sys.stderr) + return 1 + + try: + import transformer_engine.pytorch # noqa: F401 + except ImportError: + if rank == 0: + print("ERROR: transformer_engine is required but not installed", file=sys.stderr) + return 1 + + device = torch.device(f"cuda:{rank}") + + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + + from nemo_automodel.components.models.nemotron_v3.layers import NemotronV3Attention + from nemo_automodel.components.models.common import BackendConfig + + config = MockNemotronV3AttentionConfig() + backend = BackendConfig(linear="torch", attn="te") + + attn_no_cp = NemotronV3Attention(config, backend).to(device).to(torch.bfloat16) + attn_with_cp = NemotronV3Attention(config, backend).to(device).to(torch.bfloat16) + + attn_no_cp.train() + attn_with_cp.train() + + attn_with_cp.load_state_dict(attn_no_cp.state_dict()) + for param_no_cp, param_with_cp in zip(attn_no_cp.parameters(), attn_with_cp.parameters()): + dist.broadcast(param_no_cp.data, src=0) + dist.broadcast(param_with_cp.data, src=0) + + # ===== Baseline: CP=1 (no context parallelism) ===== + batch_size = 2 + seq_len = 128 + + torch.manual_seed(42) + x_full = torch.randn( + batch_size, seq_len, config.hidden_size, + device=device, dtype=torch.bfloat16, + ) + dist.broadcast(x_full, src=0) + + x_no_cp = x_full.detach().clone().requires_grad_(True) + + output_no_cp = attn_no_cp(x_no_cp) + loss_no_cp = output_no_cp.sum() + loss_no_cp.backward() + + output_baseline = output_no_cp.detach().clone() + grad_baseline = x_no_cp.grad.detach().clone() + q_proj_grad_baseline = attn_no_cp.q_proj.weight.grad.detach().clone() + + dist.barrier() + + # ===== Test: CP=2 (context parallelism with p2p / DualChunkSwap) ===== + # DualChunkSwap split: rank r gets chunks [r] and [2*cp_size-r-1] from the + # sequence partitioned into 2*cp_size equal pieces. + # TE p2p CP ring-exchanges K,V between neighboring ranks and applies the + # correct global causal mask. + from torch.distributed.device_mesh import init_device_mesh + from transformer_engine.pytorch.attention import DotProductAttention + + cp_mesh = init_device_mesh("cuda", (world_size,), mesh_dim_names=("cp",)) + cp_group = cp_mesh["cp"].get_group() + + assert isinstance(attn_with_cp.attn_module, DotProductAttention) + + attn_with_cp.attn_module.set_context_parallel_group( + cp_group, + torch.distributed.get_process_group_ranks(cp_group), + torch.cuda.Stream(), + cp_comm_type="p2p", + ) + + cp_size = world_size + x_local = _dual_chunk_swap_select(x_full, cp_size=cp_size, cp_rank=rank, seq_dim=1).detach().clone().requires_grad_(True) + + output_with_cp = attn_with_cp(x_local) + loss_with_cp = output_with_cp.sum() + loss_with_cp.backward() + + # Gather outputs and input grads from all ranks + local_seq = x_local.shape[1] + output_gathered = [ + torch.zeros(batch_size, local_seq, config.hidden_size, device=device, dtype=torch.bfloat16) + for _ in range(cp_size) + ] + grad_gathered = [ + torch.zeros(batch_size, local_seq, config.hidden_size, device=device, dtype=torch.bfloat16) + for _ in range(cp_size) + ] + + dist.all_gather(output_gathered, output_with_cp.contiguous()) + dist.all_gather(grad_gathered, x_local.grad.contiguous()) + + output_with_cp_full = dual_chunk_swap_unsplit(output_gathered, cp_size=cp_size, seq_dim=1) + grad_with_cp_full = dual_chunk_swap_unsplit(grad_gathered, cp_size=cp_size, seq_dim=1) + + q_proj_grad_cp = attn_with_cp.q_proj.weight.grad.detach().clone() + dist.all_reduce(q_proj_grad_cp, op=dist.ReduceOp.SUM) + + if rank == 0: + output_diff = (output_with_cp_full - output_baseline).abs() + grad_diff = (grad_with_cp_full - grad_baseline).abs() + q_proj_grad_diff = (q_proj_grad_cp - q_proj_grad_baseline).abs() + + print(f"\n{'='*70}") + print(f"Context Parallelism Validation Test - NemotronV3Attention") + print(f"{'='*70}") + print(f"Config: heads={config.num_attention_heads}, kv_heads={config.num_key_value_heads}, " + f"head_dim={config.head_dim}, hidden_size={config.hidden_size}") + print(f"Sequence: batch={batch_size}, seq_len={seq_len} -> {local_seq} tokens/rank with CP=2") + print(f"CP comm type: p2p (DualChunkSwap sequence split)") + print(f"\nForward output statistics:") + print(f" Output shape: CP={output_with_cp_full.shape}, Baseline={output_baseline.shape}") + print(f" Output diff - mean: {output_diff.mean():.6f}, max: {output_diff.max():.6f}, " + f"std: {output_diff.std():.6f}") + print(f" Output relative diff - mean: {(output_diff / (output_baseline.abs() + 1e-8)).mean():.6f}") + print(f"\nInput gradient statistics:") + print(f" Baseline - min: {grad_baseline.abs().min():.6f}, max: {grad_baseline.abs().max():.6f}, " + f"mean: {grad_baseline.abs().mean():.6f}") + print(f" CP - min: {grad_with_cp_full.abs().min():.6f}, max: {grad_with_cp_full.abs().max():.6f}, " + f"mean: {grad_with_cp_full.abs().mean():.6f}") + print(f" Grad diff - mean: {grad_diff.mean():.6f}, max: {grad_diff.max():.6f}, " + f"std: {grad_diff.std():.6f}") + print(f"\nq_proj.weight.grad statistics:") + print(f" Baseline - mean: {q_proj_grad_baseline.abs().mean():.6f}, " + f"max: {q_proj_grad_baseline.abs().max():.6f}") + print(f" CP - mean: {q_proj_grad_cp.abs().mean():.6f}, " + f"max: {q_proj_grad_cp.abs().max():.6f}") + print(f" Diff - mean: {q_proj_grad_diff.mean():.6f}, max: {q_proj_grad_diff.max():.6f}") + + try: + torch.testing.assert_close( + output_with_cp_full, + output_baseline, + rtol=1e-2, + atol=1e-2, + msg=f"[Rank {rank}] Forward outputs differ between CP=1 and CP=2", + ) + + torch.testing.assert_close( + grad_with_cp_full, + grad_baseline, + rtol=1e-2, + atol=5e-2, + msg=f"[Rank {rank}] Input gradients differ between CP=1 and CP=2", + ) + + torch.testing.assert_close( + q_proj_grad_cp, + q_proj_grad_baseline, + rtol=5e-2, + atol=5e-2, + msg=f"[Rank {rank}] q_proj.weight.grad differs between CP=1 and CP=2", + ) + + if rank == 0: + print(f"Test PASSED: Forward outputs and gradients match between CP=1 and CP=2") + print(f"{'='*70}\n") + return 0 + + except AssertionError as e: + if rank == 0: + print(f"Test FAILED: {e}") + print(f"Note: Some numerical differences are expected with bfloat16 precision") + print(f"{'='*70}\n") + return 1 + + +def main(): + init_distributed() + exit_code = run_test() + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/tests/functional_tests/context_parallel/test_context_parallel.py b/tests/functional_tests/context_parallel/test_context_parallel.py index a423e4429..6d9911c0e 100644 --- a/tests/functional_tests/context_parallel/test_context_parallel.py +++ b/tests/functional_tests/context_parallel/test_context_parallel.py @@ -23,6 +23,8 @@ TEST_FOLDER = "context_parallel" CP_QWEN3_MOE_ATTENTION_TEST_FILENAME = "L2_CP_Qwen3MoE_Attention_Test.sh" CP_DEEPSEEK_V3_MLA_TEST_FILENAME = "L2_CP_DeepSeekV3_MLA_Test.sh" +CP_NEMOTRON_V3_MAMBA_TEST_FILENAME = "L2_CP_NemotronV3_Mamba_Test.sh" +CP_NEMOTRON_V3_ATTENTION_TEST_FILENAME = "L2_CP_NemotronV3_Attention_Test.sh" class TestContextParallelAttention: @@ -35,3 +37,11 @@ def test_cp_qwen3_moe_attention(self): def test_cp_deepseek_v3_mla(self): """Test DeepSeek V3 MLA layer with CP=1 vs CP=2.""" run_test_script(TEST_FOLDER, CP_DEEPSEEK_V3_MLA_TEST_FILENAME) + + def test_cp_nemotron_v3_mamba(self): + """Test NemotronV3Mamba2Mixer layer with CP=1 vs CP=2.""" + run_test_script(TEST_FOLDER, CP_NEMOTRON_V3_MAMBA_TEST_FILENAME) + + def test_cp_nemotron_v3_attention(self): + """Test NemotronV3Attention layer with CP=1 vs CP=2.""" + run_test_script(TEST_FOLDER, CP_NEMOTRON_V3_ATTENTION_TEST_FILENAME) diff --git a/tests/unit_tests/distributed/test_mamba_cp.py b/tests/unit_tests/distributed/test_mamba_cp.py new file mode 100644 index 000000000..ebf3ae130 --- /dev/null +++ b/tests/unit_tests/distributed/test_mamba_cp.py @@ -0,0 +1,594 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for :pyfile:`nemo_automodel/components/distributed/mamba_cp.py`. + +Tests mock the distributed process group so they can run on CPU-only CI +systems while still verifying dimension calculations, parameter slicing, +group replication logic, and activation shape transformations. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest +import torch +import torch.nn as nn + +from nemo_automodel.components.distributed.mamba_cp import MambaContextParallel + +# --------------------------------------------------------------------------- +# Lightweight stubs for torch.distributed.ProcessGroup +# --------------------------------------------------------------------------- + +class _FakeProcessGroup: + """Minimal stub emulating ``torch.distributed.ProcessGroup`` for unit tests. + + Only ``size()`` and ``rank()`` are used by ``MambaContextParallel.__init__`` + and the parameter-slicing helpers. + """ + + def __init__(self, size: int, rank: int = 0): + self._size = size + self._rank = rank + + def size(self) -> int: + return self._size + + def rank(self) -> int: + return self._rank + +# --------------------------------------------------------------------------- +# Helpers to build a MambaContextParallel instance with dummy parameters +# --------------------------------------------------------------------------- + +def _make_conv1d(d_inner: int, n_groups: int, d_state: int, kernel_size: int = 4) -> nn.Conv1d: + """Create a Conv1d whose weight/bias values are deterministic (arange-based).""" + conv_dim = d_inner + 2 * n_groups * d_state + conv = nn.Conv1d(in_channels=conv_dim, out_channels=conv_dim, kernel_size=kernel_size, groups=conv_dim, bias=True) + # Fill weight with sequential values for easy verification. + with torch.no_grad(): + conv.weight.copy_(torch.arange(conv_dim * kernel_size, dtype=torch.float32).reshape(conv_dim, 1, kernel_size)) + conv.bias.copy_(torch.arange(conv_dim, dtype=torch.float32)) + return conv + +def _make_mamba_cp( + num_heads: int, + head_dim: int, + n_groups: int, + d_state: int, + cp_size: int, + cp_rank: int = 0, + kernel_size: int = 4, +) -> MambaContextParallel: + """Construct a ``MambaContextParallel`` with deterministic dummy parameters.""" + pg = _FakeProcessGroup(size=cp_size, rank=cp_rank) + conv1d = _make_conv1d(num_heads * head_dim, n_groups, d_state, kernel_size) + dt_bias = torch.arange(num_heads, dtype=torch.float32) + A_log = torch.arange(num_heads, dtype=torch.float32) + 100.0 + D = torch.arange(num_heads, dtype=torch.float32) + 200.0 + return MambaContextParallel( + cp_group=pg, + num_heads=num_heads, + head_dim=head_dim, + n_groups=n_groups, + d_state=d_state, + conv1d=conv1d, + dt_bias=dt_bias, + A_log=A_log, + D=D, + ) + +@pytest.mark.parametrize( + "num_heads, n_groups, cp_size, expected_heads_local, expected_d_inner_local, expected_n_groups_local, expected_repeat", + [ + # Basic: groups == heads, evenly divisible + (8, 8, 2, 4, 4 * 2, 4, 1), + (8, 8, 4, 2, 2 * 2, 2, 1), + # More groups than cp_size + (16, 8, 4, 4, 4 * 2, 2, 1), + # n_groups < cp_size -> replication + (8, 1, 2, 4, 4 * 2, 1, 2), + (8, 2, 4, 2, 2 * 2, 1, 2), + (16, 1, 4, 4, 4 * 2, 1, 4), + # cp_size == 1 (no parallelism) + (8, 4, 1, 8, 8 * 2, 4, 1), + ], + ids=[ + "groups_eq_heads_cp2", + "groups_eq_heads_cp4", + "more_groups_cp4", + "1group_cp2", + "2groups_cp4", + "1group_cp4", + "cp1_noop", + ], +) +def test_dimension_calculations( + num_heads, + n_groups, + cp_size, + expected_heads_local, + expected_d_inner_local, + expected_n_groups_local, + expected_repeat, +): + """Verify computed per-rank dimensions for various (num_heads, n_groups, cp_size) combos.""" + head_dim = 2 + mcp = _make_mamba_cp(num_heads=num_heads, head_dim=head_dim, n_groups=n_groups, d_state=4, cp_size=cp_size) + + assert mcp.num_heads_local == expected_heads_local + assert mcp.d_inner_local == expected_d_inner_local + assert mcp.n_groups_local == expected_n_groups_local + assert mcp.group_repeat_count == expected_repeat + assert mcp.d_inner == num_heads * head_dim + +def test_validation_error_heads_not_divisible_by_cp(): + """num_heads % cp_size != 0 must raise AssertionError.""" + with pytest.raises(AssertionError, match="num_heads.*must be divisible by cp_size"): + _make_mamba_cp(num_heads=7, head_dim=2, n_groups=1, d_state=4, cp_size=4) + +def test_validation_error_cp_not_divisible_by_groups(): + """When n_groups < cp_size, cp_size % n_groups != 0 must raise.""" + with pytest.raises(AssertionError, match="cp_size.*must be divisible by n_groups"): + _make_mamba_cp(num_heads=12, head_dim=2, n_groups=3, d_state=4, cp_size=4) + +def test_validation_error_groups_not_divisible_by_cp(): + """When n_groups >= cp_size, n_groups % cp_size != 0 must raise.""" + with pytest.raises(AssertionError, match="n_groups.*must be divisible by cp_size"): + _make_mamba_cp(num_heads=12, head_dim=2, n_groups=5, d_state=4, cp_size=4) + +class TestParameterSlicing: + """Verify that get_conv1d_weight, get_conv1d_bias, get_dt_bias, get_A_log, get_D + return correct slices per CP rank.""" + + NUM_HEADS = 8 + HEAD_DIM = 2 + N_GROUPS = 4 + D_STATE = 3 + KERNEL_SIZE = 4 + CP_SIZE = 2 + + @property + def d_inner(self) -> int: + return self.NUM_HEADS * self.HEAD_DIM + + @property + def groups_state_size(self) -> int: + return self.N_GROUPS * self.D_STATE + + @property + def conv_dim(self) -> int: + return self.d_inner + 2 * self.groups_state_size + + def _build(self, rank: int) -> MambaContextParallel: + return _make_mamba_cp( + num_heads=self.NUM_HEADS, + head_dim=self.HEAD_DIM, + n_groups=self.N_GROUPS, + d_state=self.D_STATE, + cp_size=self.CP_SIZE, + cp_rank=rank, + kernel_size=self.KERNEL_SIZE, + ) + + def test_dt_bias_slicing(self): + """dt_bias[num_heads] should be sliced into contiguous chunks per rank.""" + for rank in range(self.CP_SIZE): + mcp = self._build(rank) + sliced = mcp.get_dt_bias() + expected_start = rank * (self.NUM_HEADS // self.CP_SIZE) + expected = torch.arange(self.NUM_HEADS, dtype=torch.float32)[ + expected_start : expected_start + mcp.num_heads_local + ] + assert torch.equal(sliced, expected), f"dt_bias mismatch on rank {rank}" + + def test_A_log_slicing(self): + """A_log[num_heads] should be sliced into contiguous chunks per rank.""" + for rank in range(self.CP_SIZE): + mcp = self._build(rank) + sliced = mcp.get_A_log() + expected_start = rank * mcp.num_heads_local + expected = torch.arange(self.NUM_HEADS, dtype=torch.float32)[ + expected_start : expected_start + mcp.num_heads_local + ] + 100.0 + assert torch.equal(sliced, expected), f"A_log mismatch on rank {rank}" + + def test_D_slicing(self): + """D[num_heads] should be sliced into contiguous chunks per rank.""" + for rank in range(self.CP_SIZE): + mcp = self._build(rank) + sliced = mcp.get_D() + expected_start = rank * mcp.num_heads_local + expected = torch.arange(self.NUM_HEADS, dtype=torch.float32)[ + expected_start : expected_start + mcp.num_heads_local + ] + 200.0 + assert torch.equal(sliced, expected), f"D mismatch on rank {rank}" + + def test_conv1d_weight_slicing_shape(self): + """conv1d weight [conv_dim, 1, K] -> [conv_dim_local, K] per rank (squeezed).""" + for rank in range(self.CP_SIZE): + mcp = self._build(rank) + w = mcp.get_conv1d_weight() + d_inner_local = self.d_inner // self.CP_SIZE + n_groups_local = self.N_GROUPS // self.CP_SIZE + conv_dim_local = d_inner_local + 2 * n_groups_local * self.D_STATE + assert w.shape == (conv_dim_local, self.KERNEL_SIZE), f"Weight shape mismatch rank {rank}" + + def test_conv1d_bias_slicing_shape(self): + """conv1d bias [conv_dim] -> [conv_dim_local] per rank.""" + for rank in range(self.CP_SIZE): + mcp = self._build(rank) + b = mcp.get_conv1d_bias() + d_inner_local = self.d_inner // self.CP_SIZE + n_groups_local = self.N_GROUPS // self.CP_SIZE + conv_dim_local = d_inner_local + 2 * n_groups_local * self.D_STATE + assert b.shape == (conv_dim_local,), f"Bias shape mismatch rank {rank}" + + def test_conv1d_weight_slicing_values(self): + """Verify that the x-portion of conv1d weight is correctly sliced per rank.""" + # The full weight is filled with arange(conv_dim * K).reshape(conv_dim, 1, K). + # x-portion occupies rows [0, d_inner). Each rank gets d_inner/cp_size rows. + # get_conv1d_weight() squeezes dim-1, so expected shape is [rows, K]. + full_weight = torch.arange(self.conv_dim * self.KERNEL_SIZE, dtype=torch.float32).reshape( + self.conv_dim, self.KERNEL_SIZE + ) + d_inner_local = self.d_inner // self.CP_SIZE + + for rank in range(self.CP_SIZE): + mcp = self._build(rank) + sliced = mcp.get_conv1d_weight() + x_start = rank * d_inner_local + x_expected = full_weight[x_start : x_start + d_inner_local] + x_actual = sliced[:d_inner_local] + assert torch.equal(x_actual, x_expected), f"Weight x-portion mismatch rank {rank}" + + def test_conv1d_bias_none(self): + """When conv1d has no bias, get_conv1d_bias() returns None.""" + pg = _FakeProcessGroup(size=2, rank=0) + conv = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + kernel_size=self.KERNEL_SIZE, + groups=self.conv_dim, + bias=False, + ) + dt_bias = torch.zeros(self.NUM_HEADS) + A_log = torch.zeros(self.NUM_HEADS) + D = torch.zeros(self.NUM_HEADS) + mcp = MambaContextParallel( + cp_group=pg, + num_heads=self.NUM_HEADS, + head_dim=self.HEAD_DIM, + n_groups=self.N_GROUPS, + d_state=self.D_STATE, + conv1d=conv, + dt_bias=dt_bias, + A_log=A_log, + D=D, + ) + assert mcp.get_conv1d_bias() is None + +class TestParameterSlicingWithReplication: + """When n_groups < cp_size, B/C conv param slicing uses group_repeat_count.""" + + NUM_HEADS = 8 + HEAD_DIM = 2 + N_GROUPS = 1 # < cp_size -> replication + D_STATE = 3 + KERNEL_SIZE = 4 + CP_SIZE = 2 + + @property + def d_inner(self) -> int: + return self.NUM_HEADS * self.HEAD_DIM + + @property + def groups_state_size(self) -> int: + return self.N_GROUPS * self.D_STATE + + @property + def conv_dim(self) -> int: + return self.d_inner + 2 * self.groups_state_size + + def _build(self, rank: int) -> MambaContextParallel: + return _make_mamba_cp( + num_heads=self.NUM_HEADS, + head_dim=self.HEAD_DIM, + n_groups=self.N_GROUPS, + d_state=self.D_STATE, + cp_size=self.CP_SIZE, + cp_rank=rank, + kernel_size=self.KERNEL_SIZE, + ) + + def test_replicated_groups_conv_weight(self): + """With n_groups=1, cp_size=2 all ranks should get the same B/C slice.""" + slices = [] + for rank in range(self.CP_SIZE): + mcp = self._build(rank) + w = mcp.get_conv1d_weight() + d_inner_local = self.d_inner // self.CP_SIZE + bc_portion = w[d_inner_local:] + slices.append(bc_portion) + + # Both ranks replicate the single group, so B/C slices must be identical. + assert torch.equal(slices[0], slices[1]), ( + "With n_groups < cp_size, B/C conv param slices should be identical across ranks" + ) + + def test_replicated_groups_conv_weight_shape(self): + """Verify conv weight shape with replication.""" + mcp = self._build(0) + w = mcp.get_conv1d_weight() + d_inner_local = self.d_inner // self.CP_SIZE + bc_size_local = mcp.n_groups_local * self.D_STATE + expected_conv_dim_local = d_inner_local + 2 * bc_size_local + assert w.shape == (expected_conv_dim_local, self.KERNEL_SIZE) + +class TestGroupReplication: + """Verify B/C state replication via expand+reshape when n_groups < cp_size.""" + + def test_bc_state_expansion(self): + """When n_groups=1 and cp_size=2, B/C states should be doubled before all-to-all.""" + num_heads = 8 + head_dim = 2 + n_groups = 1 + d_state = 3 + cp_size = 2 + d_inner = num_heads * head_dim + groups_state_size = n_groups * d_state + + mcp = _make_mamba_cp( + num_heads=num_heads, + head_dim=head_dim, + n_groups=n_groups, + d_state=d_state, + cp_size=cp_size, + cp_rank=0, + ) + + B, L_local = 2, 4 + proj_dim = d_inner + d_inner + groups_state_size + groups_state_size + num_heads + projected = torch.randn(B, L_local, proj_dim) + + captured_calls = [] + + def fake_cp2hp(tensor, cp_group, batch_size): + captured_calls.append(tensor.clone()) + H = tensor.shape[-1] + H_local = H // cp_size + return torch.randn(batch_size, L_local * cp_size, H_local) + + with patch("nemo_automodel.components.distributed.mamba_cp._all_to_all_cp2hp", side_effect=fake_cp2hp): + mcp.pre_conv_ssm(projected) + + assert len(captured_calls) == 5, f"Expected 5 all-to-all calls, got {len(captured_calls)}" + + b_state_input = captured_calls[2] + c_state_input = captured_calls[3] + expected_expanded_dim = n_groups * mcp.group_repeat_count * d_state # 1*2*3 = 6 + assert b_state_input.shape == (B, L_local, expected_expanded_dim), ( + f"B_state should be expanded to dim {expected_expanded_dim}, got {b_state_input.shape[-1]}" + ) + assert c_state_input.shape == (B, L_local, expected_expanded_dim), ( + f"C_state should be expanded to dim {expected_expanded_dim}, got {c_state_input.shape[-1]}" + ) + + def test_no_expansion_when_groups_ge_cp(self): + """When n_groups >= cp_size, B/C states should NOT be expanded.""" + num_heads = 8 + head_dim = 2 + n_groups = 4 + d_state = 3 + cp_size = 2 + d_inner = num_heads * head_dim + groups_state_size = n_groups * d_state + + mcp = _make_mamba_cp( + num_heads=num_heads, + head_dim=head_dim, + n_groups=n_groups, + d_state=d_state, + cp_size=cp_size, + cp_rank=0, + ) + assert mcp.group_repeat_count == 1 + + B, L_local = 2, 4 + proj_dim = d_inner + d_inner + groups_state_size + groups_state_size + num_heads + projected = torch.randn(B, L_local, proj_dim) + + captured_calls = [] + + def fake_cp2hp(tensor, cp_group, batch_size): + captured_calls.append(tensor.clone()) + H = tensor.shape[-1] + H_local = H // cp_size + return torch.randn(batch_size, L_local * cp_size, H_local) + + with patch("nemo_automodel.components.distributed.mamba_cp._all_to_all_cp2hp", side_effect=fake_cp2hp): + mcp.pre_conv_ssm(projected) + + b_state_input = captured_calls[2] + assert b_state_input.shape == (B, L_local, groups_state_size) + +class TestPrePostConvSsmShapes: + """Verify shape transformations of pre_conv_ssm and post_conv_ssm. + + All-to-all calls are mocked to avoid needing real distributed backends. + The mock simulates the correct shape transformation that all-to-all would + perform (sequence-sharded <-> hidden-sharded). + """ + + @pytest.mark.parametrize("cp_size", [2, 4]) + def test_pre_conv_ssm_output_shape(self, cp_size): + """pre_conv_ssm: [B, L/cp, proj_dim] -> [B, L, proj_dim/cp].""" + num_heads = 8 + head_dim = 2 + n_groups = 4 + d_state = 3 + d_inner = num_heads * head_dim + groups_state_size = n_groups * d_state + + mcp = _make_mamba_cp( + num_heads=num_heads, + head_dim=head_dim, + n_groups=n_groups, + d_state=d_state, + cp_size=cp_size, + cp_rank=0, + ) + + B = 2 + L_local = 8 + L_global = L_local * cp_size + proj_dim = d_inner + d_inner + groups_state_size + groups_state_size + num_heads + projected = torch.randn(B, L_local, proj_dim) + + def fake_cp2hp(tensor, cp_group, batch_size): + B_t, L_t, H_t = tensor.shape + H_local = H_t // cp_size + return torch.randn(B_t, L_t * cp_size, H_local) + + with patch("nemo_automodel.components.distributed.mamba_cp._all_to_all_cp2hp", side_effect=fake_cp2hp): + output = mcp.pre_conv_ssm(projected) + + d_inner_local = d_inner // cp_size + n_groups_local = n_groups // cp_size + groups_state_local = n_groups_local * d_state + num_heads_local = num_heads // cp_size + proj_dim_local = d_inner_local + d_inner_local + groups_state_local + groups_state_local + num_heads_local + + assert output.shape == (B, L_global, proj_dim_local), ( + f"Expected ({B}, {L_global}, {proj_dim_local}), got {output.shape}" + ) + + @pytest.mark.parametrize("cp_size", [2, 4]) + def test_post_conv_ssm_output_shape(self, cp_size): + """post_conv_ssm: [B, L, d_inner/cp] -> [B, L/cp, d_inner].""" + num_heads = 8 + head_dim = 2 + n_groups = 4 + d_state = 3 + d_inner = num_heads * head_dim + + mcp = _make_mamba_cp( + num_heads=num_heads, + head_dim=head_dim, + n_groups=n_groups, + d_state=d_state, + cp_size=cp_size, + cp_rank=0, + ) + + B = 2 + L_global = 16 + L_local = L_global // cp_size + d_inner_local = d_inner // cp_size + ssm_output = torch.randn(B, L_global, d_inner_local) + + def fake_hp2cp(tensor, cp_group, batch_size): + B_t, L_t, H_t = tensor.shape + L_out = L_t // cp_size + H_out = H_t * cp_size + return torch.randn(B_t, L_out, H_out) + + with patch("nemo_automodel.components.distributed.mamba_cp._all_to_all_hp2cp", side_effect=fake_hp2cp): + output = mcp.post_conv_ssm(ssm_output) + + assert output.shape == (B, L_local, d_inner), ( + f"Expected ({B}, {L_local}, {d_inner}), got {output.shape}" + ) + + def test_pre_conv_ssm_noop_cp1(self): + """When cp_size == 1, pre_conv_ssm should return the input unchanged.""" + mcp = _make_mamba_cp(num_heads=4, head_dim=2, n_groups=2, d_state=3, cp_size=1, cp_rank=0) + inp = torch.randn(2, 8, 4 * 2 + 4 * 2 + 2 * 3 + 2 * 3 + 4) + out = mcp.pre_conv_ssm(inp) + assert out is inp, "pre_conv_ssm should be identity when cp_size==1" + + def test_post_conv_ssm_noop_cp1(self): + """When cp_size == 1, post_conv_ssm should return the input unchanged.""" + mcp = _make_mamba_cp(num_heads=4, head_dim=2, n_groups=2, d_state=3, cp_size=1, cp_rank=0) + inp = torch.randn(2, 8, 4 * 2) + out = mcp.post_conv_ssm(inp) + assert out is inp, "post_conv_ssm should be identity when cp_size==1" + +class TestAllToAllLayoutTransforms: + """Test _all_to_all_cp2hp and _all_to_all_hp2cp with mocked all-to-all. + + We mock _all_to_all (the autograd wrapper) to simulate an identity all-to-all + (single rank), which lets us verify the reshape/permute logic without a real PG. + """ + + def test_cp2hp_shape(self): + """Verify _all_to_all_cp2hp output shape with identity all-to-all.""" + from nemo_automodel.components.distributed.mamba_cp import _all_to_all_cp2hp + + cp_size = 2 + B, L_local, H = 2, 4, 8 + pg = _FakeProcessGroup(size=cp_size, rank=0) + + inp = torch.randn(B, L_local, H) + + with patch("nemo_automodel.components.distributed.mamba_cp._all_to_all", side_effect=lambda t, g: t): + out = _all_to_all_cp2hp(inp, pg, B) + + assert out.shape == (B, L_local * cp_size, H // cp_size) + + def test_hp2cp_shape(self): + """Verify _all_to_all_hp2cp output shape with identity all-to-all.""" + from nemo_automodel.components.distributed.mamba_cp import _all_to_all_hp2cp + + cp_size = 2 + B, L_global, H_local = 2, 8, 4 + pg = _FakeProcessGroup(size=cp_size, rank=0) + + inp = torch.randn(B, L_global, H_local) + + with patch("nemo_automodel.components.distributed.mamba_cp._all_to_all", side_effect=lambda t, g: t): + out = _all_to_all_hp2cp(inp, pg, B) + + assert out.shape == (B, L_global // cp_size, H_local * cp_size) + +def test_cp_size_1_is_identity(): + """When cp_size == 1, all dimension calculations should match the unpartitioned case.""" + mcp = _make_mamba_cp(num_heads=8, head_dim=4, n_groups=4, d_state=16, cp_size=1, cp_rank=0) + + assert mcp.num_heads_local == 8 + assert mcp.d_inner_local == 32 + assert mcp.n_groups_local == 4 + assert mcp.group_repeat_count == 1 + + assert mcp.get_dt_bias().shape == (8,) + assert mcp.get_A_log().shape == (8,) + assert mcp.get_D().shape == (8,) + + w = mcp.get_conv1d_weight() + assert w.shape[0] == 32 + 2 * 4 * 16 + +def test_parameter_slices_allow_gradient_flow(): + """Sliced parameters should maintain gradient connectivity to the originals.""" + mcp = _make_mamba_cp(num_heads=4, head_dim=2, n_groups=2, d_state=3, cp_size=2, cp_rank=0) + + dt_slice = mcp.get_dt_bias() + assert dt_slice.data_ptr() == mcp.dt_bias[:2].data_ptr() + + a_slice = mcp.get_A_log() + assert a_slice.data_ptr() == mcp.A_log[:2].data_ptr() + + d_slice = mcp.get_D() + assert d_slice.data_ptr() == mcp.D[:2].data_ptr() diff --git a/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_layers.py b/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_layers.py index c90687831..1ba157895 100644 --- a/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_layers.py +++ b/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_layers.py @@ -128,34 +128,34 @@ def test_attention_init_with_bias(self, config): assert attn.v_proj.bias is not None assert attn.o_proj.bias is not None + @skip_if_no_gpu def test_attention_forward_shape(self, config): """Test attention forward pass produces correct shapes.""" - attn = NemotronV3Attention(config) + attn = NemotronV3Attention(config).cuda() batch_size, seq_len = 2, 16 - hidden_states = torch.randn(batch_size, seq_len, config.hidden_size) + hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device="cuda", dtype=torch.bfloat16) output = attn(hidden_states) assert output.shape == (batch_size, seq_len, config.hidden_size) + @skip_if_no_gpu def test_attention_forward_with_mask(self, config): """Test attention forward pass with attention mask.""" - attn = NemotronV3Attention(config) + attn = NemotronV3Attention(config).cuda() batch_size, seq_len = 2, 8 - hidden_states = torch.randn(batch_size, seq_len, config.hidden_size) + hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device="cuda", dtype=torch.bfloat16) - # Create 4D causal mask - attention_mask = torch.zeros(batch_size, 1, seq_len, seq_len) - attention_mask = attention_mask.masked_fill( - torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool(), float("-inf") - ) + # 2D padding mask (1=valid, 0=pad) — TE handles causality internally + attention_mask = torch.ones(batch_size, seq_len, device="cuda", dtype=torch.long) output = attn(hidden_states, attention_mask=attention_mask) assert output.shape == (batch_size, seq_len, config.hidden_size) + @skip_if_no_gpu def test_attention_gqa_with_different_kv_heads(self): """Test GQA with different number of key-value heads.""" config = MockNemotronV3Config( @@ -164,25 +164,22 @@ def test_attention_gqa_with_different_kv_heads(self): head_dim=32, hidden_size=512, ) - attn = NemotronV3Attention(config) + attn = NemotronV3Attention(config).cuda() batch_size, seq_len = 2, 8 - hidden_states = torch.randn(batch_size, seq_len, config.hidden_size) + hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device="cuda", dtype=torch.bfloat16) output = attn(hidden_states) assert output.shape == (batch_size, seq_len, config.hidden_size) - # Verify projection dimensions - assert attn.q_proj.out_features == config.num_attention_heads * config.head_dim - assert attn.k_proj.out_features == config.num_key_value_heads * config.head_dim - assert attn.v_proj.out_features == config.num_key_value_heads * config.head_dim + @skip_if_no_gpu def test_attention_init_weights(self, config): """Test attention weight initialization.""" config.attention_bias = True - attn = NemotronV3Attention(config) + attn = NemotronV3Attention(config).cuda() - device = torch.device("cpu") + device = torch.device("cuda") attn.init_weights( num_hidden_layers=config.num_hidden_layers, rescale_prenorm_residual=True, @@ -193,12 +190,13 @@ def test_attention_init_weights(self, config): assert torch.allclose(attn.q_proj.bias, torch.zeros_like(attn.q_proj.bias)) assert torch.allclose(attn.k_proj.bias, torch.zeros_like(attn.k_proj.bias)) + @skip_if_no_gpu def test_attention_forward_single_token(self, config): """Test attention with single token (seqlen=1).""" - attn = NemotronV3Attention(config) + attn = NemotronV3Attention(config).cuda() batch_size, seq_len = 2, 1 - hidden_states = torch.randn(batch_size, seq_len, config.hidden_size) + hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device="cuda", dtype=torch.bfloat16) output = attn(hidden_states) @@ -283,13 +281,14 @@ def test_block_init_invalid_type(self, config, backend): with pytest.raises(ValueError, match="Invalid block_type"): NemotronV3Block(config, layer_idx=0, moe_config=None, backend=backend) + @skip_if_no_gpu def test_block_forward_attention(self, config, backend): """Test block forward pass with attention layer.""" config.layers_block_type = ["attention"] - block = NemotronV3Block(config, layer_idx=0, moe_config=None, backend=backend) + block = NemotronV3Block(config, layer_idx=0, moe_config=None, backend=backend).cuda() batch_size, seq_len = 2, 8 - hidden_states = torch.randn(batch_size, seq_len, config.hidden_size) + hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device="cuda", dtype=torch.bfloat16) output = block(hidden_states) @@ -527,28 +526,31 @@ def backend(self): enable_hf_state_dict_adapter=False, ) + @skip_if_no_gpu def test_attention_no_cache_args(self, config): """Verify attn(hidden) still works without cache args.""" - attn = NemotronV3Attention(config) - hidden = torch.randn(2, 8, config.hidden_size) + attn = NemotronV3Attention(config).cuda() + hidden = torch.randn(2, 8, config.hidden_size, device="cuda", dtype=torch.bfloat16) out = attn(hidden) assert out.shape == (2, 8, config.hidden_size) + @skip_if_no_gpu def test_attention_mask_only(self, config): """Verify attn(hidden, attention_mask=...) still works without cache args.""" - attn = NemotronV3Attention(config) + attn = NemotronV3Attention(config).cuda() batch_size, seq_len = 2, 8 - hidden = torch.randn(batch_size, seq_len, config.hidden_size) - mask = torch.zeros(batch_size, 1, seq_len, seq_len) - mask.masked_fill_(torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool(), float("-inf")) + hidden = torch.randn(batch_size, seq_len, config.hidden_size, device="cuda", dtype=torch.bfloat16) + # 2D padding mask (1=valid, 0=pad) — TE handles causality internally + mask = torch.ones(batch_size, seq_len, device="cuda", dtype=torch.long) out = attn(hidden, attention_mask=mask) assert out.shape == (batch_size, seq_len, config.hidden_size) + @skip_if_no_gpu def test_block_attention_no_cache_args(self, config, backend): """Verify block(hidden) still works for attention block without cache args.""" config.layers_block_type = ["attention"] - block = NemotronV3Block(config, layer_idx=0, moe_config=None, backend=backend) - hidden = torch.randn(2, 8, config.hidden_size) + block = NemotronV3Block(config, layer_idx=0, moe_config=None, backend=backend).cuda() + hidden = torch.randn(2, 8, config.hidden_size, device="cuda", dtype=torch.bfloat16) out = block(hidden) assert out.shape == (2, 8, config.hidden_size) diff --git a/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_model.py b/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_model.py index 9f85feb02..18c276cb1 100644 --- a/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_model.py +++ b/tests/unit_tests/models/nemotron_v3/test_nemotron_v3_model.py @@ -674,18 +674,18 @@ def test_attention_with_cache_prefill_decode(self, config, backend): from nemo_automodel.components.models.nemotron_v3.cache import NemotronHybridCache from nemo_automodel.components.models.nemotron_v3.layers import NemotronV3Attention - attn = NemotronV3Attention(config) + attn = NemotronV3Attention(config, backend=backend).to(torch.bfloat16) batch_size, prompt_len = 2, 4 - cache = NemotronHybridCache(config, batch_size, torch.float32, torch.device("cpu")) + cache = NemotronHybridCache(config, batch_size, torch.bfloat16, torch.device("cpu")) # Prefill - hidden = torch.randn(batch_size, prompt_len, config.hidden_size) + hidden = torch.randn(batch_size, prompt_len, config.hidden_size, dtype=torch.bfloat16) out = attn(hidden, past_key_values=cache, layer_idx=0) assert out.shape == (batch_size, prompt_len, config.hidden_size) assert cache.get_seq_length(0) == prompt_len # Decode - hidden_decode = torch.randn(batch_size, 1, config.hidden_size) + hidden_decode = torch.randn(batch_size, 1, config.hidden_size, dtype=torch.bfloat16) out = attn(hidden_decode, past_key_values=cache, layer_idx=0) assert out.shape == (batch_size, 1, config.hidden_size) assert cache.get_seq_length(0) == prompt_len + 1