diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 7e2d35a5d9..31e450eb04 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -112,13 +112,13 @@ num_experts=160, num_shared_experts=2, top_k=6, + num_expert_groups=8, + num_limited_groups=3, score_func="softmax", route_norm=False, route_scale=16.0, score_before_experts=False, ), - n_expert_groups=8, - n_limited_groups=3, q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, @@ -139,13 +139,13 @@ num_experts=256, num_shared_experts=1, top_k=8, + num_expert_groups=8, + num_limited_groups=4, score_func="sigmoid", route_norm=True, route_scale=2.5, score_before_experts=False, ), - n_expert_groups=8, - n_limited_groups=4, q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index e683905878..16f047021a 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -37,8 +37,6 @@ class DeepSeekV3ModelArgs(BaseModelArgs): n_heads (int): Number of attention heads. norm_eps (float): Epsilon value used for RMSNorm. moe_args (MoEArgs): MoE configuration. - n_expert_groups (int): Number of expert groups. - n_limited_groups (int): Number of limited groups for MoE routing. q_lora_rank (int): LoRA rank for query projections. kv_lora_rank (int): LoRA rank for key-value projections. qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. @@ -66,9 +64,6 @@ class DeepSeekV3ModelArgs(BaseModelArgs): # MoE moe_args: MoEArgs = field(default_factory=MoEArgs) - # TODO: node-limited routing is not supported yet - n_expert_groups: int = 1 - n_limited_groups: int = 1 # Multi-Head Latent Attention (MLA) q_lora_rank: int = 0 diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 741c908eab..617246e839 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -26,8 +26,10 @@ class MoEArgs: route_scale: float = 1.0 score_before_experts: bool = True - # token-choice + # token-choice with optional node limited routing top_k: int = 1 + num_expert_groups: int | None = None # must be a divisor of num_experts + num_limited_groups: int | None = None use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation load_balance_coeff: float | None = 1e-3 @@ -180,9 +182,17 @@ class TokenChoiceTopKRouter(nn.Module): """This class implements token-choice routing. In token-choice top-K routing, each token is routed to top K experts based on the router scores. + Optionally supports node-limited (group-limited) routing where experts are divided into groups + (e.g., by node), and only num_limited_groups groups are considered before selecting top_k experts. + This reduces cross-node communication in distributed settings. + Args: dim (int): Dimension of input tokens. num_experts (int): Number of experts in each moe layer. + num_expert_groups (int | None): Number of expert groups for node-limited routing. If None, standard + top-k routing is used. Must be a divisor of num_experts. + num_limited_groups (int | None): Number of groups to select in node-limited routing. Required when + num_expert_groups is set. top_k (int): Number of experts each token will be routed to in token-choice routing. score_func (Literal["softmax", "sigmoid"]): Whether to use sigmoid or softmax for router scores. route_norm (bool): Whether to normalize the routing scores when using sigmoid. @@ -193,6 +203,8 @@ def __init__( self, dim: int, num_experts: int, + num_expert_groups: int | None, + num_limited_groups: int | None, top_k: int, score_func: Literal["softmax", "sigmoid"], route_norm: bool, @@ -202,6 +214,8 @@ def __init__( super().__init__() self.gate = nn.Linear(dim, num_experts, bias=False) self.num_experts = num_experts + self.num_expert_groups = num_expert_groups + self.num_limited_groups = num_limited_groups self.top_k = top_k self.score_func = score_func self.route_norm = route_norm @@ -225,6 +239,47 @@ def _debug_force_load_balance_routing( top_scores = scores.gather(dim=1, index=selected_experts_indices) # [N,K] return selected_experts_indices, top_scores + def _get_node_limited_routing_scores( + self, + scores_for_choice: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Select num_limited_groups groups based on group scores, + and set expert scores in non-selected groups as -inf + + Args: + scores_for_choice: Router scores with expert_bias (if any), shape (bs*slen, num_experts) + + Returns: + scores_for_choice: shape (bs*slen, num_experts) + """ + if self.num_limited_groups is None: + raise ValueError( + "num_limited_groups must be set when num_expert_groups is set" + ) + if self.num_experts % self.num_expert_groups != 0: + raise ValueError( + f"num_experts ({self.num_experts}) must be divisible by num_expert_groups ({self.num_expert_groups})" + ) + experts_per_group = self.num_experts // self.num_expert_groups + if experts_per_group < 2: + raise ValueError(f"experts_per_group ({experts_per_group}) must be >= 2") + scores_grouped = scores_for_choice.view( + -1, self.num_expert_groups, experts_per_group + ) + top2_scores_in_group, _ = scores_grouped.topk(2, dim=-1) + group_scores = top2_scores_in_group.sum(dim=-1) + _, group_idx = torch.topk( + group_scores, k=self.num_limited_groups, dim=-1, sorted=False + ) + group_mask = torch.ones_like(group_scores, dtype=torch.bool) + group_mask.scatter_(1, group_idx, False) # False = selected groups (keep) + # Mask out experts from non-selected groups + scores_for_choice = scores_grouped.masked_fill( + group_mask.unsqueeze(-1), float("-inf") + ).view(-1, self.num_experts) + + return scores_for_choice + def forward( self, x: torch.Tensor, expert_bias: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -254,18 +309,18 @@ def forward( else: raise NotImplementedError(f"Unknown score function {self.score_func}") + scores_for_choice = scores if expert_bias is None else scores + expert_bias + # Apply node-limited routing if configured + if self.num_expert_groups is not None: + scores_for_choice = self._get_node_limited_routing_scores(scores_for_choice) + _, selected_experts_indices = torch.topk( + scores_for_choice, k=self.top_k, dim=-1, sorted=False + ) + # top scores shape (bs*slen, top_k) # NOTE: The expert_bias is only used for routing. The gating value # top_scores is still derived from the original scores. - if expert_bias is not None: - _, selected_experts_indices = torch.topk( - scores + expert_bias, k=self.top_k, dim=1 - ) - top_scores = scores.gather(dim=1, index=selected_experts_indices) - else: - top_scores, selected_experts_indices = torch.topk( - scores, k=self.top_k, dim=1 - ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) # debug override: balanced round-robin routing if self._debug_force_load_balance: @@ -367,6 +422,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): self.router = TokenChoiceTopKRouter( dim=dim, num_experts=num_experts, + num_expert_groups=moe_args.num_expert_groups, + num_limited_groups=moe_args.num_limited_groups, top_k=moe_args.top_k, score_func=moe_args.score_func, route_norm=moe_args.route_norm,