diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index 6fb80f39cb..c0e1a85076 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -78,3 +78,43 @@ def build_mse_loss(job_config: JobConfig, **kwargs): logger.info("Compiling the loss function with torch.compile") loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) return loss_fn + + +def moe_loss( + pred: tuple[torch.Tensor, torch.Tensor] | torch.Tensor, + labels: torch.Tensor, + loss_fn: LossFunction, +) -> torch.Tensor: + """Sequence-wise (or batch-wise) auxiliary load balance loss function for MoE + model training. + """ + if isinstance(pred, tuple): + pred, load_balance_loss = pred + loss = loss_fn(pred, labels) + # Add auxiliary loss to the computation graph for gradients in the backward pass, + # but cancel out its numeric value so the forward pass only logs language model task loss. + loss = loss + (load_balance_loss - load_balance_loss.detach()) + else: + loss = loss_fn(pred, labels) + return loss + + +def moe_loss_wrap(original_build_fn): + """ + Wraps a loss builder function. It builds the base loss function first, + then wraps it with the 'moe_loss' logic before returning it. + """ + + @functools.wraps(original_build_fn) # Preserves name/docstring of original + def wrapper(job_config, **kwargs): + # 1. Create the base loss function (e.g., standard CrossEntropy) + # We pass job_config and kwargs through exactly as the original expects + base_loss_fn = original_build_fn(job_config, **kwargs) + + # 2. Apply the MoE wrapper immediately + # This binds 'base_loss_fn' to the 'loss_fn' argument of 'moe_loss' + final_loss_fn = functools.partial(moe_loss, loss_fn=base_loss_fn) + + return final_loss_fn + + return wrapper diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 7fc5098800..353003eec5 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -344,8 +344,8 @@ def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool: for model_part in model_parts: for transformer_block in model_part.layers.values(): if transformer_block.moe_enabled: - # Assumption: load_balance_coeff is set universally on all moe blocks. - return bool(transformer_block.moe.load_balance_coeff) + # Assumption: moe_aux_loss_free_bias_coeff is set universally on all moe blocks. + return bool(transformer_block.moe.moe_aux_loss_free_bias_coeff) return False # for MoE auxiliary-loss-free load balancing @@ -366,7 +366,7 @@ def _update_expert_bias( for transformer_block in model_part.layers.values(): if not transformer_block.moe_enabled: continue - if transformer_block.moe.load_balance_coeff is None: + if transformer_block.moe.moe_aux_loss_free_bias_coeff is None: return tokens_per_expert = transformer_block.moe.tokens_per_expert if _is_recomputation_enabled(transformer_block): @@ -401,7 +401,7 @@ def _update_expert_bias( # update the expert bias # this is not exactly the same as https://arxiv.org/pdf/2408.15664 proposed - expert_bias_delta = moe.load_balance_coeff * torch.sign( + expert_bias_delta = moe.moe_aux_loss_free_bias_coeff * torch.sign( tokens_per_expert.mean() - tokens_per_expert ) expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 95588d2c3b..3bfb6a60a1 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -198,6 +198,18 @@ class LRScheduler: """ +@dataclass +class ExtraLosses: + load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" + """Type of load balance loss to use""" + + load_balance_loss_weight: float = 0 + """Weight of load balance loss""" + + moe_aux_loss_free_bias_coeff: float | None = 1e-3 + """The coefficient of the bias update for aux-loss-free load balancing""" + + @dataclass class Training: dataset: str = "c4_test" @@ -226,6 +238,9 @@ class Training: steps: int = 10000 """How many train steps to run""" + extra_losses: ExtraLosses = field(default_factory=ExtraLosses) + """If we have multiple of losses, we can configure their weights here""" + enable_cpu_offload: bool = False """ Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py index c12ad13a5c..0fcb4f8637 100644 --- a/torchtitan/experiments/gpt_oss/__init__.py +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -38,7 +38,7 @@ score_before_experts=False, top_k=4, use_grouped_mm=True, - load_balance_coeff=1e-3, + moe_aux_loss_free_bias_coeff=1e-3, ), attn_mask_type="causal", ), @@ -53,7 +53,7 @@ score_before_experts=False, top_k=4, use_grouped_mm=True, - load_balance_coeff=1e-3, + moe_aux_loss_free_bias_coeff=1e-3, ), ), "120b": GptOssModelArgs( @@ -67,7 +67,7 @@ score_before_experts=False, top_k=4, use_grouped_mm=True, - load_balance_coeff=1e-3, + moe_aux_loss_free_bias_coeff=1e-3, ), ), } diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 525bd96c13..fb53d89744 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.loss import build_cross_entropy_loss, moe_loss_wrap from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer @@ -167,6 +167,6 @@ def get_train_spec() -> TrainSpec: build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_text_dataloader, build_tokenizer_fn=build_hf_tokenizer, - build_loss_fn=build_cross_entropy_loss, + build_loss_fn=moe_loss_wrap(build_cross_entropy_loss), state_dict_adapter=DeepSeekV3StateDictAdapter, ) diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 48d4b5ece1..4f7f4320f2 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -95,6 +95,13 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len + losses_config = job_config.training.extra_losses + self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type + self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight + self.moe_args.moe_aux_loss_free_bias_coeff = ( + losses_config.moe_aux_loss_free_bias_coeff + ) + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): logger.warning( "Failed to use grouped mm, which is only supported on SM90 or later", diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 3cf56eb1b2..25b168806a 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -309,6 +309,7 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, + accumulated_load_balance_loss: torch.Tensor, attention_masks: AttentionMasksType | None, ): """ @@ -323,10 +324,15 @@ def forward( """ x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) if self.moe_enabled: - x = x + self.moe(self.ffn_norm(x)) + ffn_moe_output, load_balance_loss = self.moe(self.ffn_norm(x)) + accumulated_load_balance_loss = ( + accumulated_load_balance_loss + load_balance_loss + ) else: - x = x + self.feed_forward(self.ffn_norm(x)) - return x + ffn_moe_output = self.feed_forward(self.ffn_norm(x)) + + x = x + ffn_moe_output + return x, accumulated_load_balance_loss def init_weights(self, buffer_device: torch.device): for norm in (self.attention_norm, self.ffn_norm): @@ -410,6 +416,7 @@ def get_attention_masks( def forward( self, tokens: torch.Tensor, + accumulated_load_balance_loss: torch.Tensor | None = None, attention_masks: AttentionMasksType | None = None, ): """ @@ -427,8 +434,16 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens + accumulated_load_balance_loss = ( + torch.zeros((), device=h.device, dtype=torch.float32) + if accumulated_load_balance_loss is None + else accumulated_load_balance_loss + ) + for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks) + h, accumulated_load_balance_loss = layer( + h, self.freqs_cis, accumulated_load_balance_loss, attention_masks + ) h = self.norm(h) if self.norm is not None else h output = self.output(h) if self.output is not None else h - return output + return output, accumulated_load_balance_loss diff --git a/torchtitan/models/llama4/__init__.py b/torchtitan/models/llama4/__init__.py index 24196c2326..bd7f06b54e 100644 --- a/torchtitan/models/llama4/__init__.py +++ b/torchtitan/models/llama4/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.loss import build_cross_entropy_loss, moe_loss_wrap from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer @@ -112,7 +112,7 @@ def get_train_spec() -> TrainSpec: build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_text_dataloader, build_tokenizer_fn=build_hf_tokenizer, - build_loss_fn=build_cross_entropy_loss, + build_loss_fn=moe_loss_wrap(build_cross_entropy_loss), build_validator_fn=build_validator, state_dict_adapter=Llama4StateDictAdapter, ) diff --git a/torchtitan/models/llama4/model/args.py b/torchtitan/models/llama4/model/args.py index 7fcc9871f5..3e9baebd84 100644 --- a/torchtitan/models/llama4/model/args.py +++ b/torchtitan/models/llama4/model/args.py @@ -70,6 +70,13 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len + losses_config = job_config.training.extra_losses + self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type + self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight + self.moe_args.moe_aux_loss_free_bias_coeff = ( + losses_config.moe_aux_loss_free_bias_coeff + ) + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): logger.warning( "Failed to use grouped mm, which is only supported on SM90 or later", diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 295e2193a5..d5fdcdef9c 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -29,8 +29,9 @@ class MoEArgs: # token-choice top_k: int = 1 use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation - load_balance_coeff: float | None = 1e-3 - + moe_aux_loss_free_bias_coeff: float | None = 1e-3 + load_balance_loss_weight: float = 0 + load_balance_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" _debug_force_load_balance: bool = False # if True, we force each experts get same amount of token via round-robin @@ -266,6 +267,7 @@ def forward( top_scores, selected_experts_indices = torch.topk( scores, k=self.top_k, dim=1 ) + expert_indices_for_load_balance = torch.topk(scores, k=self.top_k, dim=1)[1] # debug override: balanced round-robin routing if self._debug_force_load_balance: @@ -287,7 +289,13 @@ def forward( max=self.num_experts, ) - return top_scores, selected_experts_indices, num_tokens_per_expert + return ( + top_scores, + scores, + selected_experts_indices, + num_tokens_per_expert, + expert_indices_for_load_balance, + ) def init_weights(self, init_std: float): nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) @@ -359,6 +367,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): super().__init__() num_experts = moe_args.num_experts + self.top_k = moe_args.top_k self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, @@ -386,9 +395,11 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): # NOTE: tokens_per_expert is accumulated in the model forward pass. # expert_bias is updated outside the model in an optimizer step pre hook # to work with gradient accumulation. - self.load_balance_coeff = moe_args.load_balance_coeff - if self.load_balance_coeff is not None: - assert self.load_balance_coeff > 0.0 + self.load_balance_loss_weight = moe_args.load_balance_loss_weight + self.load_balance_loss_type = moe_args.load_balance_loss_type + self.moe_aux_loss_free_bias_coeff = moe_args.moe_aux_loss_free_bias_coeff + if self.moe_aux_loss_free_bias_coeff is not None: + assert self.moe_aux_loss_free_bias_coeff > 0.0 self.register_buffer( "expert_bias", torch.zeros(num_experts, dtype=torch.float32), @@ -418,8 +429,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # num_tokens_per_expert shape (num_experts,) ( top_scores, + scores, selected_experts_indices, num_tokens_per_expert, + expert_indices_for_load_balance, ) = self.router(x, self.expert_bias) # tokens_per_expert will be used to update the expert bias for load balancing. @@ -430,6 +443,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) + if self.training: + if self.load_balance_loss_type == "sequence_wise": + load_balance_loss = MoE.sequence_wise_aux_loss( + scores, + expert_indices_for_load_balance.long(), + bs, + slen, + self.top_k, + self.load_balance_loss_weight, + ) + elif self.load_balance_loss_type == "batch_wise": + load_balance_loss = MoE.batch_wise_aux_loss( + scores, + num_tokens_per_expert, + self.top_k, + self.load_balance_loss_weight, + ) + else: + load_balance_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) + # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) # NOTE: the reason we need to compute num_tokens_per_expert again is: @@ -479,7 +512,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dim=0, index=token_indices_experts_sorted, src=routed_output ) out = out.reshape(bs, slen, dim) - return out + return out, load_balance_loss def init_weights( self, @@ -495,7 +528,100 @@ def init_weights( self.tokens_per_expert = torch.zeros( self.experts.num_experts, dtype=torch.float32 ) - if self.load_balance_coeff is not None: + if self.moe_aux_loss_free_bias_coeff is not None: self.expert_bias = torch.zeros( self.experts.num_experts, dtype=torch.float32 ) + + @staticmethod + def sequence_wise_aux_loss( + scores: torch.Tensor, + indices: torch.Tensor, + B: int, + S: int, + top_k: int, + aux_loss_weight: float, + ) -> torch.Tensor: + """ + Computes Sequence-Wise Auxiliary Loss (DeepSeek-V3 Equations 17-20). + + Args: + scores: The dense affinity scores (s_{i,t}) for routed experts. + Should be the output of Sigmoid, shape (B*S, N). + indices: The top-k selected expert indices. Shape (B*S, K). + B: Batch size + S: Sequence length (T in the paper) + top_k: K_r in the paper, the number of experts each token will be routed to + aux_loss_weight: the weight of the auxiliary loss + """ + if aux_loss_weight <= 0: + return torch.tensor(0.0, device=scores.device, dtype=scores.dtype) + + # N: Total number of routed experts + N = scores.size(-1) + + # 1. Reshape inputs to handle each sequence separately: (B, S, N) + # This ensures we calculate P_i and f_i per sequence (Eq 20 & 18). + scores_per_seq = scores.view(B, S, N) + + # 2. Eq 19: Normalize affinity scores s_{i,t} to get s'_{i,t} + # DeepSeek-V3 uses Sigmoid, so scores don't sum to 1. + # Eq 19 explicitly requires dividing by the sum of all affinities. + # denominator shape: (B, S, 1) + denominator = scores_per_seq.sum(dim=-1, keepdim=True) + 1e-20 + probs_per_seq = scores_per_seq / denominator # This is s'_{i,t} + + # 3. Eq 20: Calculate P_i (Average probability per expert for each sequence) + # P_i = (1/T) * sum_{t=1}^T (s'_{i,t}) + # We average over the Sequence dimension (dim=1). + # P_i shape: (B, N) + P_i = probs_per_seq.mean(dim=1) + + # 4. Eq 18: Calculate f_i (Fraction of tokens selecting expert i per sequence) + # f_i = (N / (K * T)) * count_i + + # Flatten the top-k dimension to count hits per sequence: (B, S*K) + flat_indices_per_seq = indices.view(B, -1) + offset = torch.arange(B, device=flat_indices_per_seq.device).unsqueeze(1) * N + flat_indices = (flat_indices_per_seq + offset).reshape(-1) + selection_counts = torch.bincount(flat_indices, minlength=B * N).reshape(B, N) + selection_counts = selection_counts.to(dtype=scores.dtype) + + # Calculate f_i for each sequence, T (tokens in sequence) is S + f_i = selection_counts * (N / (top_k * S)) + + # 5. Eq 17: Calculate Balance Loss + loss_per_seq = (f_i * P_i).sum(dim=1) * aux_loss_weight + + return loss_per_seq.mean() + + @staticmethod + def batch_wise_aux_loss( + scores: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + top_k: int, + aux_loss_weight: float, + ) -> torch.Tensor: + """ + Computes Batch-Wise Auxiliary Loss. + Args: + scores: the dense probabilities (s_{i,t}) for routed experts. + num_tokens_per_expert: the number of tokens assigned to each expert (f_i). + top_k: K_r in the paper, the number of experts each token will be routed to + aux_loss_weight: the weight of the auxiliary loss. + """ + if aux_loss_weight <= 0: + return torch.tensor(0.0, device=scores.device, dtype=scores.dtype) + + # Total number of routed experts (N) + N = scores.size(1) + # Total number of tokens (T = BS * S) + T = scores.size(0) + + P_i = scores.mean(dim=0) + + f_i = num_tokens_per_expert.to(scores.dtype) * (N / (top_k * T)) + + loss = (f_i * P_i).sum() * aux_loss_weight + + return loss diff --git a/torchtitan/models/qwen3/__init__.py b/torchtitan/models/qwen3/__init__.py index 0cd569b697..6915cb74fc 100644 --- a/torchtitan/models/qwen3/__init__.py +++ b/torchtitan/models/qwen3/__init__.py @@ -6,7 +6,7 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.loss import build_cross_entropy_loss, moe_loss_wrap from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers from torchtitan.components.tokenizer import build_hf_tokenizer @@ -201,7 +201,7 @@ def get_train_spec() -> TrainSpec: build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_text_dataloader, build_tokenizer_fn=build_hf_tokenizer, - build_loss_fn=build_cross_entropy_loss, + build_loss_fn=moe_loss_wrap(build_cross_entropy_loss), build_validator_fn=build_validator, state_dict_adapter=Qwen3StateDictAdapter, ) diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index 0c700ce2e0..bac5149d56 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -55,6 +55,13 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len + losses_config = job_config.training.extra_losses + self.moe_args.load_balance_loss_type = losses_config.load_balance_loss_type + self.moe_args.load_balance_loss_weight = losses_config.load_balance_loss_weight + self.moe_args.moe_aux_loss_free_bias_coeff = ( + losses_config.moe_aux_loss_free_bias_coeff + ) + self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance )