@@ -75,15 +75,16 @@ def _pack_routed_tokens_reference(
7575 return sorted_token_ids_used , expert_ids_used , used_len
7676
7777
78- def test_triton_moe_matches_torch_moe_mlp_relu2 ():
78+ @pytest .mark .parametrize ("early_exit" , [False , True ])
79+ def test_triton_moe_matches_torch_moe_mlp_relu2 (early_exit ):
7980 torch .manual_seed (0 )
8081
8182 if not torch .cuda .is_available ():
8283 pytest .skip ("CUDA is required for triton_moe fused MLP test" )
8384 device = "cuda"
8485 dtype = torch .bfloat16
8586
86- M = 8 # tokens
87+ M = 32 if early_exit else 8 # tokens
8788 HIDDEN_SIZE = 8
8889 INTERMEDIATE_SIZE = 16
8990 E = 8 # experts
@@ -102,12 +103,26 @@ def test_triton_moe_matches_torch_moe_mlp_relu2():
102103 w_up_stacked = torch .stack (w_up_list , dim = 0 ).contiguous () # [E, I, H]
103104 w_down_stacked = torch .stack (w_down_list , dim = 0 ).contiguous () # [E, H, I]
104105
105- # Create routing with top-k normalization
106- router_logits = torch .randn (M , E , device = device , dtype = torch .float32 )
107- routing_full = torch .softmax (router_logits , dim = - 1 )
108- routing_weights , selected_experts = torch .topk (routing_full , k = top_k , dim = - 1 )
109- routing_weights = routing_weights / routing_weights .sum (dim = - 1 , keepdim = True )
110- routing_weights = routing_weights .to (torch .float32 )
106+ # Create routing based on whether we want to test early exit
107+ if not early_exit :
108+ # Random routing with top-k normalization
109+ router_logits = torch .randn (M , E , device = device , dtype = torch .float32 )
110+ routing_full = torch .softmax (router_logits , dim = - 1 )
111+ routing_weights , selected_experts = torch .topk (routing_full , k = top_k , dim = - 1 )
112+ routing_weights = routing_weights / routing_weights .sum (dim = - 1 , keepdim = True )
113+ routing_weights = routing_weights .to (torch .float32 )
114+ else :
115+ # Imbalanced routing: concentrate 75% of tokens on first 2 experts
116+ # This tests early exit logic in num_tokens_post_padded path
117+ selected_experts = torch .zeros ((M , top_k ), dtype = torch .int64 , device = device )
118+ for i in range (M ):
119+ if i < M * 3 // 4 :
120+ selected_experts [i , 0 ] = 0
121+ selected_experts [i , 1 ] = 1
122+ else :
123+ selected_experts [i , 0 ] = i % E
124+ selected_experts [i , 1 ] = (i + 1 ) % E
125+ routing_weights = torch .ones ((M , top_k ), device = device , dtype = torch .float32 ) / top_k
111126
112127 # Triton fused MoE (mlp with relu^2 activation between two GEMMs)
113128 out_triton = torch .ops .auto_deploy .triton_moe_fused (
@@ -219,7 +234,8 @@ def test_moe_align_kernel_groups_tokens_by_expert_and_block_padding():
219234
220235
221236@skip_pre_hopper
222- def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe ():
237+ @pytest .mark .parametrize ("early_exit" , [False , True ])
238+ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe (early_exit ):
223239 """Test triton_quant_fp8_moe against torch_quant_fp8_moe reference."""
224240 torch .manual_seed (0 )
225241
@@ -228,7 +244,7 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe():
228244 device = "cuda"
229245 dtype = torch .bfloat16
230246
231- M = 32 # tokens
247+ M = 64 if early_exit else 32 # tokens
232248 HIDDEN_SIZE = 16 # Must be multiple of 16 for FP8 linear
233249 INTERMEDIATE_SIZE = 32 # Must be multiple of 16 for FP8 linear
234250 E = 4 # experts
@@ -313,12 +329,23 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe():
313329 w3_weight_scale_list = [torch .ones ((), device = device , dtype = torch .float32 ) for _ in range (E )]
314330 w3_weight_scale_tensor = torch .ones ((E ,), device = device , dtype = torch .float32 )
315331
316- # Create controlled routing to ensure even token distribution across experts
332+ # Create routing based on whether we want to test early exit
317333 selected_experts = torch .zeros ((M , top_k ), dtype = torch .int64 , device = device )
318- for i in range ( M ) :
334+ if not early_exit :
319335 # Distribute tokens evenly: token i goes to experts (i % E) and ((i+1) % E)
320- selected_experts [i , 0 ] = i % E
321- selected_experts [i , 1 ] = (i + 1 ) % E
336+ for i in range (M ):
337+ selected_experts [i , 0 ] = i % E
338+ selected_experts [i , 1 ] = (i + 1 ) % E
339+ else :
340+ # Imbalanced routing: concentrate 75% of tokens on first 2 experts
341+ # This tests early exit logic in num_tokens_post_padded path
342+ for i in range (M ):
343+ if i < M * 3 // 4 :
344+ selected_experts [i , 0 ] = 0
345+ selected_experts [i , 1 ] = 1
346+ else :
347+ selected_experts [i , 0 ] = i % E
348+ selected_experts [i , 1 ] = (i + 1 ) % E
322349
323350 # Create equal routing weights
324351 routing_weights = torch .ones ((M , top_k ), device = device , dtype = torch .float32 ) / top_k
0 commit comments