@@ -451,10 +451,25 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None):
451451
452452 def backward_dw (self , routed_experts : bool = True , shared_experts : bool = False ):
453453 """Compute weight gradients for experts and shared experts."""
454+ # TODO(Wohox): replace the "routed_experts" and "shared_experts" arguments with better
455+ # naming to better explain that they are actually from different fine-grained callables,
456+ # or use scanning to decide which backward_dw should be called.
454457 if routed_experts :
455458 self .experts .backward_dw ()
456- if shared_experts and self .use_shared_expert and not self .shared_expert_overlap :
457- self .shared_experts .backward_dw ()
459+ if self .config .moe_latent_size :
460+ # TODO(Wohox): fc2_latent_proj forward and backward are executed in comm stream,
461+ # so we execute its backward_dw in the comm stream too. But this may harm the
462+ # EP overlap performance. Better to check if there is a better way to handle this.
463+ from megatron .core .pipeline_parallel .utils import get_comm_stream
464+
465+ comm_stream = get_comm_stream ()
466+ with torch .cuda .stream (comm_stream ):
467+ self .fc2_latent_proj .backward_dw ()
468+ if shared_experts :
469+ if self .use_shared_expert and not self .shared_expert_overlap :
470+ self .shared_experts .backward_dw ()
471+ if self .config .moe_latent_size :
472+ self .fc1_latent_proj .backward_dw ()
458473
459474 def set_for_recompute_pre_mlp_layernorm (self ):
460475 """Set the MoE layer for recompute pre_mlp_layernorm. Only needed for fp8/fp4."""
0 commit comments