Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions nemo_automodel/components/distributed/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,53 @@ def _register(cls):
return _register


def _pre_shard_combined_projections(
module: nn.Module,
mesh: DeviceMesh,
mp_policy: Optional[MixedPrecisionPolicy],
offload_policy: Optional[OffloadPolicy] = None,
) -> None:
"""Pre-shard combined projection modules (qkv_proj, gate_up_proj) on dim 1.

Combined QKV and gate_up projections use interleaved layouts on dim 0
(grouping Q/K/V rows or gate/up rows). Standard FSDP ``Shard(0)`` can
break these group boundaries when ``num_kv_heads`` does not divide evenly
by the FSDP shard count, causing reshape failures in the state-dict adapter.

Sharding on dim 1 keeps dim 0 intact so reshape / split operations work
for any model configuration. 1-D tensors (biases) use ``Shard(0)`` because
FSDP's ``shard_placement_fn`` only accepts ``Shard`` placements; the
state-dict adapter's ``_gather_1d_if_needed`` gathers them when the local
shard doesn't align with interleaved group boundaries.

Follows the same pattern as MoE expert sharding
(``nemo_automodel/components/moe/parallelizer.py``).
"""
try:
from nemo_automodel.components.models.common.combined_projection.combined_mlp import CombinedGateUpMLP
from nemo_automodel.components.models.common.combined_projection.combined_qkv import CombinedQKVAttentionMixin
except ImportError:
return

_shard_fn = lambda p: Shard(1) if p.ndim >= 2 else Shard(0)

for sub in module.modules():
target = None
if isinstance(sub, CombinedQKVAttentionMixin) and hasattr(sub, "qkv_proj"):
target = sub.qkv_proj
elif isinstance(sub, CombinedGateUpMLP) and hasattr(sub, "gate_up_proj"):
target = sub.gate_up_proj

if target is not None and not isinstance(target, FSDPModule):
fully_shard(
target,
mesh=mesh,
mp_policy=mp_policy,
offload_policy=offload_policy,
shard_placement_fn=_shard_fn,
)


def apply_fsdp2_sharding_recursively(
module: nn.Module,
mesh: DeviceMesh,
Expand Down Expand Up @@ -441,6 +488,10 @@ def apply_fsdp2_sharding_recursively(
if isinstance(child_module, nn.ModuleList):
apply_fsdp2_sharding_recursively(child_module, mesh, mp_policy, offload_policy)
else:
# Pre-shard combined projection submodules on dim 1 so that
# the parent fully_shard (dim 0) skips them automatically.
_pre_shard_combined_projections(child_module, mesh, mp_policy, offload_policy)

# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(module) - 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import torch.nn as nn
from transformers.activations import ACT2FN

from nemo_automodel.components.models.common.combined_projection.combined_qkv import _assert_colwise_parallel


class CombinedGateUpMLP(nn.Module):
"""SwiGLU MLP with combined gate_up projection for efficiency.
Expand Down Expand Up @@ -60,26 +62,30 @@ def __init__(self, config):
self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
self._tp_checked = False

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with combined gate_up projection.

Handles tensor parallelism by dynamically computing split sizes
based on actual tensor dimensions.
The gate_up weight uses a row-interleaved layout:
[gate_0, up_0, gate_1, up_1, ...]
This ensures ColwiseParallel TP sharding gives each rank matching
gate/up pairs. We de-interleave via a local reshape.

Args:
x: Input tensor [batch, seq_len, hidden_size]

Returns:
Output tensor [batch, seq_len, hidden_size]
"""
# Project and split into gate and up
gate_up = self.gate_up_proj(x)
if not self._tp_checked:
_assert_colwise_parallel(self.gate_up_proj.weight, "gate_up_proj")
self._tp_checked = True

# Handle tensor parallelism: split based on actual tensor size
gate_up_size = gate_up.shape[-1]
local_intermediate_size = gate_up_size // 2
gate, up = gate_up.split([local_intermediate_size, local_intermediate_size], dim=-1)
gate_up = self.gate_up_proj(x)
gate_up = gate_up.unflatten(-1, (-1, 2))
gate = gate_up[..., 0]
up = gate_up[..., 1]

# SwiGLU: down(act(gate) * up)
return self.down_proj(self.act_fn(gate) * up)
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,26 @@

import torch
import torch.nn as nn
from torch.distributed.tensor import DTensor


def _assert_colwise_parallel(weight: torch.Tensor, name: str) -> None:
"""Verify that a combined-projection weight uses ColwiseParallel (Shard(0)) if TP is active.

Shard(dim=1) is expected from FSDP pre-sharding of combined projections and
is excluded from the check — it is temporary and undone by FSDP all-gather
before the actual matmul.
"""
if isinstance(weight, DTensor) and weight.placements:
from torch.distributed.tensor.placement_types import Shard

tp_shards = [p for p in weight.placements if isinstance(p, Shard) and p.dim != 1]
if tp_shards and not any(p.dim == 0 for p in tp_shards):
raise ValueError(
f"{name} uses an interleaved layout that requires ColwiseParallel "
f"(Shard(0)) for correct TP sharding, but got placements={weight.placements}. "
f"Check your TP plan."
)


class CombinedQKVAttentionMixin:
Expand Down Expand Up @@ -68,6 +88,9 @@ def setup_qkv_projection(
self.use_combined_qkv = True # Always combined in custom implementations
self.q_size = num_attention_heads * head_dim
self.kv_size = num_key_value_heads * head_dim
self._num_kv_groups = num_attention_heads // num_key_value_heads
self._head_dim = head_dim
self._tp_checked = False

# Combined QKV projection for improved efficiency
self.qkv_proj = nn.Linear(
Expand All @@ -79,22 +102,24 @@ def setup_qkv_projection(
def compute_qkv(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute Q, K, V from hidden states using combined projection.

Handles tensor parallelism by dynamically computing split sizes based on actual tensor dimensions.
The QKV weight uses a KV-head-grouped interleaved layout:
[Q_group_0 | K_0 | V_0 | Q_group_1 | K_1 | V_1 | ...]
This ensures ColwiseParallel TP sharding gives each rank complete
KV-head groups. We split within each group (a local operation).

Args:
hidden_states: Input hidden states [batch, seq_len, hidden_size]

Returns:
Tuple of (query, key, value) tensors, each [batch, seq_len, ...]
"""
# Combined QKV projection and split
qkv = self.qkv_proj(hidden_states)
if not self._tp_checked:
_assert_colwise_parallel(self.qkv_proj.weight, "qkv_proj")
self._tp_checked = True

# Compute split sizes based on actual tensor size (handles TP sharding)
qkv_size = qkv.shape[-1]
total_size = self.q_size + 2 * self.kv_size
local_q_size = (self.q_size * qkv_size) // total_size
local_kv_size = (self.kv_size * qkv_size) // total_size
qkv = self.qkv_proj(hidden_states)

q, k, v = qkv.split([local_q_size, local_kv_size, local_kv_size], dim=-1)
return q, k, v
group_width = (self._num_kv_groups + 2) * self._head_dim
qkv = qkv.unflatten(-1, (-1, group_width))
q, k, v = qkv.split([self._num_kv_groups * self._head_dim, self._head_dim, self._head_dim], dim=-1)
return q.flatten(-2), k.flatten(-2), v.flatten(-2)
Loading