|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from typing import Any |
| 16 | + |
| 17 | +import torch |
| 18 | +from torch import nn |
| 19 | + |
| 20 | +from nemo_automodel.components.attention.utils import ( |
| 21 | + initialize_attn_module_and_func, |
| 22 | + postprocess_output_for_attn, |
| 23 | + preprocess_args_and_kwargs_for_attn, |
| 24 | +) |
| 25 | +from nemo_automodel.components.models.common import ( |
| 26 | + BackendConfig, |
| 27 | + initialize_linear_module, |
| 28 | + initialize_rms_norm_module, |
| 29 | +) |
| 30 | +from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk |
| 31 | + |
| 32 | + |
| 33 | +class Qwen3Attention(nn.Module): |
| 34 | + """Qwen3 dense attention with per-head QK RMSNorm and RoPE. |
| 35 | +
|
| 36 | + Identical to the Qwen3 MoE attention layer — the attention mechanism |
| 37 | + is shared between dense and MoE variants. |
| 38 | + """ |
| 39 | + |
| 40 | + def __init__(self, config, backend: BackendConfig): |
| 41 | + super().__init__() |
| 42 | + self.backend = backend |
| 43 | + |
| 44 | + self.num_heads = config.num_attention_heads |
| 45 | + self.num_kv_heads = config.num_key_value_heads |
| 46 | + self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads) |
| 47 | + |
| 48 | + attention_bias = getattr(config, "attention_bias", False) |
| 49 | + |
| 50 | + self.q_proj = initialize_linear_module( |
| 51 | + backend.linear, config.hidden_size, self.num_heads * self.head_dim, attention_bias |
| 52 | + ) |
| 53 | + self.k_proj = initialize_linear_module( |
| 54 | + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias |
| 55 | + ) |
| 56 | + self.v_proj = initialize_linear_module( |
| 57 | + backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias |
| 58 | + ) |
| 59 | + self.o_proj = initialize_linear_module( |
| 60 | + backend.linear, self.num_heads * self.head_dim, config.hidden_size, attention_bias |
| 61 | + ) |
| 62 | + |
| 63 | + self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) |
| 64 | + self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps) |
| 65 | + |
| 66 | + softmax_scale = self.head_dim**-0.5 |
| 67 | + self.attn_module, self.attn_func = initialize_attn_module_and_func( |
| 68 | + attn_impl=backend.attn, |
| 69 | + num_attention_heads=self.num_heads, |
| 70 | + num_qk_channels=self.head_dim, |
| 71 | + num_v_channels=self.head_dim, |
| 72 | + softmax_scale=softmax_scale, |
| 73 | + num_gqa_groups=self.num_kv_heads, |
| 74 | + ) |
| 75 | + |
| 76 | + def forward( |
| 77 | + self, |
| 78 | + x: torch.Tensor, |
| 79 | + *, |
| 80 | + freqs_cis: torch.Tensor, |
| 81 | + attention_mask: torch.Tensor | None = None, |
| 82 | + **attn_kwargs: Any, |
| 83 | + ) -> torch.Tensor: |
| 84 | + if len(x.shape) == 2: |
| 85 | + qkv_format = "thd" |
| 86 | + num_tokens = x.shape[0] |
| 87 | + else: |
| 88 | + qkv_format = "bshd" |
| 89 | + bsz, seqlen, _ = x.size() |
| 90 | + |
| 91 | + q = self.q_proj(x) |
| 92 | + k = self.k_proj(x) |
| 93 | + v = self.v_proj(x) |
| 94 | + |
| 95 | + if qkv_format == "thd": |
| 96 | + q = q.view(num_tokens, self.num_heads, self.head_dim) |
| 97 | + k = k.view(num_tokens, self.num_kv_heads, self.head_dim) |
| 98 | + v = v.view(num_tokens, self.num_kv_heads, self.head_dim) |
| 99 | + else: |
| 100 | + q = q.view(bsz, seqlen, self.num_heads, self.head_dim) |
| 101 | + k = k.view(bsz, seqlen, self.num_kv_heads, self.head_dim) |
| 102 | + v = v.view(bsz, seqlen, self.num_kv_heads, self.head_dim) |
| 103 | + |
| 104 | + q = self.q_norm(q) |
| 105 | + k = self.k_norm(k) |
| 106 | + |
| 107 | + q, k = apply_rotary_emb_qk( |
| 108 | + q, |
| 109 | + k, |
| 110 | + freqs_cis, |
| 111 | + format=qkv_format, |
| 112 | + rope_fusion=self.backend.rope_fusion, |
| 113 | + cu_seqlens=attn_kwargs.get("cu_seqlens", None), |
| 114 | + cp_size=attn_kwargs.get("cp_size", 1), |
| 115 | + cp_rank=attn_kwargs.get("cp_rank", 0), |
| 116 | + ) |
| 117 | + |
| 118 | + q, k, v, _attn_kwargs = preprocess_args_and_kwargs_for_attn( |
| 119 | + q, k, v, attention_mask, self.backend.attn, **attn_kwargs |
| 120 | + ) |
| 121 | + out = self.attn_func(q, k, v, **_attn_kwargs) |
| 122 | + out = postprocess_output_for_attn(out, self.backend.attn) |
| 123 | + |
| 124 | + flatten_dim = 2 if qkv_format == "bshd" else 1 |
| 125 | + out = self.o_proj(out.flatten(flatten_dim)) |
| 126 | + return out |
| 127 | + |
| 128 | + def init_weights(self, buffer_device: torch.device, init_std: float = 0.02): |
| 129 | + for linear in [self.q_proj, self.k_proj, self.v_proj, self.o_proj]: |
| 130 | + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) |
| 131 | + if hasattr(linear, "bias") and linear.bias is not None: |
| 132 | + nn.init.zeros_(linear.bias) |
| 133 | + for norm in (self.q_norm, self.k_norm): |
| 134 | + norm.reset_parameters() |
0 commit comments