Skip to content

Commit ad888cd

Browse files
authored
Merge branch 'main' into pmannan/hcp_fix
2 parents 311f8a8 + a4008d0 commit ad888cd

2 files changed

Lines changed: 32 additions & 3 deletions

File tree

megatron/core/transformer/moe/moe_layer.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,10 +451,25 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None):
451451

452452
def backward_dw(self, routed_experts: bool = True, shared_experts: bool = False):
453453
"""Compute weight gradients for experts and shared experts."""
454+
# TODO(Wohox): replace the "routed_experts" and "shared_experts" arguments with better
455+
# naming to better explain that they are actually from different fine-grained callables,
456+
# or use scanning to decide which backward_dw should be called.
454457
if routed_experts:
455458
self.experts.backward_dw()
456-
if shared_experts and self.use_shared_expert and not self.shared_expert_overlap:
457-
self.shared_experts.backward_dw()
459+
if self.config.moe_latent_size:
460+
# TODO(Wohox): fc2_latent_proj forward and backward are executed in comm stream,
461+
# so we execute its backward_dw in the comm stream too. But this may harm the
462+
# EP overlap performance. Better to check if there is a better way to handle this.
463+
from megatron.core.pipeline_parallel.utils import get_comm_stream
464+
465+
comm_stream = get_comm_stream()
466+
with torch.cuda.stream(comm_stream):
467+
self.fc2_latent_proj.backward_dw()
468+
if shared_experts:
469+
if self.use_shared_expert and not self.shared_expert_overlap:
470+
self.shared_experts.backward_dw()
471+
if self.config.moe_latent_size:
472+
self.fc1_latent_proj.backward_dw()
458473

459474
def set_for_recompute_pre_mlp_layernorm(self):
460475
"""Set the MoE layer for recompute pre_mlp_layernorm. Only needed for fp8/fp4."""

megatron/training/training.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def transformer_flops():
364364
if args.moe_ffn_hidden_size is not None
365365
else args.ffn_hidden_size
366366
)
367+
moe_latent_size = args.moe_latent_size
367368
shared_expert_ffn_hidden_size = (
368369
0
369370
if args.moe_shared_expert_intermediate_size is None
@@ -545,7 +546,20 @@ def transformer_flops():
545546
(args.ffn_hidden_size * ffn_expansion_factor)
546547
* num_dense_layers
547548
# routed experts
548-
+ (moe_ffn_hidden_size * num_experts_routed_to * ffn_expansion_factor)
549+
+ (
550+
(moe_ffn_hidden_size * num_experts_routed_to * ffn_expansion_factor)
551+
if moe_latent_size is None
552+
else (
553+
(
554+
moe_ffn_hidden_size
555+
* num_experts_routed_to
556+
* ffn_expansion_factor
557+
* moe_latent_size
558+
/ args.hidden_size
559+
) # Routed experts run on moe_latent_size.
560+
+ 2 * moe_latent_size # Up proj and down proj.
561+
)
562+
)
549563
* num_moe_layers
550564
# Shared Experts.
551565
+ (shared_expert_ffn_hidden_size * ffn_expansion_factor)

0 commit comments

Comments
 (0)