From 874b9c5183a220c049001922c1848a8d7fb4e6ae Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Sun, 22 Feb 2026 17:11:43 +0800 Subject: [PATCH 1/2] fix: combined projection (#1324) * fix: compbined projection Signed-off-by: Zhiyu Li * lint Signed-off-by: Zhiyu Li * address comment Signed-off-by: Zhiyu Li * add tp output parity tests Signed-off-by: Zhiyu Li --------- Signed-off-by: Zhiyu Li Signed-off-by: HuiyingLi Signed-off-by: Claude Opus 4.6 (1M context) --- .../combined_projection/combined_mlp.py | 22 +- .../combined_projection/combined_qkv.py | 39 +- .../combined_projection/state_dict_adapter.py | 210 ++++++++-- .../L2_TP_Output_Parity_Minified.sh | 29 ++ .../run_tp_output_parity_minified.py | 390 ++++++++++++++++++ 5 files changed, 638 insertions(+), 52 deletions(-) create mode 100644 tests/functional_tests/llm_pretrain_and_kd/L2_TP_Output_Parity_Minified.sh create mode 100644 tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py diff --git a/nemo_automodel/components/models/common/combined_projection/combined_mlp.py b/nemo_automodel/components/models/common/combined_projection/combined_mlp.py index e15e7df45..d9c84d548 100644 --- a/nemo_automodel/components/models/common/combined_projection/combined_mlp.py +++ b/nemo_automodel/components/models/common/combined_projection/combined_mlp.py @@ -22,6 +22,8 @@ import torch.nn as nn from transformers.activations import ACT2FN +from nemo_automodel.components.models.common.combined_projection.combined_qkv import _assert_colwise_parallel + class CombinedGateUpMLP(nn.Module): """SwiGLU MLP with combined gate_up projection for efficiency. @@ -60,12 +62,15 @@ def __init__(self, config): self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias) self.act_fn = ACT2FN[config.hidden_act] + self._tp_checked = False def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass with combined gate_up projection. - Handles tensor parallelism by dynamically computing split sizes - based on actual tensor dimensions. + The gate_up weight uses a row-interleaved layout: + [gate_0, up_0, gate_1, up_1, ...] + This ensures ColwiseParallel TP sharding gives each rank matching + gate/up pairs. We de-interleave via a local reshape. Args: x: Input tensor [batch, seq_len, hidden_size] @@ -73,13 +78,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: Output tensor [batch, seq_len, hidden_size] """ - # Project and split into gate and up - gate_up = self.gate_up_proj(x) + if not self._tp_checked: + _assert_colwise_parallel(self.gate_up_proj.weight, "gate_up_proj") + self._tp_checked = True - # Handle tensor parallelism: split based on actual tensor size - gate_up_size = gate_up.shape[-1] - local_intermediate_size = gate_up_size // 2 - gate, up = gate_up.split([local_intermediate_size, local_intermediate_size], dim=-1) + gate_up = self.gate_up_proj(x) + gate_up = gate_up.unflatten(-1, (-1, 2)) + gate = gate_up[..., 0] + up = gate_up[..., 1] # SwiGLU: down(act(gate) * up) return self.down_proj(self.act_fn(gate) * up) diff --git a/nemo_automodel/components/models/common/combined_projection/combined_qkv.py b/nemo_automodel/components/models/common/combined_projection/combined_qkv.py index a64939e5a..3adcd8b2f 100644 --- a/nemo_automodel/components/models/common/combined_projection/combined_qkv.py +++ b/nemo_automodel/components/models/common/combined_projection/combined_qkv.py @@ -20,6 +20,20 @@ import torch import torch.nn as nn +from torch.distributed.tensor import DTensor + + +def _assert_colwise_parallel(weight: torch.Tensor, name: str) -> None: + """Verify that a combined-projection weight uses ColwiseParallel (Shard(0)) if TP is active.""" + if isinstance(weight, DTensor) and weight.placements: + from torch.distributed.tensor.placement_types import Shard + + if weight.placements[0] != Shard(0): + raise ValueError( + f"{name} uses an interleaved layout that requires ColwiseParallel " + f"(Shard(0)) for correct TP sharding, but got placements={weight.placements}. " + f"Check your TP plan." + ) class CombinedQKVAttentionMixin: @@ -68,6 +82,9 @@ def setup_qkv_projection( self.use_combined_qkv = True # Always combined in custom implementations self.q_size = num_attention_heads * head_dim self.kv_size = num_key_value_heads * head_dim + self._num_kv_groups = num_attention_heads // num_key_value_heads + self._head_dim = head_dim + self._tp_checked = False # Combined QKV projection for improved efficiency self.qkv_proj = nn.Linear( @@ -79,7 +96,10 @@ def setup_qkv_projection( def compute_qkv(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute Q, K, V from hidden states using combined projection. - Handles tensor parallelism by dynamically computing split sizes based on actual tensor dimensions. + The QKV weight uses a KV-head-grouped interleaved layout: + [Q_group_0 | K_0 | V_0 | Q_group_1 | K_1 | V_1 | ...] + This ensures ColwiseParallel TP sharding gives each rank complete + KV-head groups. We split within each group (a local operation). Args: hidden_states: Input hidden states [batch, seq_len, hidden_size] @@ -87,14 +107,13 @@ def compute_qkv(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch. Returns: Tuple of (query, key, value) tensors, each [batch, seq_len, ...] """ - # Combined QKV projection and split - qkv = self.qkv_proj(hidden_states) + if not self._tp_checked: + _assert_colwise_parallel(self.qkv_proj.weight, "qkv_proj") + self._tp_checked = True - # Compute split sizes based on actual tensor size (handles TP sharding) - qkv_size = qkv.shape[-1] - total_size = self.q_size + 2 * self.kv_size - local_q_size = (self.q_size * qkv_size) // total_size - local_kv_size = (self.kv_size * qkv_size) // total_size + qkv = self.qkv_proj(hidden_states) - q, k, v = qkv.split([local_q_size, local_kv_size, local_kv_size], dim=-1) - return q, k, v + group_width = (self._num_kv_groups + 2) * self._head_dim + qkv = qkv.unflatten(-1, (-1, group_width)) + q, k, v = qkv.split([self._num_kv_groups * self._head_dim, self._head_dim, self._head_dim], dim=-1) + return q.flatten(-2), k.flatten(-2), v.flatten(-2) diff --git a/nemo_automodel/components/models/common/combined_projection/state_dict_adapter.py b/nemo_automodel/components/models/common/combined_projection/state_dict_adapter.py index 6eddea8c2..c63c39afc 100644 --- a/nemo_automodel/components/models/common/combined_projection/state_dict_adapter.py +++ b/nemo_automodel/components/models/common/combined_projection/state_dict_adapter.py @@ -73,6 +73,41 @@ def __init__(self, config): # Compute projection sizes self.q_size = self.num_attention_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim + self.group_size = self.num_attention_heads // self.num_key_value_heads + + def _interleave_qkv(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Interleave Q, K, V by KV-head groups for TP-correct ColwiseParallel sharding. + + Layout: [Q_group_0 | K_0 | V_0 | Q_group_1 | K_1 | V_1 | ...] + where each group has (group_size * head_dim) Q rows, head_dim K rows, head_dim V rows. + """ + rest = q.shape[1:] + q = q.reshape(self.num_key_value_heads, self.group_size * self.head_dim, *rest) + k = k.reshape(self.num_key_value_heads, self.head_dim, *rest) + v = v.reshape(self.num_key_value_heads, self.head_dim, *rest) + return torch.cat([q, k, v], dim=1).reshape(-1, *rest) + + def _deinterleave_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """De-interleave QKV from KV-head-grouped layout back to separate Q, K, V. + + Works for both full (unsharded) and TP-sharded tensors — the group structure + is the same regardless of how many groups this rank holds. + """ + rest = qkv.shape[1:] + group_width = (self.group_size + 2) * self.head_dim + qkv = qkv.reshape(-1, group_width, *rest) + q, k, v = qkv.split([self.group_size * self.head_dim, self.head_dim, self.head_dim], dim=1) + return q.reshape(-1, *rest), k.reshape(-1, *rest), v.reshape(-1, *rest) + + def _interleave_gate_up(self, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """Interleave gate and up row-by-row for TP-correct ColwiseParallel sharding.""" + return torch.stack([gate, up], dim=1).reshape(-1, *gate.shape[1:]) + + def _deinterleave_gate_up(self, gate_up: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """De-interleave gate/up from row-interleaved layout.""" + rest = gate_up.shape[1:] + gate_up = gate_up.reshape(-1, 2, *rest) + return gate_up[:, 0].contiguous(), gate_up[:, 1].contiguous() def from_hf(self, hf_state_dict: dict[str, Any], **kwargs) -> dict[str, Any]: """Convert HuggingFace state dict to combined-projection format. @@ -110,8 +145,7 @@ def from_hf(self, hf_state_dict: dict[str, Any], **kwargs) -> dict[str, Any]: k_weight = hf_state_dict[k_weight_key] v_weight = hf_state_dict[v_weight_key] - # Concatenate along output dimension - qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) + qkv_weight = self._interleave_qkv(q_weight, k_weight, v_weight) custom_state_dict[f"{prefix}.self_attn.qkv_proj.weight"] = qkv_weight processed_keys.update([q_weight_key, k_weight_key, v_weight_key]) @@ -121,8 +155,8 @@ def from_hf(self, hf_state_dict: dict[str, Any], **kwargs) -> dict[str, Any]: k_bias_key = f"{prefix}.self_attn.k_proj.bias" v_bias_key = f"{prefix}.self_attn.v_proj.bias" - qkv_bias = torch.cat( - [hf_state_dict[q_bias_key], hf_state_dict[k_bias_key], hf_state_dict[v_bias_key]], dim=0 + qkv_bias = self._interleave_qkv( + hf_state_dict[q_bias_key], hf_state_dict[k_bias_key], hf_state_dict[v_bias_key] ) custom_state_dict[f"{prefix}.self_attn.qkv_proj.bias"] = qkv_bias processed_keys.update([q_bias_key, k_bias_key, v_bias_key]) @@ -135,8 +169,7 @@ def from_hf(self, hf_state_dict: dict[str, Any], **kwargs) -> dict[str, Any]: gate_weight = hf_state_dict[gate_weight_key] up_weight = hf_state_dict[up_weight_key] - # Concatenate along output dimension - gate_up_weight = torch.cat([gate_weight, up_weight], dim=0) + gate_up_weight = self._interleave_gate_up(gate_weight, up_weight) custom_state_dict[f"{prefix}.mlp.gate_up_proj.weight"] = gate_up_weight processed_keys.update([gate_weight_key, up_weight_key]) @@ -145,7 +178,7 @@ def from_hf(self, hf_state_dict: dict[str, Any], **kwargs) -> dict[str, Any]: if gate_bias_key in hf_state_dict: up_bias_key = f"{prefix}.mlp.up_proj.bias" - gate_up_bias = torch.cat([hf_state_dict[gate_bias_key], hf_state_dict[up_bias_key]], dim=0) + gate_up_bias = self._interleave_gate_up(hf_state_dict[gate_bias_key], hf_state_dict[up_bias_key]) custom_state_dict[f"{prefix}.mlp.gate_up_proj.bias"] = gate_up_bias processed_keys.update([gate_bias_key, up_bias_key]) @@ -202,15 +235,7 @@ def to_hf( qkv_weight_key = f"{prefix}.self_attn.qkv_proj.weight" if qkv_weight_key in state_dict: - qkv_weight = state_dict[qkv_weight_key] - - # Compute local split sizes based on actual tensor size (handles TP sharding) - qkv_actual_size = qkv_weight.shape[0] - total_size = self.q_size + 2 * self.kv_size - local_q_size = (self.q_size * qkv_actual_size) // total_size - local_kv_size = (self.kv_size * qkv_actual_size) // total_size - - q_weight, k_weight, v_weight = qkv_weight.split([local_q_size, local_kv_size, local_kv_size], dim=0) + q_weight, k_weight, v_weight = self._deinterleave_qkv(state_dict[qkv_weight_key]) hf_state_dict[f"{prefix}.self_attn.q_proj.weight"] = q_weight hf_state_dict[f"{prefix}.self_attn.k_proj.weight"] = k_weight @@ -220,12 +245,7 @@ def to_hf( # Handle biases if present qkv_bias_key = f"{prefix}.self_attn.qkv_proj.bias" if qkv_bias_key in state_dict: - qkv_bias = state_dict[qkv_bias_key] - qkv_bias_size = qkv_bias.shape[0] - local_q_size = (self.q_size * qkv_bias_size) // total_size - local_kv_size = (self.kv_size * qkv_bias_size) // total_size - - q_bias, k_bias, v_bias = qkv_bias.split([local_q_size, local_kv_size, local_kv_size], dim=0) + q_bias, k_bias, v_bias = self._deinterleave_qkv(state_dict[qkv_bias_key]) hf_state_dict[f"{prefix}.self_attn.q_proj.bias"] = q_bias hf_state_dict[f"{prefix}.self_attn.k_proj.bias"] = k_bias @@ -236,13 +256,7 @@ def to_hf( gate_up_weight_key = f"{prefix}.mlp.gate_up_proj.weight" if gate_up_weight_key in state_dict: - gate_up_weight = state_dict[gate_up_weight_key] - - # Compute local split sizes - gate_up_actual_size = gate_up_weight.shape[0] - local_intermediate_size = gate_up_actual_size // 2 - - gate_weight, up_weight = gate_up_weight.split([local_intermediate_size, local_intermediate_size], dim=0) + gate_weight, up_weight = self._deinterleave_gate_up(state_dict[gate_up_weight_key]) hf_state_dict[f"{prefix}.mlp.gate_proj.weight"] = gate_weight hf_state_dict[f"{prefix}.mlp.up_proj.weight"] = up_weight @@ -251,11 +265,7 @@ def to_hf( # Handle biases if present gate_up_bias_key = f"{prefix}.mlp.gate_up_proj.bias" if gate_up_bias_key in state_dict: - gate_up_bias = state_dict[gate_up_bias_key] - gate_up_bias_size = gate_up_bias.shape[0] - local_intermediate_size = gate_up_bias_size // 2 - - gate_bias, up_bias = gate_up_bias.split([local_intermediate_size, local_intermediate_size], dim=0) + gate_bias, up_bias = self._deinterleave_gate_up(state_dict[gate_up_bias_key]) hf_state_dict[f"{prefix}.mlp.gate_proj.bias"] = gate_bias hf_state_dict[f"{prefix}.mlp.up_proj.bias"] = up_bias @@ -271,3 +281,135 @@ def to_hf( hf_state_dict = {k: v for k, v in hf_state_dict.items() if not re.match(exclude_key_regex, k)} return hf_state_dict + + def _split_remaining_combined_projection_keys(self, hf_state_dict: dict[str, Any]) -> None: + """Split any remaining combined-projection keys in-place. + + Handles LoRA adapter weights (lora_A, lora_B), DoRA magnitude vectors, + and any base weight/bias keys that weren't caught by the layer-indexed loop + (e.g., keys with a ``base_model.model.`` prefix from PEFT saving). + + For keys containing ``.self_attn.qkv_proj.``: + - ``lora_A`` weights (input dimension) are duplicated to q/k/v projections. + - All other weights (lora_B, magnitude, weight, bias) are split along dim 0 + using the Q/KV size ratio. + + For keys containing ``.mlp.gate_up_proj.``: + - ``lora_A`` weights are duplicated to gate/up projections. + - All other weights are split in half along dim 0. + + Args: + hf_state_dict: State dict to modify in-place. + """ + combined_qkv_keys = [k for k in hf_state_dict if ".self_attn.qkv_proj." in k] + for key in combined_qkv_keys: + value = hf_state_dict.pop(key) + pre, suffix = key.split(".self_attn.qkv_proj.", 1) + + if "lora_A" in suffix: + # Input-dimension LoRA weight: identical for all projections + hf_state_dict[f"{pre}.self_attn.q_proj.{suffix}"] = value + hf_state_dict[f"{pre}.self_attn.k_proj.{suffix}"] = value.clone() + hf_state_dict[f"{pre}.self_attn.v_proj.{suffix}"] = value.clone() + else: + # Output-dimension weight (lora_B, magnitude, base weight/bias): de-interleave + q_val, k_val, v_val = self._deinterleave_qkv(value) + hf_state_dict[f"{pre}.self_attn.q_proj.{suffix}"] = q_val + hf_state_dict[f"{pre}.self_attn.k_proj.{suffix}"] = k_val + hf_state_dict[f"{pre}.self_attn.v_proj.{suffix}"] = v_val + + combined_gate_up_keys = [k for k in hf_state_dict if ".mlp.gate_up_proj." in k] + for key in combined_gate_up_keys: + value = hf_state_dict.pop(key) + pre, suffix = key.split(".mlp.gate_up_proj.", 1) + + if "lora_A" in suffix: + hf_state_dict[f"{pre}.mlp.gate_proj.{suffix}"] = value + hf_state_dict[f"{pre}.mlp.up_proj.{suffix}"] = value.clone() + else: + # Output-dimension weight: de-interleave + gate_val, up_val = self._deinterleave_gate_up(value) + hf_state_dict[f"{pre}.mlp.gate_proj.{suffix}"] = gate_val + hf_state_dict[f"{pre}.mlp.up_proj.{suffix}"] = up_val + + def _recombine_split_projection_keys(self, state_dict: dict[str, Any]) -> None: + """Recombine split projection LoRA/DoRA keys back to combined format. + + This is the reverse of ``_split_remaining_combined_projection_keys``. + It handles LoRA adapter weights and DoRA magnitude vectors that were + split for HF-PEFT compatibility during ``to_hf()`` and need to be + recombined when loading back into a model with combined projections. + + For keys containing ``.self_attn.q_proj.``: + - ``lora_A`` weights (which were duplicated during split) are + deduplicated — we take the ``q_proj`` version. + - All other weights (``lora_B``, magnitude, etc.) are concatenated + along dim 0 in Q, K, V order. + + For keys containing ``.mlp.gate_proj.``: + - ``lora_A`` weights are deduplicated — we take the ``gate_proj`` version. + - All other weights are concatenated along dim 0 in gate, up order. + + Keys that end with ``.weight`` or ``.bias`` directly on the projection + (e.g., ``q_proj.weight``) are skipped because those are already handled + by the layer-indexed loop in ``from_hf``. + + Args: + state_dict: State dict to modify in-place. + """ + # --- QKV recombination --- + # Find q_proj keys that are NOT base weight/bias (already handled by layer loop) + q_keys = [ + k + for k in list(state_dict.keys()) + if ".self_attn.q_proj." in k + and not k.endswith(".self_attn.q_proj.weight") + and not k.endswith(".self_attn.q_proj.bias") + ] + + for q_key in q_keys: + k_key = q_key.replace(".self_attn.q_proj.", ".self_attn.k_proj.") + v_key = q_key.replace(".self_attn.q_proj.", ".self_attn.v_proj.") + + if k_key not in state_dict or v_key not in state_dict: + continue + + combined_key = q_key.replace(".self_attn.q_proj.", ".self_attn.qkv_proj.") + + q_val = state_dict.pop(q_key) + k_val = state_dict.pop(k_key) + v_val = state_dict.pop(v_key) + + if "lora_A" in q_key: + # lora_A weights were duplicated during split — just take one + state_dict[combined_key] = q_val + else: + # lora_B, magnitude, etc. — interleave by KV-head groups + state_dict[combined_key] = self._interleave_qkv(q_val, k_val, v_val) + + # --- gate_up recombination --- + gate_keys = [ + k + for k in list(state_dict.keys()) + if ".mlp.gate_proj." in k + and not k.endswith(".mlp.gate_proj.weight") + and not k.endswith(".mlp.gate_proj.bias") + ] + + for gate_key in gate_keys: + up_key = gate_key.replace(".mlp.gate_proj.", ".mlp.up_proj.") + + if up_key not in state_dict: + continue + + combined_key = gate_key.replace(".mlp.gate_proj.", ".mlp.gate_up_proj.") + + gate_val = state_dict.pop(gate_key) + up_val = state_dict.pop(up_key) + + if "lora_A" in gate_key: + # lora_A weights were duplicated during split — just take one + state_dict[combined_key] = gate_val + else: + # lora_B, magnitude, etc. — interleave row-by-row + state_dict[combined_key] = self._interleave_gate_up(gate_val, up_val) diff --git a/tests/functional_tests/llm_pretrain_and_kd/L2_TP_Output_Parity_Minified.sh b/tests/functional_tests/llm_pretrain_and_kd/L2_TP_Output_Parity_Minified.sh new file mode 100644 index 000000000..aa6b7a126 --- /dev/null +++ b/tests/functional_tests/llm_pretrain_and_kd/L2_TP_Output_Parity_Minified.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeuo pipefail + +export PYTHONPATH=${PYTHONPATH:-}:$(pwd) +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1}" + +# Override if needed, e.g. KL_THRESHOLD=1e-5 bash ... +KL_THRESHOLD="${KL_THRESHOLD:-1e-6}" + +torchrun --nproc_per_node=2 --nnodes=1 \ + tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py \ + --models qwen3 qwen3_seq_cls ministral3 llama qwen2 \ + --sequence_parallel both \ + --kl_threshold "${KL_THRESHOLD}" + diff --git a/tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py b/tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py new file mode 100644 index 000000000..b68fedd8b --- /dev/null +++ b/tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Standalone distributed TP output parity test (minified models). + +Validates that TP=2 produces the same logits as TP=1 for tiny ("2-layer thin") +variants of: +- Qwen3ForCausalLM (HF) +- Qwen3ForSequenceClassification (HF) +- Ministral3ForCausalLM (NeMo Automodel custom) +- LlamaForCausalLM (NeMo Automodel custom, combined QKV + gate_up projections) +- Qwen2ForCausalLM (NeMo Automodel custom, combined QKV + gate_up projections) + +It also validates both tensor-parallel plans: +- sequence_parallel=False +- sequence_parallel=True + +Usage: + torchrun --nproc_per_node=2 tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py + + # Optional: select models / SP mode + torchrun --nproc_per_node=2 tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py \\ + --models qwen3 qwen3_seq_cls ministral3 \\ + --sequence_parallel both \\ + --kl_threshold 1e-6 +""" + +from __future__ import annotations + +import argparse +import os +import sys +from dataclasses import dataclass +from typing import Literal, Sequence, cast + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.parallel import parallelize_module +from torch.distributed.tensor.placement_types import Replicate + +from nemo_automodel._transformers.utils import apply_cache_compatibility_patches +from nemo_automodel.components.distributed.parallelizer import _get_parallel_plan +from nemo_automodel.components.models.common.utils import BackendConfig +from nemo_automodel.components.models.llama.model import LlamaConfig, LlamaForCausalLM +from nemo_automodel.components.models.mistral3.model import Ministral3Config, Ministral3ForCausalLM +from nemo_automodel.components.models.qwen2.model import Qwen2Config, Qwen2ForCausalLM +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM, Qwen3ForSequenceClassification + +ModelKind = Literal["qwen3", "qwen3_seq_cls", "ministral3", "llama", "qwen2"] +SPMode = Literal["true", "false", "both"] + + +def _is_distributed() -> bool: + return dist.is_available() and dist.is_initialized() + + +def _rank() -> int: + return dist.get_rank() if _is_distributed() else 0 + + +def _world_size() -> int: + return dist.get_world_size() if _is_distributed() else 1 + + +def _init_distributed() -> tuple[str, torch.device, str]: + """Init process group if launched via torchrun. + + Returns: + backend, device, device_type + """ + if not dist.is_available(): + return "none", torch.device("cpu"), "cpu" + if dist.is_initialized(): + # Best-effort infer device_type. + device_type = "cuda" if torch.cuda.is_available() else "cpu" + return ( + dist.get_backend(), + torch.device("cuda", torch.cuda.current_device()) if device_type == "cuda" else torch.device("cpu"), + device_type, + ) + + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + return "none", torch.device("cpu"), "cpu" + + if torch.cuda.is_available(): + backend = "nccl" + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + device_type = "cuda" + else: + # Ensure gloo binds to loopback on minimal containers. + os.environ.setdefault("GLOO_SOCKET_IFNAME", "lo") + backend = "gloo" + device = torch.device("cpu") + device_type = "cpu" + + dist.init_process_group(backend=backend) + return backend, device, device_type + + +def _maybe_gather_dtensor_to_replicated_local(x: torch.Tensor | DTensor, *, tp_mesh: DeviceMesh) -> torch.Tensor: + if isinstance(x, DTensor): + x = x.redistribute(device_mesh=tp_mesh, placements=[Replicate()]).to_local() + return cast(torch.Tensor, x) + + +def _kl_divergence_from_logits(*, reference_logits: torch.Tensor, candidate_logits: torch.Tensor) -> torch.Tensor: + """Return KL(reference || candidate), averaged per token. + + Both inputs are expected to be full (non-sharded) logits with shape [B, T, V]. + """ + assert reference_logits.shape == candidate_logits.shape + vocab_size = reference_logits.shape[-1] + ref_log_probs = F.log_softmax(reference_logits.float(), dim=-1).reshape(-1, vocab_size) + cand_log_probs = F.log_softmax(candidate_logits.float(), dim=-1).reshape(-1, vocab_size) + # F.kl_div expects input=log(q), target=log(p) when log_target=True → KL(p || q) + return F.kl_div(cand_log_probs, ref_log_probs, reduction="none", log_target=True).sum(-1) + +@dataclass(frozen=True) +class _Case: + kind: ModelKind + sequence_parallel: bool + + def name(self) -> str: + return f"{self.kind}/sequence_parallel={self.sequence_parallel}" + + +def _build_minified_model(kind: ModelKind): + if kind == "ministral3": + cfg = Ministral3Config( + vocab_size=128, + hidden_size=64, + intermediate_size=256, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + max_position_embeddings=128, + use_cache=False, + tie_word_embeddings=True, + rope_parameters={ + "type": "yarn", + "rope_theta": 1000000.0, + "factor": 16.0, + "original_max_position_embeddings": 8, + "max_position_embeddings": 128, + "beta_fast": 32.0, + "beta_slow": 1.0, + "mscale_all_dim": 1.0, + "mscale": 1.0, + "llama_4_scaling_beta": 0.1, + }, + ) + return cfg, Ministral3ForCausalLM(cfg) + + if kind == "qwen3": + num_layers = 2 + cfg = Qwen3Config( + vocab_size=128, + hidden_size=64, + intermediate_size=256, + num_hidden_layers=num_layers, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + max_position_embeddings=128, + use_cache=False, + tie_word_embeddings=True, + attention_bias=False, + use_sliding_window=False, + layer_types=["full_attention"] * num_layers, + ) + return cfg, Qwen3ForCausalLM(cfg) + + if kind == "qwen3_seq_cls": + num_layers = 2 + cfg = Qwen3Config( + vocab_size=128, + hidden_size=64, + intermediate_size=256, + num_hidden_layers=num_layers, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + max_position_embeddings=128, + use_cache=False, + tie_word_embeddings=True, + attention_bias=False, + use_sliding_window=False, + layer_types=["full_attention"] * num_layers, + num_labels=2, # must be divisible by TP size (2) if sharded + pad_token_id=0, # required for batch_size>1 pooling + ) + return cfg, Qwen3ForSequenceClassification(cfg) + + if kind == "llama": + num_layers = 2 + backend = BackendConfig(rms_norm="torch") + cfg = LlamaConfig( + vocab_size=128, + hidden_size=64, + intermediate_size=256, + num_hidden_layers=num_layers, + num_attention_heads=4, + num_key_value_heads=2, + max_position_embeddings=128, + use_cache=False, + tie_word_embeddings=True, + attention_bias=False, + attn_implementation="eager", + dtype=torch.float32, + ) + return cfg, LlamaForCausalLM(cfg, backend=backend) + + if kind == "qwen2": + num_layers = 2 + backend = BackendConfig(rms_norm="torch") + cfg = Qwen2Config( + vocab_size=128, + hidden_size=64, + intermediate_size=256, + num_hidden_layers=num_layers, + num_attention_heads=4, + num_key_value_heads=2, + max_position_embeddings=128, + use_cache=False, + tie_word_embeddings=True, + use_sliding_window=False, + sliding_window=None, + layer_types=["full_attention"] * num_layers, + attn_implementation="eager", + dtype=torch.float32, + ) + return cfg, Qwen2ForCausalLM(cfg, backend=backend) + + raise ValueError(f"Unknown model kind: {kind}") + + +def _run_case( + case: _Case, + *, + device: torch.device, + device_type: str, + kl_threshold: float, +) -> tuple[bool, float]: + """Return (ok, kl_divergence).""" + world_size = _world_size() + assert world_size == 2, f"This test is intended for TP=2; got world_size={world_size}" + + # Use the same initial weights for baseline and TP models. + torch.manual_seed(1234) + if device_type == "cuda": + torch.cuda.manual_seed_all(1234) + + cfg, baseline = _build_minified_model(case.kind) + baseline = baseline.to(device=device, dtype=torch.float32) + baseline.eval() + + # Deterministic inputs across ranks. + torch.manual_seed(999) + if device_type == "cuda": + torch.cuda.manual_seed_all(999) + # Keep this small for a fast functional test. Also avoid 0 to keep seq-cls pad-token + # pooling deterministic. + input_ids = torch.randint(1, int(cfg.vocab_size), (2, 1024), dtype=torch.long, device=device) + attention_mask = torch.ones_like(input_ids) + + with torch.inference_mode(): + baseline_logits = cast( + torch.Tensor, + baseline(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits, + ) + + # Rebuild model with identical weights, then TP-parallelize. + torch.manual_seed(1234) + if device_type == "cuda": + torch.cuda.manual_seed_all(1234) + _, tp_model = _build_minified_model(case.kind) + tp_model = tp_model.to(device=device, dtype=torch.float32) + tp_model.eval() + + tp_mesh = DeviceMesh(device_type, torch.arange(world_size, device="cpu"), mesh_dim_names=("tp",)) + plan = _get_parallel_plan(tp_model, sequence_parallel=case.sequence_parallel) + parallelize_module(tp_model, tp_mesh, plan) + + with torch.inference_mode(): + tp_logits = tp_model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits + tp_logits_full = _maybe_gather_dtensor_to_replicated_local(tp_logits, tp_mesh=tp_mesh) + + kl = _kl_divergence_from_logits(reference_logits=baseline_logits, candidate_logits=tp_logits_full) + # NOTE: keep this threshold loose; different TP reduction orders can introduce tiny numeric drift. + ok = torch.all(kl <= kl_threshold) + return ok, kl.view(-1).max().item() + + +def main(argv: Sequence[str] | None = None) -> int: + # Ensure any required transformers compatibility patches are applied before we build models. + apply_cache_compatibility_patches() + + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + default=["qwen3", "qwen3_seq_cls", "ministral3", "llama", "qwen2"], + choices=["qwen3", "qwen3_seq_cls", "ministral3", "llama", "qwen2"], + help="Which models to test.", + ) + parser.add_argument( + "--sequence_parallel", + default="both", + choices=["true", "false", "both"], + help="Run sequence_parallel=True/False/both.", + ) + parser.add_argument( + "--kl_threshold", + type=float, + default=1e-6, + help="Fail if KL(TP=1 || TP=2) exceeds this threshold (averaged per token).", + ) + args = parser.parse_args(list(argv) if argv is not None else None) + + _init_distributed() + rank = _rank() + world_size = _world_size() + + if world_size != 2: + if rank == 0: + print(f"ERROR: This test requires world_size=2 (TP=2), got {world_size}", file=sys.stderr) + return 1 + + # Derive per-rank device after init. + if torch.cuda.is_available() and dist.get_backend() == "nccl": + local_rank = int(os.environ.get("LOCAL_RANK", str(rank))) + device = torch.device(f"cuda:{local_rank}") + device_type = "cuda" + else: + device = torch.device("cpu") + device_type = "cpu" + + if args.sequence_parallel == "both": + sp_flags = [False, True] + else: + sp_flags = [args.sequence_parallel == "true"] + + cases = [_Case(kind=cast(ModelKind, k), sequence_parallel=sp) for k in args.models for sp in sp_flags] + + all_ok = True + + for case in cases: + # Keep ranks roughly in sync for cleaner output. + dist.barrier() + ok, kl = _run_case(case, device=device, device_type=device_type, kl_threshold=float(args.kl_threshold)) + + ok_tensor = torch.tensor(1 if ok else 0, device=device, dtype=torch.int) + dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN) + all_ok = all_ok and bool(ok_tensor.item()) + + kl_tensor = torch.tensor(kl, device=device, dtype=torch.float32) + dist.all_reduce(kl_tensor, op=dist.ReduceOp.MAX) + + if rank == 0: + status = "PASS" if bool(ok_tensor.item()) else "FAIL" + print(f"{status}: {case.name()} (kl_div={kl_tensor.item():.6g}, threshold={args.kl_threshold:g})") + + if rank == 0 and all_ok: + print("PASS: all TP parity checks passed") + + return 0 if all_ok else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) + From 35347d651c38a7f0478fd1227fb86fd5e7286261 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 24 Feb 2026 21:14:01 +0800 Subject: [PATCH 2/2] fix: FSDP pre-shard combined projections on dim 1 for Qwen2.5-7B support (#1357) * fix: FSDP pre-shard combined projections on dim 1 for Qwen2.5-7B support Signed-off-by: Zhiyu Li * revert recipe change Signed-off-by: Zhiyu Li * lint Signed-off-by: Zhiyu Li --------- Signed-off-by: Zhiyu Li Signed-off-by: HuiyingLi Signed-off-by: Claude Opus 4.6 (1M context) --- .../components/distributed/parallelizer.py | 51 +++++++++++++++++++ .../combined_projection/combined_qkv.py | 10 +++- .../combined_projection/state_dict_adapter.py | 41 +++++++++++++-- .../run_tp_output_parity_minified.py | 7 ++- 4 files changed, 102 insertions(+), 7 deletions(-) diff --git a/nemo_automodel/components/distributed/parallelizer.py b/nemo_automodel/components/distributed/parallelizer.py index ac3c94a97..d1daae10c 100644 --- a/nemo_automodel/components/distributed/parallelizer.py +++ b/nemo_automodel/components/distributed/parallelizer.py @@ -405,6 +405,53 @@ def _register(cls): return _register +def _pre_shard_combined_projections( + module: nn.Module, + mesh: DeviceMesh, + mp_policy: Optional[MixedPrecisionPolicy], + offload_policy: Optional[OffloadPolicy] = None, +) -> None: + """Pre-shard combined projection modules (qkv_proj, gate_up_proj) on dim 1. + + Combined QKV and gate_up projections use interleaved layouts on dim 0 + (grouping Q/K/V rows or gate/up rows). Standard FSDP ``Shard(0)`` can + break these group boundaries when ``num_kv_heads`` does not divide evenly + by the FSDP shard count, causing reshape failures in the state-dict adapter. + + Sharding on dim 1 keeps dim 0 intact so reshape / split operations work + for any model configuration. 1-D tensors (biases) use ``Shard(0)`` because + FSDP's ``shard_placement_fn`` only accepts ``Shard`` placements; the + state-dict adapter's ``_gather_1d_if_needed`` gathers them when the local + shard doesn't align with interleaved group boundaries. + + Follows the same pattern as MoE expert sharding + (``nemo_automodel/components/moe/parallelizer.py``). + """ + try: + from nemo_automodel.components.models.common.combined_projection.combined_mlp import CombinedGateUpMLP + from nemo_automodel.components.models.common.combined_projection.combined_qkv import CombinedQKVAttentionMixin + except ImportError: + return + + _shard_fn = lambda p: Shard(1) if p.ndim >= 2 else Shard(0) + + for sub in module.modules(): + target = None + if isinstance(sub, CombinedQKVAttentionMixin) and hasattr(sub, "qkv_proj"): + target = sub.qkv_proj + elif isinstance(sub, CombinedGateUpMLP) and hasattr(sub, "gate_up_proj"): + target = sub.gate_up_proj + + if target is not None and not isinstance(target, FSDPModule): + fully_shard( + target, + mesh=mesh, + mp_policy=mp_policy, + offload_policy=offload_policy, + shard_placement_fn=_shard_fn, + ) + + def apply_fsdp2_sharding_recursively( module: nn.Module, mesh: DeviceMesh, @@ -441,6 +488,10 @@ def apply_fsdp2_sharding_recursively( if isinstance(child_module, nn.ModuleList): apply_fsdp2_sharding_recursively(child_module, mesh, mp_policy, offload_policy) else: + # Pre-shard combined projection submodules on dim 1 so that + # the parent fully_shard (dim 0) skips them automatically. + _pre_shard_combined_projections(child_module, mesh, mp_policy, offload_policy) + # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately reshard_after_forward = int(layer_id) < len(module) - 1 diff --git a/nemo_automodel/components/models/common/combined_projection/combined_qkv.py b/nemo_automodel/components/models/common/combined_projection/combined_qkv.py index 3adcd8b2f..19eaab26b 100644 --- a/nemo_automodel/components/models/common/combined_projection/combined_qkv.py +++ b/nemo_automodel/components/models/common/combined_projection/combined_qkv.py @@ -24,11 +24,17 @@ def _assert_colwise_parallel(weight: torch.Tensor, name: str) -> None: - """Verify that a combined-projection weight uses ColwiseParallel (Shard(0)) if TP is active.""" + """Verify that a combined-projection weight uses ColwiseParallel (Shard(0)) if TP is active. + + Shard(dim=1) is expected from FSDP pre-sharding of combined projections and + is excluded from the check — it is temporary and undone by FSDP all-gather + before the actual matmul. + """ if isinstance(weight, DTensor) and weight.placements: from torch.distributed.tensor.placement_types import Shard - if weight.placements[0] != Shard(0): + tp_shards = [p for p in weight.placements if isinstance(p, Shard) and p.dim != 1] + if tp_shards and not any(p.dim == 0 for p in tp_shards): raise ValueError( f"{name} uses an interleaved layout that requires ColwiseParallel " f"(Shard(0)) for correct TP sharding, but got placements={weight.placements}. " diff --git a/nemo_automodel/components/models/common/combined_projection/state_dict_adapter.py b/nemo_automodel/components/models/common/combined_projection/state_dict_adapter.py index c63c39afc..bbbb46f68 100644 --- a/nemo_automodel/components/models/common/combined_projection/state_dict_adapter.py +++ b/nemo_automodel/components/models/common/combined_projection/state_dict_adapter.py @@ -75,14 +75,44 @@ def __init__(self, config): self.kv_size = self.num_key_value_heads * self.head_dim self.group_size = self.num_attention_heads // self.num_key_value_heads + @staticmethod + def _gather_1d_if_needed(tensor: torch.Tensor, divisor: int) -> torch.Tensor: + """Gather a 1-D DTensor on dim 0 when the local shard isn't divisible by *divisor*. + + FSDP2's ``shard_placement_fn`` only accepts ``Shard`` placements (not + ``Replicate``), so 1-D bias vectors of combined projections end up with + ``Shard(0)`` even though their interleaved layout may not divide evenly + across the FSDP shard count. This helper gathers such biases to full + before reshape / split operations in the state-dict adapter. + + Weights are handled by FSDP ``Shard(1)``, so they never need this. + """ + if tensor.ndim != 1: + return tensor + try: + from torch.distributed.tensor import DTensor + from torch.distributed.tensor.placement_types import Replicate, Shard + except ImportError: + return tensor + if not isinstance(tensor, DTensor): + return tensor + if tensor.to_local().shape[0] % divisor == 0: + return tensor + new_placements = tuple(Replicate() if isinstance(p, Shard) and p.dim == 0 else p for p in tensor.placements) + return tensor.redistribute(tensor.device_mesh, new_placements) + def _interleave_qkv(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: """Interleave Q, K, V by KV-head groups for TP-correct ColwiseParallel sharding. Layout: [Q_group_0 | K_0 | V_0 | Q_group_1 | K_1 | V_1 | ...] where each group has (group_size * head_dim) Q rows, head_dim K rows, head_dim V rows. """ + q_group_width = self.group_size * self.head_dim + q = self._gather_1d_if_needed(q, q_group_width) + k = self._gather_1d_if_needed(k, self.head_dim) + v = self._gather_1d_if_needed(v, self.head_dim) rest = q.shape[1:] - q = q.reshape(self.num_key_value_heads, self.group_size * self.head_dim, *rest) + q = q.reshape(self.num_key_value_heads, q_group_width, *rest) k = k.reshape(self.num_key_value_heads, self.head_dim, *rest) v = v.reshape(self.num_key_value_heads, self.head_dim, *rest) return torch.cat([q, k, v], dim=1).reshape(-1, *rest) @@ -90,11 +120,13 @@ def _interleave_qkv(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> def _deinterleave_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """De-interleave QKV from KV-head-grouped layout back to separate Q, K, V. - Works for both full (unsharded) and TP-sharded tensors — the group structure - is the same regardless of how many groups this rank holds. + Works for full (unsharded), TP-sharded, and FSDP-sharded tensors. + When the tensor is a DTensor whose dim-0 local shard doesn't align with + the QKV group boundary, the tensor is gathered to full first. """ - rest = qkv.shape[1:] group_width = (self.group_size + 2) * self.head_dim + qkv = self._gather_1d_if_needed(qkv, group_width) + rest = qkv.shape[1:] qkv = qkv.reshape(-1, group_width, *rest) q, k, v = qkv.split([self.group_size * self.head_dim, self.head_dim, self.head_dim], dim=1) return q.reshape(-1, *rest), k.reshape(-1, *rest), v.reshape(-1, *rest) @@ -105,6 +137,7 @@ def _interleave_gate_up(self, gate: torch.Tensor, up: torch.Tensor) -> torch.Ten def _deinterleave_gate_up(self, gate_up: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """De-interleave gate/up from row-interleaved layout.""" + gate_up = self._gather_1d_if_needed(gate_up, 2) rest = gate_up.shape[1:] gate_up = gate_up.reshape(-1, 2, *rest) return gate_up[:, 0].contiguous(), gate_up[:, 1].contiguous() diff --git a/tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py b/tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py index b68fedd8b..af05b6b7b 100644 --- a/tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py +++ b/tests/functional_tests/llm_pretrain_and_kd/run_tp_output_parity_minified.py @@ -248,7 +248,12 @@ def _build_minified_model(kind: ModelKind): attn_implementation="eager", dtype=torch.float32, ) - return cfg, Qwen2ForCausalLM(cfg, backend=backend) + model = Qwen2ForCausalLM(cfg, backend=backend) + with torch.no_grad(): + for _, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) and module.bias is not None: + torch.nn.init.normal_(module.bias, mean=0.1, std=0.1) + return cfg, model raise ValueError(f"Unknown model kind: {kind}")