Skip to content

Commit 476a965

Browse files
authored
Implicit overlap of shared expert compute and token combine communication (#1741)
This PR moves the computation of the shared expert before the possible scoring of the routed expert output which leads to an implicit overlap between shared expert compute and token combine communication in MoE models. Repro (lowered the layer number to 2): ``` CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --profiling.enable_profiling --profiling.profile_freq 10 --training.steps 10 ``` Trace before the change: <img width="1503" height="625" alt="Screenshot 2025-09-23 at 12 08 31 AM" src="https://github.com/user-attachments/assets/bbcc41cf-6497-482e-972e-d917baf4498e" /> Trace after the change (note that all-to-all comm is now overlapping shared expert compute): <img width="1503" height="625" alt="Screenshot 2025-09-23 at 12 04 56 AM" src="https://github.com/user-attachments/assets/3504e77c-aa14-46fd-8e47-e247b88d7b9c" /> cc @tianyu-l @xmfan
1 parent 22d2d44 commit 476a965

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

torchtitan/models/moe.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -417,18 +417,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
417417
# shape (bs*slen*top_k, dim)
418418
routed_output = self.experts(routed_input, num_tokens_per_expert)
419419

420-
if not self.score_before_experts:
421-
routed_output = (
422-
routed_output.to(torch.float32)
423-
* top_scores_experts_sorted.reshape(-1, 1)
424-
).to(x.dtype)
425-
426420
# shared expert
421+
# Note: we execute the shared expert before scoring the output of the routed expert
422+
# to "implicitly" overlap the shared expert compute with token combine communication
427423
if self.shared_experts is not None:
428424
out = self.shared_experts(x)
429425
else:
430426
out = torch.zeros_like(x)
431427

428+
if not self.score_before_experts:
429+
routed_output = (
430+
routed_output.to(torch.float32)
431+
* top_scores_experts_sorted.reshape(-1, 1)
432+
).to(x.dtype)
433+
432434
out = out.scatter_add(
433435
dim=0, index=token_indices_experts_sorted, src=routed_output
434436
)

0 commit comments

Comments
 (0)