@@ -517,7 +517,16 @@ def triton_fused_moe(
517517 """Triton unquantized MoE with 2-layer MLP and ReLU^2 activation."""
518518 x_shape = x .shape
519519 x2d = x .view (- 1 , x_shape [- 1 ])
520- topk_ids = selected_experts .to (torch .int32 ).contiguous ()
520+
521+ # Get number of local experts from weight shape
522+ num_experts = w1_stacked_weight .shape [0 ]
523+
524+ # Clamp expert IDs to valid range to handle EP sharding
525+ # After EP sharding, some expert IDs may be negative (for experts on other ranks)
526+ # Clamp them to 0 (first expert) - these will be masked by routing_weights=0 anyway
527+ selected_experts_clamped = torch .clamp (selected_experts , min = 0 , max = num_experts - 1 )
528+
529+ topk_ids = selected_experts_clamped .to (torch .int32 ).contiguous ()
521530 topk_weights = routing_weights .to (torch .float32 ).contiguous ()
522531
523532 out2d = _fused_moe_mlp_relu2 (x2d , w1_stacked_weight , w2_stacked_weight , topk_ids , topk_weights )
@@ -565,7 +574,16 @@ def triton_quant_fp8_moe(
565574
566575 x_shape = x .shape
567576 x2d = x .view (- 1 , x_shape [- 1 ])
568- topk_ids = selected_experts .to (torch .int32 ).contiguous ()
577+
578+ # Get number of local experts from weight shape
579+ num_experts = w1_weight .shape [0 ]
580+
581+ # Clamp expert IDs to valid range to handle EP sharding
582+ # After EP sharding, some expert IDs may be negative (for experts on other ranks)
583+ # Clamp them to 0 (first expert) - these will be masked by routing_weights=0 anyway
584+ selected_experts_clamped = torch .clamp (selected_experts , min = 0 , max = num_experts - 1 )
585+
586+ topk_ids = selected_experts_clamped .to (torch .int32 ).contiguous ()
569587 topk_weights = routing_weights .to (torch .float32 ).contiguous ()
570588
571589 # Weights are already stacked [E, ...] - just ensure contiguous and extract scales
0 commit comments