From 1c5ddd56e13f29a413edef9e3fae4f8a6f73ebdc Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Wed, 19 Nov 2025 20:07:00 +0100 Subject: [PATCH 1/5] 1) make the moe's load_balance_coeff configurable 2) add the batch and seq-wise aux loss for load balance --- torchtitan/components/loss.py | 18 +++ torchtitan/config/job_config.py | 15 +++ torchtitan/models/deepseek_v3/model/args.py | 5 + torchtitan/models/deepseek_v3/model/model.py | 25 +++- torchtitan/models/moe/moe.py | 124 ++++++++++++++++++- torchtitan/train.py | 8 +- 6 files changed, 186 insertions(+), 9 deletions(-) diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index 6fb80f39cb..b8b52993e1 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -78,3 +78,21 @@ def build_mse_loss(job_config: JobConfig, **kwargs): logger.info("Compiling the loss function with torch.compile") loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) return loss_fn + + +def moe_loss( + pred: tuple[torch.Tensor, torch.Tensor] | torch.Tensor, + labels: torch.Tensor, + loss_fn: LossFunction, +) -> torch.Tensor: + """Sequence-wise auxiliary load balance loss function for MoE + model training. + """ + if isinstance(pred, tuple): + pred, load_balance_loss = pred + loss = loss_fn(pred, labels) + # USE STE to make the magnitude of loss remain the same + loss = loss + (load_balance_loss - load_balance_loss.detach()) + else: + loss = loss_fn(pred, labels) + return loss diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 95588d2c3b..c4d40108c8 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -97,6 +97,18 @@ class Metrics: """Whether to log metrics to Weights & Biases""" +@dataclass +class ExtraLosses: + load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" + """Type of load balance loss to use""" + + load_balance_loss_weight: float = 0 + """Weight of load balance loss""" + + load_balance_coeff: float | None = 1e-3 + """Coefficient of bias update for aux-loss-free load balancing""" + + @dataclass class Model: name: str = "llama3" @@ -130,6 +142,9 @@ class Model: converters have been applied. """ + extra_losses: ExtraLosses = field(default_factory=ExtraLosses) + """Extra losses to use""" + @dataclass class Optimizer: diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 48d4b5ece1..fab9862b91 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -95,6 +95,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len + losses_config = job_config.model.extra_losses + self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type + self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight + self.moe_args.load_balance_coeff = losses_config.load_balance_coeff + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): logger.warning( "Failed to use grouped mm, which is only supported on SM90 or later", diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 3cf56eb1b2..25b168806a 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -309,6 +309,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + accumulated_load_balance_loss: torch.Tensor, attention_masks: AttentionMasksType | None, ): """ @@ -323,10 +324,15 @@ def forward( """ x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) if self.moe_enabled: - x = x + self.moe(self.ffn_norm(x)) + ffn_moe_output, load_balance_loss = self.moe(self.ffn_norm(x)) + accumulated_load_balance_loss = ( + accumulated_load_balance_loss + load_balance_loss + ) else: - x = x + self.feed_forward(self.ffn_norm(x)) - return x + ffn_moe_output = self.feed_forward(self.ffn_norm(x)) + + x = x + ffn_moe_output + return x, accumulated_load_balance_loss def init_weights(self, buffer_device: torch.device): for norm in (self.attention_norm, self.ffn_norm): @@ -410,6 +416,7 @@ def get_attention_masks( def forward( self, tokens: torch.Tensor, + accumulated_load_balance_loss: torch.Tensor | None = None, attention_masks: AttentionMasksType | None = None, ): """ @@ -427,8 +434,16 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens + accumulated_load_balance_loss = ( + torch.zeros((), device=h.device, dtype=torch.float32) + if accumulated_load_balance_loss is None + else accumulated_load_balance_loss + ) + for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks) + h, accumulated_load_balance_loss = layer( + h, self.freqs_cis, accumulated_load_balance_loss, attention_masks + ) h = self.norm(h) if self.norm is not None else h output = self.output(h) if self.output is not None else h - return output + return output, accumulated_load_balance_loss diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 295e2193a5..471fc7c076 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -30,7 +30,8 @@ class MoEArgs: top_k: int = 1 use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation load_balance_coeff: float | None = 1e-3 - + load_balance_loss_weight: float = 0 + load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" _debug_force_load_balance: bool = False # if True, we force each experts get same amount of token via round-robin @@ -287,7 +288,7 @@ def forward( max=self.num_experts, ) - return top_scores, selected_experts_indices, num_tokens_per_expert + return top_scores, scores, selected_experts_indices, num_tokens_per_expert def init_weights(self, init_std: float): nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) @@ -359,6 +360,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): super().__init__() num_experts = moe_args.num_experts + self.topk = moe_args.top_k + self.num_experts = num_experts self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, @@ -386,6 +389,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): # NOTE: tokens_per_expert is accumulated in the model forward pass. # expert_bias is updated outside the model in an optimizer step pre hook # to work with gradient accumulation. + self.load_balance_loss_weight = moe_args.load_balance_loss_weight + self.load_balance_loss_type = moe_args.load_balance_loss_type self.load_balance_coeff = moe_args.load_balance_coeff if self.load_balance_coeff is not None: assert self.load_balance_coeff > 0.0 @@ -418,6 +423,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # num_tokens_per_expert shape (num_experts,) ( top_scores, + scores, selected_experts_indices, num_tokens_per_expert, ) = self.router(x, self.expert_bias) @@ -430,6 +436,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) + if self.training: + if self.load_balance_loss_type == "sequence_wise": + load_balance_loss = MoE.sequence_wise_aux_loss( + scores, + selected_experts_indices.long(), + bs, + slen, + self.topk, + self.load_balance_loss_weight, + ) + elif self.load_balance_loss_type == "batch_wise": + load_balance_loss = MoE.batch_wise_aux_loss( + scores, + num_tokens_per_expert, + self.topk, + self.load_balance_loss_weight, + ) + else: + load_balance_loss = torch.tensor(0.0, device=out.device, dtype=out.dtype) + # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) # NOTE: the reason we need to compute num_tokens_per_expert again is: @@ -479,7 +505,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dim=0, index=token_indices_experts_sorted, src=routed_output ) out = out.reshape(bs, slen, dim) - return out + + return out, load_balance_loss def init_weights( self, @@ -499,3 +526,94 @@ def init_weights( self.expert_bias = torch.zeros( self.experts.num_experts, dtype=torch.float32 ) + + @staticmethod + @torch.compile(fullgraph=True) + def sequence_wise_aux_loss( + scores: torch.Tensor, # Shape: (B*S, N) - Raw Sigmoid Affinities (s_{i,t}) + indices: torch.Tensor, # Shape: (B*S, K) - Selected Expert Indices + B: int, # Batch size + S: int, # Sequence length (T in the paper) + top_k: int, # K_r + aux_loss_alpha: float, # Alpha + ) -> torch.Tensor: + """ + Computes Sequence-Wise Auxiliary Loss (DeepSeek-V3 Equations 17-20). + + Args: + scores: The dense affinity scores (s_{i,t}) for routed experts. + Should be the output of Sigmoid, shape (B*S, N). + indices: The top-k selected expert indices. Shape (B*S, K). + """ + if aux_loss_alpha <= 0: + return torch.tensor(0.0, device=scores.device, dtype=scores.dtype) + + # N_r: Total number of routed experts + N = scores.size(-1) + + # 1. Reshape inputs to handle each sequence separately: (B, S, N) + # This ensures we calculate P_i and f_i per sequence (Eq 20 & 18). + scores_per_seq = scores.view(B, S, N) + indices_per_seq = indices.view(B, S, top_k) + + # 2. Eq 19: Normalize affinity scores s_{i,t} to get s'_{i,t} + # DeepSeek-V3 uses Sigmoid, so scores don't sum to 1. + # Eq 19 explicitly requires dividing by the sum of all affinities. + # denominator shape: (B, S, 1) + denominator = scores_per_seq.sum(dim=-1, keepdim=True) + 1e-20 + probs_per_seq = scores_per_seq / denominator # This is s'_{i,t} + + # 3. Eq 20: Calculate P_i (Average probability per expert for each sequence) + # P_i = (1/T) * sum_{t=1}^T (s'_{i,t}) + # We average over the Sequence dimension (dim=1). + # P_i shape: (B, N) + P_i = probs_per_seq.mean(dim=1) + + # 4. Eq 18: Calculate f_i (Fraction of tokens selecting expert i per sequence) + # f_i = (N / (K * T)) * count_i + + # Flatten the top-k dimension to count hits per sequence: (B, S*K) + flat_indices_per_seq = indices_per_seq.view(B, -1) + selection_counts = torch.zeros((B, N), device=scores.device, dtype=scores.dtype) + src = torch.ones_like(flat_indices_per_seq, dtype=scores.dtype) + selection_counts.scatter_add_(1, flat_indices_per_seq, src) + + # Calculate f_i for each sequence, T (tokens in sequence) is S + f_i = selection_counts * (N / (top_k * S)) + + # 5. Eq 17: Calculate Balance Loss + loss_per_seq = (f_i * P_i).sum(dim=1) * aux_loss_alpha + + return loss_per_seq.mean() + + @staticmethod + @torch.compile(fullgraph=True) + def batch_wise_aux_loss( + scores: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + top_k: int, + aux_loss_alpha: float, + ) -> torch.Tensor: + """ + Computes Batch-Wise Auxiliary Loss. + Args: + scores: Dense probabilities (BS, N). + num_tokens_per_expert: Token counts (N). + top_k: Number of experts selected per token. + aux_loss_alpha: Scaling factor for the loss. + """ + if aux_loss_alpha <= 0: + return torch.tensor(0.0, device=scores.device, dtype=scores.dtype) + + # Total number of routed experts (N) + N = scores.size(1) + # Total number of tokens (T = BS * S) + T = scores.size(0) + + P_i = scores.mean(dim=0) + + f_i = num_tokens_per_expert.to(scores.dtype) * (N / (top_k * T)) + + loss = (f_i * P_i).sum() * aux_loss_alpha + + return loss diff --git a/torchtitan/train.py b/torchtitan/train.py index 5cfab998b2..bab206cb00 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import functools import importlib import os import time @@ -18,7 +19,7 @@ from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training -from torchtitan.components.loss import rescale_accumulated_loss +from torchtitan.components.loss import moe_loss, rescale_accumulated_loss from torchtitan.components.metrics import ( build_metrics_processor, ensure_pp_loss_visible, @@ -184,6 +185,11 @@ def __init__(self, job_config: JobConfig): job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager ) + self.loss_fn = functools.partial( + moe_loss, + loss_fn=self.loss_fn, + ) + # verify batch sizes global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: From 77dd533dcad8c7c78fcbe1168ac37975d5eb4401 Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Tue, 9 Dec 2025 02:58:37 +0100 Subject: [PATCH 2/5] fix loss build and move loss-weights to training scope, and other fixing --- torchtitan/components/loss.py | 26 +++++++- torchtitan/components/optimizer.py | 8 +-- torchtitan/config/job_config.py | 30 +++++----- torchtitan/experiments/gpt_oss/__init__.py | 6 +- torchtitan/models/deepseek_v3/__init__.py | 4 +- torchtitan/models/deepseek_v3/model/args.py | 6 +- torchtitan/models/llama4/__init__.py | 4 +- torchtitan/models/llama4/model/args.py | 7 +++ torchtitan/models/moe/moe.py | 66 ++++++++++----------- torchtitan/models/qwen3/__init__.py | 4 +- torchtitan/models/qwen3/model/args.py | 7 +++ torchtitan/train.py | 8 +-- 12 files changed, 104 insertions(+), 72 deletions(-) diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index b8b52993e1..c0e1a85076 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -85,14 +85,36 @@ def moe_loss( labels: torch.Tensor, loss_fn: LossFunction, ) -> torch.Tensor: - """Sequence-wise auxiliary load balance loss function for MoE + """Sequence-wise (or batch-wise) auxiliary load balance loss function for MoE model training. """ if isinstance(pred, tuple): pred, load_balance_loss = pred loss = loss_fn(pred, labels) - # USE STE to make the magnitude of loss remain the same + # Add auxiliary loss to the computation graph for gradients in the backward pass, + # but cancel out its numeric value so the forward pass only logs language model task loss. loss = loss + (load_balance_loss - load_balance_loss.detach()) else: loss = loss_fn(pred, labels) return loss + + +def moe_loss_wrap(original_build_fn): + """ + Wraps a loss builder function. It builds the base loss function first, + then wraps it with the 'moe_loss' logic before returning it. + """ + + @functools.wraps(original_build_fn) # Preserves name/docstring of original + def wrapper(job_config, **kwargs): + # 1. Create the base loss function (e.g., standard CrossEntropy) + # We pass job_config and kwargs through exactly as the original expects + base_loss_fn = original_build_fn(job_config, **kwargs) + + # 2. Apply the MoE wrapper immediately + # This binds 'base_loss_fn' to the 'loss_fn' argument of 'moe_loss' + final_loss_fn = functools.partial(moe_loss, loss_fn=base_loss_fn) + + return final_loss_fn + + return wrapper diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 7fc5098800..353003eec5 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -344,8 +344,8 @@ def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool: for model_part in model_parts: for transformer_block in model_part.layers.values(): if transformer_block.moe_enabled: - # Assumption: load_balance_coeff is set universally on all moe blocks. - return bool(transformer_block.moe.load_balance_coeff) + # Assumption: moe_aux_loss_free_bias_coeff is set universally on all moe blocks. + return bool(transformer_block.moe.moe_aux_loss_free_bias_coeff) return False # for MoE auxiliary-loss-free load balancing @@ -366,7 +366,7 @@ def _update_expert_bias( for transformer_block in model_part.layers.values(): if not transformer_block.moe_enabled: continue - if transformer_block.moe.load_balance_coeff is None: + if transformer_block.moe.moe_aux_loss_free_bias_coeff is None: return tokens_per_expert = transformer_block.moe.tokens_per_expert if _is_recomputation_enabled(transformer_block): @@ -401,7 +401,7 @@ def _update_expert_bias( # update the expert bias # this is not exactly the same as https://arxiv.org/pdf/2408.15664 proposed - expert_bias_delta = moe.load_balance_coeff * torch.sign( + expert_bias_delta = moe.moe_aux_loss_free_bias_coeff * torch.sign( tokens_per_expert.mean() - tokens_per_expert ) expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index c4d40108c8..3bfb6a60a1 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -97,18 +97,6 @@ class Metrics: """Whether to log metrics to Weights & Biases""" -@dataclass -class ExtraLosses: - load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" - """Type of load balance loss to use""" - - load_balance_loss_weight: float = 0 - """Weight of load balance loss""" - - load_balance_coeff: float | None = 1e-3 - """Coefficient of bias update for aux-loss-free load balancing""" - - @dataclass class Model: name: str = "llama3" @@ -142,9 +130,6 @@ class Model: converters have been applied. """ - extra_losses: ExtraLosses = field(default_factory=ExtraLosses) - """Extra losses to use""" - @dataclass class Optimizer: @@ -213,6 +198,18 @@ class LRScheduler: """ +@dataclass +class ExtraLosses: + load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" + """Type of load balance loss to use""" + + load_balance_loss_weight: float = 0 + """Weight of load balance loss""" + + moe_aux_loss_free_bias_coeff: float | None = 1e-3 + """The coefficient of the bias update for aux-loss-free load balancing""" + + @dataclass class Training: dataset: str = "c4_test" @@ -241,6 +238,9 @@ class Training: steps: int = 10000 """How many train steps to run""" + extra_losses: ExtraLosses = field(default_factory=ExtraLosses) + """If we have multiple of losses, we can configure their weights here""" + enable_cpu_offload: bool = False """ Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py index c12ad13a5c..0fcb4f8637 100644 --- a/torchtitan/experiments/gpt_oss/__init__.py +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -38,7 +38,7 @@ score_before_experts=False, top_k=4, use_grouped_mm=True, - load_balance_coeff=1e-3, + moe_aux_loss_free_bias_coeff=1e-3, ), attn_mask_type="causal", ), @@ -53,7 +53,7 @@ score_before_experts=False, top_k=4, use_grouped_mm=True, - load_balance_coeff=1e-3, + moe_aux_loss_free_bias_coeff=1e-3, ), ), "120b": GptOssModelArgs( @@ -67,7 +67,7 @@ score_before_experts=False, top_k=4, use_grouped_mm=True, - load_balance_coeff=1e-3, + moe_aux_loss_free_bias_coeff=1e-3, ), ), } diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 525bd96c13..fb53d89744 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.loss import build_cross_entropy_loss, moe_loss_wrap from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer @@ -167,6 +167,6 @@ def get_train_spec() -> TrainSpec: build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_text_dataloader, build_tokenizer_fn=build_hf_tokenizer, - build_loss_fn=build_cross_entropy_loss, + build_loss_fn=moe_loss_wrap(build_cross_entropy_loss), state_dict_adapter=DeepSeekV3StateDictAdapter, ) diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index fab9862b91..4f7f4320f2 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -95,10 +95,12 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len - losses_config = job_config.model.extra_losses + losses_config = job_config.training.extra_losses self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight - self.moe_args.load_balance_coeff = losses_config.load_balance_coeff + self.moe_args.moe_aux_loss_free_bias_coeff = ( + losses_config.moe_aux_loss_free_bias_coeff + ) if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): logger.warning( diff --git a/torchtitan/models/llama4/__init__.py b/torchtitan/models/llama4/__init__.py index 24196c2326..bd7f06b54e 100644 --- a/torchtitan/models/llama4/__init__.py +++ b/torchtitan/models/llama4/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.loss import build_cross_entropy_loss, moe_loss_wrap from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer @@ -112,7 +112,7 @@ def get_train_spec() -> TrainSpec: build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_text_dataloader, build_tokenizer_fn=build_hf_tokenizer, - build_loss_fn=build_cross_entropy_loss, + build_loss_fn=moe_loss_wrap(build_cross_entropy_loss), build_validator_fn=build_validator, state_dict_adapter=Llama4StateDictAdapter, ) diff --git a/torchtitan/models/llama4/model/args.py b/torchtitan/models/llama4/model/args.py index 7fcc9871f5..3e9baebd84 100644 --- a/torchtitan/models/llama4/model/args.py +++ b/torchtitan/models/llama4/model/args.py @@ -70,6 +70,13 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len + losses_config = job_config.training.extra_losses + self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type + self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight + self.moe_args.moe_aux_loss_free_bias_coeff = ( + losses_config.moe_aux_loss_free_bias_coeff + ) + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): logger.warning( "Failed to use grouped mm, which is only supported on SM90 or later", diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 471fc7c076..b2cc8fdc04 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -29,7 +29,7 @@ class MoEArgs: # token-choice top_k: int = 1 use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation - load_balance_coeff: float | None = 1e-3 + moe_aux_loss_free_bias_coeff: float | None = 1e-3 load_balance_loss_weight: float = 0 load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" _debug_force_load_balance: bool = False @@ -360,8 +360,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): super().__init__() num_experts = moe_args.num_experts - self.topk = moe_args.top_k - self.num_experts = num_experts + self.top_k = moe_args.top_k self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, @@ -391,9 +390,9 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): # to work with gradient accumulation. self.load_balance_loss_weight = moe_args.load_balance_loss_weight self.load_balance_loss_type = moe_args.load_balance_loss_type - self.load_balance_coeff = moe_args.load_balance_coeff - if self.load_balance_coeff is not None: - assert self.load_balance_coeff > 0.0 + self.moe_aux_loss_free_bias_coeff = moe_args.moe_aux_loss_free_bias_coeff + if self.moe_aux_loss_free_bias_coeff is not None: + assert self.moe_aux_loss_free_bias_coeff > 0.0 self.register_buffer( "expert_bias", torch.zeros(num_experts, dtype=torch.float32), @@ -443,14 +442,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: selected_experts_indices.long(), bs, slen, - self.topk, + self.top_k, self.load_balance_loss_weight, ) elif self.load_balance_loss_type == "batch_wise": load_balance_loss = MoE.batch_wise_aux_loss( scores, num_tokens_per_expert, - self.topk, + self.top_k, self.load_balance_loss_weight, ) else: @@ -505,7 +504,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dim=0, index=token_indices_experts_sorted, src=routed_output ) out = out.reshape(bs, slen, dim) - return out, load_balance_loss def init_weights( @@ -522,20 +520,19 @@ def init_weights( self.tokens_per_expert = torch.zeros( self.experts.num_experts, dtype=torch.float32 ) - if self.load_balance_coeff is not None: + if self.moe_aux_loss_free_bias_coeff is not None: self.expert_bias = torch.zeros( self.experts.num_experts, dtype=torch.float32 ) @staticmethod - @torch.compile(fullgraph=True) def sequence_wise_aux_loss( - scores: torch.Tensor, # Shape: (B*S, N) - Raw Sigmoid Affinities (s_{i,t}) - indices: torch.Tensor, # Shape: (B*S, K) - Selected Expert Indices - B: int, # Batch size - S: int, # Sequence length (T in the paper) - top_k: int, # K_r - aux_loss_alpha: float, # Alpha + scores: torch.Tensor, + indices: torch.Tensor, + B: int, + S: int, + top_k: int, + aux_loss_weight: float, ) -> torch.Tensor: """ Computes Sequence-Wise Auxiliary Loss (DeepSeek-V3 Equations 17-20). @@ -544,17 +541,20 @@ def sequence_wise_aux_loss( scores: The dense affinity scores (s_{i,t}) for routed experts. Should be the output of Sigmoid, shape (B*S, N). indices: The top-k selected expert indices. Shape (B*S, K). + B: Batch size + S: Sequence length (T in the paper) + top_k: K_r in the paper, the number of experts each token will be routed to + aux_loss_weight: the weight of the auxiliary loss """ - if aux_loss_alpha <= 0: + if aux_loss_weight <= 0: return torch.tensor(0.0, device=scores.device, dtype=scores.dtype) - # N_r: Total number of routed experts + # N: Total number of routed experts N = scores.size(-1) # 1. Reshape inputs to handle each sequence separately: (B, S, N) # This ensures we calculate P_i and f_i per sequence (Eq 20 & 18). scores_per_seq = scores.view(B, S, N) - indices_per_seq = indices.view(B, S, top_k) # 2. Eq 19: Normalize affinity scores s_{i,t} to get s'_{i,t} # DeepSeek-V3 uses Sigmoid, so scores don't sum to 1. @@ -573,36 +573,36 @@ def sequence_wise_aux_loss( # f_i = (N / (K * T)) * count_i # Flatten the top-k dimension to count hits per sequence: (B, S*K) - flat_indices_per_seq = indices_per_seq.view(B, -1) - selection_counts = torch.zeros((B, N), device=scores.device, dtype=scores.dtype) - src = torch.ones_like(flat_indices_per_seq, dtype=scores.dtype) - selection_counts.scatter_add_(1, flat_indices_per_seq, src) + flat_indices_per_seq = indices.view(B, -1) + offset = torch.arange(B, device=flat_indices_per_seq.device).unsqueeze(1) * N + flat_indices = (flat_indices_per_seq + offset).reshape(-1) + selection_counts = torch.bincount(flat_indices, minlength=B * N).reshape(B, N) + selection_counts = selection_counts.to(dtype=scores.dtype) # Calculate f_i for each sequence, T (tokens in sequence) is S f_i = selection_counts * (N / (top_k * S)) # 5. Eq 17: Calculate Balance Loss - loss_per_seq = (f_i * P_i).sum(dim=1) * aux_loss_alpha + loss_per_seq = (f_i * P_i).sum(dim=1) * aux_loss_weight return loss_per_seq.mean() @staticmethod - @torch.compile(fullgraph=True) def batch_wise_aux_loss( scores: torch.Tensor, num_tokens_per_expert: torch.Tensor, top_k: int, - aux_loss_alpha: float, + aux_loss_weight: float, ) -> torch.Tensor: """ Computes Batch-Wise Auxiliary Loss. Args: - scores: Dense probabilities (BS, N). - num_tokens_per_expert: Token counts (N). - top_k: Number of experts selected per token. - aux_loss_alpha: Scaling factor for the loss. + scores: the dense probabilities (s_{i,t}) for routed experts. + num_tokens_per_expert: the number of tokens assigned to each expert (f_i). + top_k: K_r in the paper, the number of experts each token will be routed to + aux_loss_weight: the weight of the auxiliary loss. """ - if aux_loss_alpha <= 0: + if aux_loss_weight <= 0: return torch.tensor(0.0, device=scores.device, dtype=scores.dtype) # Total number of routed experts (N) @@ -614,6 +614,6 @@ def batch_wise_aux_loss( f_i = num_tokens_per_expert.to(scores.dtype) * (N / (top_k * T)) - loss = (f_i * P_i).sum() * aux_loss_alpha + loss = (f_i * P_i).sum() * aux_loss_weight return loss diff --git a/torchtitan/models/qwen3/__init__.py b/torchtitan/models/qwen3/__init__.py index 0cd569b697..6915cb74fc 100644 --- a/torchtitan/models/qwen3/__init__.py +++ b/torchtitan/models/qwen3/__init__.py @@ -6,7 +6,7 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.loss import build_cross_entropy_loss, moe_loss_wrap from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers from torchtitan.components.tokenizer import build_hf_tokenizer @@ -201,7 +201,7 @@ def get_train_spec() -> TrainSpec: build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_text_dataloader, build_tokenizer_fn=build_hf_tokenizer, - build_loss_fn=build_cross_entropy_loss, + build_loss_fn=moe_loss_wrap(build_cross_entropy_loss), build_validator_fn=build_validator, state_dict_adapter=Qwen3StateDictAdapter, ) diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index 0c700ce2e0..bac5149d56 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -55,6 +55,13 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len + losses_config = job_config.training.extra_losses + self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type + self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight + self.moe_args.moe_aux_loss_free_bias_coeff = ( + losses_config.moe_aux_loss_free_bias_coeff + ) + self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance ) diff --git a/torchtitan/train.py b/torchtitan/train.py index bab206cb00..5cfab998b2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import functools import importlib import os import time @@ -19,7 +18,7 @@ from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training -from torchtitan.components.loss import moe_loss, rescale_accumulated_loss +from torchtitan.components.loss import rescale_accumulated_loss from torchtitan.components.metrics import ( build_metrics_processor, ensure_pp_loss_visible, @@ -185,11 +184,6 @@ def __init__(self, job_config: JobConfig): job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager ) - self.loss_fn = functools.partial( - moe_loss, - loss_fn=self.loss_fn, - ) - # verify batch sizes global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: From 1bdc48b0ee4464e1e56787eb56515ab96d89c7d0 Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Wed, 10 Dec 2025 22:50:58 +0100 Subject: [PATCH 3/5] fix load balance loss when its disabled --- torchtitan/models/moe/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index b2cc8fdc04..be7e412f7f 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -453,7 +453,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.load_balance_loss_weight, ) else: - load_balance_loss = torch.tensor(0.0, device=out.device, dtype=out.dtype) + load_balance_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) From 2830bcb073e0ec6b86b5244c63e1ad031b26a4c4 Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Wed, 10 Dec 2025 23:45:10 +0100 Subject: [PATCH 4/5] fix aux loss --- torchtitan/models/moe/moe.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index be7e412f7f..dfd6ba894d 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -262,11 +262,13 @@ def forward( _, selected_experts_indices = torch.topk( scores + expert_bias, k=self.top_k, dim=1 ) + expert_indices_for_load_balance = torch.topk(scores, k=self.top_k, dim=1)[1] top_scores = scores.gather(dim=1, index=selected_experts_indices) else: top_scores, selected_experts_indices = torch.topk( scores, k=self.top_k, dim=1 ) + expert_indices_for_load_balance = torch.topk(scores, k=self.top_k, dim=1)[1] # debug override: balanced round-robin routing if self._debug_force_load_balance: @@ -288,7 +290,13 @@ def forward( max=self.num_experts, ) - return top_scores, scores, selected_experts_indices, num_tokens_per_expert + return ( + top_scores, + scores, + selected_experts_indices, + num_tokens_per_expert, + expert_indices_for_load_balance, + ) def init_weights(self, init_std: float): nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) @@ -425,6 +433,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: scores, selected_experts_indices, num_tokens_per_expert, + expert_indices_for_load_balance, ) = self.router(x, self.expert_bias) # tokens_per_expert will be used to update the expert bias for load balancing. @@ -439,7 +448,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.load_balance_loss_type == "sequence_wise": load_balance_loss = MoE.sequence_wise_aux_loss( scores, - selected_experts_indices.long(), + expert_indices_for_load_balance.long(), bs, slen, self.top_k, From 1df2412d9991a59fb4888836da9778512b2e1161 Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Thu, 11 Dec 2025 00:03:07 +0100 Subject: [PATCH 5/5] fix --- torchtitan/models/moe/moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index dfd6ba894d..d5fdcdef9c 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -262,13 +262,12 @@ def forward( _, selected_experts_indices = torch.topk( scores + expert_bias, k=self.top_k, dim=1 ) - expert_indices_for_load_balance = torch.topk(scores, k=self.top_k, dim=1)[1] top_scores = scores.gather(dim=1, index=selected_experts_indices) else: top_scores, selected_experts_indices = torch.topk( scores, k=self.top_k, dim=1 ) - expert_indices_for_load_balance = torch.topk(scores, k=self.top_k, dim=1)[1] + expert_indices_for_load_balance = torch.topk(scores, k=self.top_k, dim=1)[1] # debug override: balanced round-robin routing if self._debug_force_load_balance: