Skip to content

Commit 26785ae

Browse files
authored
Fix the init_weights for the MoE models (#42306)
* fix the modulars * apply modulars * forgot jamba * fix doc
1 parent 6bc8121 commit 26785ae

20 files changed

+102
-28
lines changed

src/transformers/modeling_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,12 +2140,6 @@ def _init_weights(self, module):
21402140
init.ones_(module.weight)
21412141
if hasattr(module, "bias") and module.bias is not None:
21422142
init.zeros_(module.bias)
2143-
if isinstance(getattr(module, "gate_up_proj", None), nn.Parameter):
2144-
init.normal_(module.gate_up_proj, mean=0.0, std=std)
2145-
if isinstance(getattr(module, "down_proj", None), nn.Parameter):
2146-
init.normal_(module.down_proj, mean=0.0, std=std)
2147-
if isinstance(getattr(module, "gate", None), nn.Parameter):
2148-
init.normal_(module.gate, mean=0.0, std=std)
21492143

21502144
def _initialize_weights(self, module):
21512145
"""
@@ -2166,10 +2160,6 @@ def initialize_weights(self):
21662160
module graph along the recursion. It can handle an arbitrary number of sub-models. Without it, every composite
21672161
model would have to recurse a second time on all sub-models explicitly in the outer-most `_init_weights`, which
21682162
is extremely error prone and inefficient.
2169-
2170-
Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use
2171-
`torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as
2172-
`module.weight.zero_()`.
21732163
"""
21742164
if not hasattr(torch.nn.Module, "smart_apply"):
21752165
# This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function

src/transformers/models/deepseek_v2/modeling_deepseek_v2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,9 @@ class DeepseekV2PreTrainedModel(PreTrainedModel):
465465
@torch.no_grad()
466466
def _init_weights(self, module):
467467
super()._init_weights(module)
468-
if isinstance(module, DeepseekV2Moe):
469-
init.normal_(module.gate.weight, mean=0.0, std=self.config.initializer_range)
468+
if isinstance(module, DeepseekV2Experts):
469+
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
470+
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
470471

471472

472473
@auto_docstring

src/transformers/models/deepseek_v2/modular_deepseek_v2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,9 @@ class DeepseekV2PreTrainedModel(LlamaPreTrainedModel):
437437
@torch.no_grad()
438438
def _init_weights(self, module):
439439
PreTrainedModel._init_weights(self, module)
440-
if isinstance(module, DeepseekV2Moe):
441-
init.normal_(module.gate.weight, mean=0.0, std=self.config.initializer_range)
440+
if isinstance(module, DeepseekV2Experts):
441+
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
442+
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
442443

443444

444445
class DeepseekV2Model(LlamaModel):

src/transformers/models/deepseek_v3/modeling_deepseek_v3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,9 @@ def _init_weights(self, module):
554554
super()._init_weights(module)
555555
if isinstance(module, DeepseekV3TopkRouter):
556556
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
557+
elif isinstance(module, DeepseekV3NaiveMoe):
558+
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
559+
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
557560

558561

559562
@auto_docstring

src/transformers/models/deepseek_v3/modular_deepseek_v3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ def _init_weights(self, module):
310310
PreTrainedModel._init_weights(self, module)
311311
if isinstance(module, DeepseekV3TopkRouter):
312312
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
313+
elif isinstance(module, DeepseekV3NaiveMoe):
314+
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
315+
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
313316

314317

315318
class DeepseekV3Model(LlamaModel):

src/transformers/models/dots1/modeling_dots1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,9 @@ def _init_weights(self, module):
472472
super()._init_weights(module)
473473
if isinstance(module, Dots1TopkRouter):
474474
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
475+
elif isinstance(module, Dots1NaiveMoe):
476+
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
477+
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
475478

476479

477480
@auto_docstring

src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,12 @@ def _init_weights(self, module):
501501
super()._init_weights(module)
502502
if isinstance(module, Ernie4_5_MoeStatics):
503503
init.zeros_(module.e_score_correction_bias)
504+
elif isinstance(module, Ernie4_5_MoeExperts):
505+
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
506+
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
507+
if module.gate_up_proj_bias is not None:
508+
init.zeros_(module.gate_up_proj_bias)
509+
init.zeros_(module.down_proj_bias)
504510

505511

506512
@auto_docstring

src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,12 @@ def _init_weights(self, module):
242242
PreTrainedModel._init_weights(self, module)
243243
if isinstance(module, Ernie4_5_MoeStatics):
244244
init.zeros_(module.e_score_correction_bias)
245+
elif isinstance(module, Ernie4_5_MoeExperts):
246+
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
247+
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
248+
if module.gate_up_proj_bias is not None:
249+
init.zeros_(module.gate_up_proj_bias)
250+
init.zeros_(module.down_proj_bias)
245251

246252

247253
@auto_docstring

src/transformers/models/glm4_moe/modeling_glm4_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,9 @@ def _init_weights(self, module):
498498
super()._init_weights(module)
499499
if isinstance(module, Glm4MoeTopkRouter):
500500
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
501+
elif isinstance(module, Glm4MoeNaiveMoe):
502+
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
503+
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
501504

502505

503506
@auto_docstring

src/transformers/models/glm4v_moe/modeling_glm4v_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,9 @@ def _init_weights(self, module):
559559
super()._init_weights(module)
560560
if isinstance(module, Glm4vMoeTextTopkRouter):
561561
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
562+
elif isinstance(module, Glm4vMoeTextNaiveMoe):
563+
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
564+
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)
562565

563566

564567
@dataclass

0 commit comments

Comments
 (0)