Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions docker/common/uv-pytorch.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3349,7 +3349,7 @@ requires-dist = [
{ name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'vlm'" },
{ name = "torchdata" },
{ name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'cuda'", specifier = "<=2.11.0" },
{ name = "transformers", specifier = ">=5.0.0" },
{ name = "transformers", specifier = ">=5.2.0" },
{ name = "wandb" },
]
provides-extras = ["cuda", "extra", "fa", "delta-databricks", "moe", "vlm", "all"]
Expand Down Expand Up @@ -6398,10 +6398,9 @@ wheels = [

[[package]]
name = "transformers"
version = "5.0.0"
version = "5.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock" },
{ name = "huggingface-hub" },
{ name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.13'" },
{ name = "numpy", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" },
Expand All @@ -6413,9 +6412,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typer-slim" },
]
sdist = { url = "https://files.pythonhosted.org/packages/bc/79/845941711811789c85fb7e2599cea425a14a07eda40f50896b9d3fda7492/transformers-5.0.0.tar.gz", hash = "sha256:5f5634efed6cf76ad068cc5834c7adbc32db78bbd6211fb70df2325a9c37dec8", size = 8424830, upload-time = "2026-01-26T10:46:46.813Z" }
sdist = { url = "https://files.pythonhosted.org/packages/bd/7e/8a0c57d562015e5b16c97c1f0b8e0e92ead2c7c20513225dc12c2043ba9f/transformers-5.2.0.tar.gz", hash = "sha256:0088b8b46ccc9eff1a1dca72b5d618a5ee3b1befc3e418c9512b35dea9f9a650", size = 8618176, upload-time = "2026-02-16T18:54:02.867Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/52/f3/ac976fa8e305c9e49772527e09fbdc27cc6831b8a2f6b6063406626be5dd/transformers-5.0.0-py3-none-any.whl", hash = "sha256:587086f249ce64c817213cf36afdb318d087f790723e9b3d4500b97832afd52d", size = 10142091, upload-time = "2026-01-26T10:46:43.88Z" },
{ url = "https://files.pythonhosted.org/packages/4e/93/79754b0ca486e556c2b95d4f5afc66aaf4b260694f3d6e1b51da2d036691/transformers-5.2.0-py3-none-any.whl", hash = "sha256:9ecaf243dc45bee11a7d93f8caf03746accc0cb069181bbf4ad8566c53e854b4", size = 10403304, upload-time = "2026-02-16T18:53:59.699Z" },
]

[[package]]
Expand Down
13 changes: 13 additions & 0 deletions nemo_automodel/components/models/glm_moe_dsa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
302 changes: 302 additions & 0 deletions nemo_automodel/components/models/glm_moe_dsa/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
import torch.nn as nn
from transformers.models.glm_moe_dsa.configuration_glm_moe_dsa import GlmMoeDsaConfig

from nemo_automodel.components.models.common import BackendConfig, initialize_linear_module, initialize_rms_norm_module
from nemo_automodel.components.models.common.hf_checkpointing_mixin import HFCheckpointingMixin
from nemo_automodel.components.models.deepseek_v3.rope_utils import (
freqs_cis_from_position_ids,
precompute_freqs_cis,
)
from nemo_automodel.components.models.deepseek_v32.layers import DeepseekV32MLA
from nemo_automodel.components.models.glm_moe_dsa.state_dict_adapter import GlmMoeDsaStateDictAdapter
from nemo_automodel.components.moe.fsdp_mixin import MoEFSDPSyncMixin
from nemo_automodel.components.moe.layers import MLP, MoE, MoEConfig
from nemo_automodel.components.utils.model_utils import squeeze_input_for_thd
from nemo_automodel.shared.utils import dtype_from_str as get_dtype


class Block(nn.Module):
def __init__(self, layer_idx: int, config: GlmMoeDsaConfig, moe_config: MoEConfig, backend: BackendConfig):
super().__init__()
self.self_attn = DeepseekV32MLA(config, backend)

mlp_layer_types = getattr(config, "mlp_layer_types", None)
if mlp_layer_types is not None:
is_moe_layer = mlp_layer_types[layer_idx] == "sparse"
else:
first_k_dense_replace = getattr(config, "first_k_dense_replace", 0)
is_moe_layer = layer_idx >= first_k_dense_replace

if is_moe_layer:
self.mlp = MoE(moe_config, backend)
else:
self.mlp = MLP(config.hidden_size, config.intermediate_size, backend.linear)

self.input_layernorm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = initialize_rms_norm_module(
backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps
)
self.layer_idx = layer_idx

def forward(
self,
x: torch.Tensor,
*,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) -> torch.Tensor:
if attention_mask is not None and padding_mask is None:
padding_mask = attention_mask.bool().logical_not()

attn_out = self.self_attn(
x=self.input_layernorm(x),
freqs_cis=freqs_cis,
attention_mask=attention_mask,
**attn_kwargs,
)
x = x + attn_out

mlp_out = self._mlp(x=self.post_attention_layernorm(x), padding_mask=padding_mask)
x = x + mlp_out
return x

def _mlp(self, x: torch.Tensor, padding_mask: torch.Tensor | None) -> torch.Tensor:
if isinstance(self.mlp, MLP):
return self.mlp(x)
else:
assert isinstance(self.mlp, MoE)
return self.mlp(x, padding_mask)

def init_weights(self, buffer_device: torch.device):
for norm in (self.input_layernorm, self.post_attention_layernorm):
norm.reset_parameters()
self.self_attn.init_weights(buffer_device)
self.mlp.init_weights(buffer_device)


class GlmMoeDsaModel(nn.Module):
def __init__(self, config: GlmMoeDsaConfig, backend: BackendConfig, *, moe_config: MoEConfig | None = None):
super().__init__()
self.backend = backend
self.config = config

self.moe_config = moe_config or MoEConfig(
dim=config.hidden_size,
inter_dim=config.intermediate_size,
moe_inter_dim=config.moe_intermediate_size,
n_routed_experts=config.n_routed_experts,
n_shared_experts=config.n_shared_experts,
n_activated_experts=config.num_experts_per_tok,
n_expert_groups=config.n_group,
n_limited_groups=config.topk_group,
train_gate=True,
gate_bias_update_factor=0.001,
score_func="sigmoid",
route_scale=config.routed_scaling_factor,
aux_loss_coeff=0.0,
norm_topk_prob=config.norm_topk_prob,
expert_bias=False,
router_bias=False,
expert_activation="swiglu",
softmax_before_topk=False,
)

self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, dtype=get_dtype(config.torch_dtype, torch.bfloat16)
)
self.layers = torch.nn.ModuleDict()
for layer_id in range(config.num_hidden_layers):
self.layers[str(layer_id)] = Block(layer_id, config, self.moe_config, backend)
self.norm = initialize_rms_norm_module(backend.rms_norm, config.hidden_size, eps=config.rms_norm_eps)

self.max_seq_len = config.max_position_embeddings
self.qk_rope_head_dim = config.qk_rope_head_dim

if hasattr(config, "rope_parameters") and config.rope_parameters is not None:
rope_theta = config.rope_parameters["rope_theta"]
else:
rope_theta = config.rope_theta

rope_scaling = getattr(config, "rope_scaling", None)

self.freqs = precompute_freqs_cis(
qk_rope_head_dim=self.qk_rope_head_dim,
max_seq_len=self.max_seq_len,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
)

def forward(
self,
input_ids: torch.Tensor,
*,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) -> torch.Tensor:
if position_ids is None:
position_ids = (
torch.arange(0, input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.shape[0], -1)
)

freqs_cis = freqs_cis_from_position_ids(
position_ids,
self.freqs.to(position_ids.device),
qkv_format=attn_kwargs.get("qkv_format", "bshd"),
for_fused_rope=self.backend.rope_fusion,
cp_size=attn_kwargs.get("cp_size", 1),
)

h = self.embed_tokens(input_ids) if self.embed_tokens is not None else input_ids

for layer in self.layers.values():
h = layer(
x=h,
freqs_cis=freqs_cis,
attention_mask=attention_mask,
padding_mask=padding_mask,
**attn_kwargs,
)

h = self.norm(h) if self.norm else h
return h

@torch.no_grad()
def init_weights(self, buffer_device: torch.device | None = None) -> None:
buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}")

with buffer_device:
if self.embed_tokens is not None:
nn.init.normal_(self.embed_tokens.weight)
if self.norm is not None:
self.norm.reset_parameters()

for layer in self.layers.values():
if layer is not None:
layer.init_weights(buffer_device=buffer_device)


class GlmMoeDsaForCausalLM(HFCheckpointingMixin, nn.Module, MoEFSDPSyncMixin):
@classmethod
def from_config(
cls,
config: GlmMoeDsaConfig,
moe_config: MoEConfig | None = None,
backend: BackendConfig | None = None,
**kwargs,
):
return cls(config, moe_config, backend, **kwargs)

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
*model_args,
**kwargs,
):
config = GlmMoeDsaConfig.from_pretrained(pretrained_model_name_or_path)
return cls.from_config(config, *model_args, **kwargs)

def __init__(
self,
config: GlmMoeDsaConfig,
moe_config: MoEConfig | None = None,
backend: BackendConfig | None = None,
**kwargs,
):
super().__init__()
self.config = config
self.backend = backend or BackendConfig()
self.model = GlmMoeDsaModel(config, backend=self.backend, moe_config=moe_config)
self.lm_head = initialize_linear_module(self.backend.linear, config.hidden_size, config.vocab_size, bias=False)
if self.backend.enable_hf_state_dict_adapter:
self.state_dict_adapter = GlmMoeDsaStateDictAdapter(
self.config, self.model.moe_config, self.backend, dtype=get_dtype(config.torch_dtype, torch.bfloat16)
)

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def forward(
self,
input_ids: torch.Tensor,
*,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) -> torch.Tensor:
if "qkv_format" in attn_kwargs and attn_kwargs["qkv_format"] == "thd":
input_ids, position_ids, padding_mask, attn_kwargs = squeeze_input_for_thd(
input_ids, position_ids, padding_mask, attn_kwargs
)
attention_mask = None

hidden = self.model(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
padding_mask=padding_mask,
**attn_kwargs,
)
logits = self.lm_head(hidden) if self.lm_head else hidden
if "qkv_format" in attn_kwargs and attn_kwargs["qkv_format"] == "thd":
logits = logits.unsqueeze(0)
return logits

@torch.no_grad()
def initialize_weights(
self, buffer_device: torch.device | None = None, dtype: torch.dtype = torch.bfloat16
) -> None:
buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}")
with buffer_device:
self.model.init_weights(buffer_device=buffer_device)
final_out_std = self.config.hidden_size**-0.5
cutoff_factor = 3
if self.lm_head is not None:
nn.init.trunc_normal_(
self.lm_head.weight,
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)

self.to(dtype)
for layer in self.model.layers.values():
if isinstance(layer.mlp, MoE):
layer.mlp.gate.e_score_correction_bias = torch.zeros(
(self.config.n_routed_experts), dtype=torch.float32
).to(buffer_device)


ModelClass = GlmMoeDsaForCausalLM
Loading