Skip to content

Commit b69710c

Browse files
feat: add Qwen3 dense model handler for NeMo Automodel
Adds a custom Qwen3ForCausalLM implementation that supports TE backend and context parallelism for dense Qwen3 models (e.g. Qwen3-14B). Uses the same attention with per-head QK RMSNorm as the existing Qwen3 MoE handler, with a standard SwiGLU MLP instead of MoE layers. Registers Qwen3ForCausalLM in the model architecture mapping so NeMo routes it through the custom model path instead of falling back to vanilla HuggingFace (which doesn't support backend/CP). Made-with: Cursor
1 parent 0be57dc commit b69710c

4 files changed

Lines changed: 408 additions & 0 deletions

File tree

nemo_automodel/_transformers/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@
9292
"Qwen2ForCausalLM",
9393
("nemo_automodel.components.models.qwen2.model", "Qwen2ForCausalLM"),
9494
),
95+
(
96+
"Qwen3ForCausalLM",
97+
("nemo_automodel.components.models.qwen3.model", "Qwen3ForCausalLM"),
98+
),
9599
(
96100
"Qwen3MoeForCausalLM",
97101
("nemo_automodel.components.models.qwen3_moe.model", "Qwen3MoeForCausalLM"),
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any
16+
17+
import torch
18+
from torch import nn
19+
20+
from nemo_automodel.components.attention.utils import (
21+
initialize_attn_module_and_func,
22+
postprocess_output_for_attn,
23+
preprocess_args_and_kwargs_for_attn,
24+
)
25+
from nemo_automodel.components.models.common import (
26+
BackendConfig,
27+
initialize_linear_module,
28+
initialize_rms_norm_module,
29+
)
30+
from nemo_automodel.components.models.gpt_oss.rope_utils import apply_rotary_emb_qk
31+
32+
33+
class Qwen3Attention(nn.Module):
34+
"""Qwen3 dense attention with per-head QK RMSNorm and RoPE.
35+
36+
Identical to the Qwen3 MoE attention layer — the attention mechanism
37+
is shared between dense and MoE variants.
38+
"""
39+
40+
def __init__(self, config, backend: BackendConfig):
41+
super().__init__()
42+
self.backend = backend
43+
44+
self.num_heads = config.num_attention_heads
45+
self.num_kv_heads = config.num_key_value_heads
46+
self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
47+
48+
attention_bias = getattr(config, "attention_bias", False)
49+
50+
self.q_proj = initialize_linear_module(
51+
backend.linear, config.hidden_size, self.num_heads * self.head_dim, attention_bias
52+
)
53+
self.k_proj = initialize_linear_module(
54+
backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias
55+
)
56+
self.v_proj = initialize_linear_module(
57+
backend.linear, config.hidden_size, self.num_kv_heads * self.head_dim, attention_bias
58+
)
59+
self.o_proj = initialize_linear_module(
60+
backend.linear, self.num_heads * self.head_dim, config.hidden_size, attention_bias
61+
)
62+
63+
self.q_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps)
64+
self.k_norm = initialize_rms_norm_module(backend.rms_norm, self.head_dim, eps=config.rms_norm_eps)
65+
66+
softmax_scale = self.head_dim**-0.5
67+
self.attn_module, self.attn_func = initialize_attn_module_and_func(
68+
attn_impl=backend.attn,
69+
num_attention_heads=self.num_heads,
70+
num_qk_channels=self.head_dim,
71+
num_v_channels=self.head_dim,
72+
softmax_scale=softmax_scale,
73+
num_gqa_groups=self.num_kv_heads,
74+
)
75+
76+
def forward(
77+
self,
78+
x: torch.Tensor,
79+
*,
80+
freqs_cis: torch.Tensor,
81+
attention_mask: torch.Tensor | None = None,
82+
**attn_kwargs: Any,
83+
) -> torch.Tensor:
84+
if len(x.shape) == 2:
85+
qkv_format = "thd"
86+
num_tokens = x.shape[0]
87+
else:
88+
qkv_format = "bshd"
89+
bsz, seqlen, _ = x.size()
90+
91+
q = self.q_proj(x)
92+
k = self.k_proj(x)
93+
v = self.v_proj(x)
94+
95+
if qkv_format == "thd":
96+
q = q.view(num_tokens, self.num_heads, self.head_dim)
97+
k = k.view(num_tokens, self.num_kv_heads, self.head_dim)
98+
v = v.view(num_tokens, self.num_kv_heads, self.head_dim)
99+
else:
100+
q = q.view(bsz, seqlen, self.num_heads, self.head_dim)
101+
k = k.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
102+
v = v.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
103+
104+
q = self.q_norm(q)
105+
k = self.k_norm(k)
106+
107+
q, k = apply_rotary_emb_qk(
108+
q,
109+
k,
110+
freqs_cis,
111+
format=qkv_format,
112+
rope_fusion=self.backend.rope_fusion,
113+
cu_seqlens=attn_kwargs.get("cu_seqlens", None),
114+
cp_size=attn_kwargs.get("cp_size", 1),
115+
cp_rank=attn_kwargs.get("cp_rank", 0),
116+
)
117+
118+
q, k, v, _attn_kwargs = preprocess_args_and_kwargs_for_attn(
119+
q, k, v, attention_mask, self.backend.attn, **attn_kwargs
120+
)
121+
out = self.attn_func(q, k, v, **_attn_kwargs)
122+
out = postprocess_output_for_attn(out, self.backend.attn)
123+
124+
flatten_dim = 2 if qkv_format == "bshd" else 1
125+
out = self.o_proj(out.flatten(flatten_dim))
126+
return out
127+
128+
def init_weights(self, buffer_device: torch.device, init_std: float = 0.02):
129+
for linear in [self.q_proj, self.k_proj, self.v_proj, self.o_proj]:
130+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
131+
if hasattr(linear, "bias") and linear.bias is not None:
132+
nn.init.zeros_(linear.bias)
133+
for norm in (self.q_norm, self.k_norm):
134+
norm.reset_parameters()

0 commit comments

Comments
 (0)