@@ -119,16 +119,16 @@ def relu2(x: torch.Tensor) -> torch.Tensor:
119119
120120
121121def _get_test_data (
122- otype , wtype , batch_size , hidden_size , num_experts , intermediate_size , X_GEN_SCALE
122+ otype , wtype , batch_size , hidden_size , num_experts , intermediate_size , X_GEN_SCALE , W_GEN_SCALE
123123):
124124 input_shape = (batch_size , hidden_size )
125125 w31_shape = (num_experts , 2 * intermediate_size , hidden_size )
126126 w2_shape = (num_experts , hidden_size , intermediate_size )
127127
128128 x = cast_to_representable (gen_tensor (input_shape , otype , scale = X_GEN_SCALE ))
129129 router_logits = gen_tensor ((batch_size , num_experts ), otype )
130- w31_weight = gen_tensor (w31_shape , otype , wtype )
131- w2_weight = gen_tensor (w2_shape , otype , wtype )
130+ w31_weight = gen_tensor (w31_shape , otype , wtype , W_GEN_SCALE )
131+ w2_weight = gen_tensor (w2_shape , otype , wtype , W_GEN_SCALE )
132132 w31_empty_scales = torch .empty (num_experts , 2 , dtype = otype ).cuda ()
133133 w2_empty_scales = torch .empty (num_experts , 1 , dtype = otype ).cuda ()
134134 return x , router_logits , w31_weight , w2_weight , w31_empty_scales , w2_empty_scales
@@ -203,9 +203,17 @@ def test_trtllm_fused_moe(
203203 X_GEN_SCALE = 1.0
204204 else :
205205 X_GEN_SCALE = 0.5
206+ W_GEN_SCALE = 0.1
206207
207208 x , router_logits , w31_weight , w2_weight , w31_scales , w2_scales = _get_test_data (
208- otype , wtype , batch_size , hidden_size , num_experts , intermediate_size , X_GEN_SCALE
209+ otype ,
210+ wtype ,
211+ batch_size ,
212+ hidden_size ,
213+ num_experts ,
214+ intermediate_size ,
215+ X_GEN_SCALE ,
216+ W_GEN_SCALE ,
209217 )
210218
211219 routing_weights , selected_experts = compute_routing (router_logits , top_k )
@@ -278,14 +286,14 @@ def get_fc1_expert_weights(
278286 w1_weight .contiguous (),
279287 w2_weight .contiguous (),
280288 )[0 ].view (x .shape )
281- torch .testing .assert_close (output_triton_moe , ad_test_output , rtol = 1e-1 , atol = 1e-1 )
289+ torch .testing .assert_close (output_triton_moe , ad_test_output , rtol = 1e-2 , atol = 1e-2 )
282290
283291 diff = (ref_output - ad_test_output ).abs ()
284292 print (f"max diff: { diff .max ()} " )
285293 torch .testing .assert_close (ad_test_output , trtllm_test_output , rtol = 1e-6 , atol = 1e-6 )
286294
287295 _print_diff_if (lambda diff : diff .max () > 1e-1 , diff , ad_test_output , ref_output )
288- torch .testing .assert_close (ref_output , ad_test_output , rtol = 1e-1 , atol = 1e-1 )
296+ torch .testing .assert_close (ref_output , ad_test_output , rtol = 1e-2 , atol = 1e-2 )
289297
290298
291299FP8_TEST_DTYPES = [
@@ -305,7 +313,7 @@ def get_fc1_expert_weights(
305313 not fp8_compatible () or not trtllm_ops_available (),
306314 reason = "Requires fp8 and trtllm support" ,
307315)
308- def test_trtllm_fused_fp8moe (
316+ def test_trtllm_fused_moe_fp8 (
309317 batch_size ,
310318 hidden_size ,
311319 num_experts ,
@@ -333,16 +341,18 @@ def test_trtllm_fused_fp8moe(
333341 else :
334342 X_GEN_SCALE = 0.5
335343
336- def dequantize_weights (w31_weight , w2_weight , w31_scales , w2_scales ):
344+ W_GEN_SCALE = 0.1
345+
346+ def dequantize_weights (w31_weight , w2_weight , w31_scales , w2_scales , W_GEN_SCALE ):
337347 # input_shape = (batch_size, hidden_size)
338348 w31_shape = (num_experts , 2 * intermediate_size , hidden_size )
339349 w2_shape = (num_experts , hidden_size , intermediate_size )
340350
341351 w31_dequantized = gen_tensor (w31_weight .shape , otype )
342352 w2_dequantized = gen_tensor (w2_weight .shape , otype )
343353 for expert_id in range (num_experts ):
344- w31 = cast_to_representable (gen_tensor (w31_shape [1 :], otype , scale = 0.1 ))
345- w2 = cast_to_representable (gen_tensor (w2_shape [1 :], otype , scale = 0.09 ))
354+ w31 = cast_to_representable (gen_tensor (w31_shape [1 :], otype , scale = W_GEN_SCALE ))
355+ w2 = cast_to_representable (gen_tensor (w2_shape [1 :], otype , scale = W_GEN_SCALE ))
346356 w31_quant , s31 = dynamic_per_tensor_fp8_quant (w31 )
347357 w2_quant , s2 = dynamic_per_tensor_fp8_quant (w2 )
348358 w31_weight .data [expert_id ].copy_ (w31_quant )
@@ -354,11 +364,18 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales):
354364 return w31_dequantized , w2_dequantized
355365
356366 x , router_logits , w31_weight , w2_weight , w31_scales , w2_scales = _get_test_data (
357- otype , wtype , batch_size , hidden_size , num_experts , intermediate_size , X_GEN_SCALE
367+ otype ,
368+ wtype ,
369+ batch_size ,
370+ hidden_size ,
371+ num_experts ,
372+ intermediate_size ,
373+ X_GEN_SCALE ,
374+ W_GEN_SCALE ,
358375 )
359376
360377 w31_dequantized , w2_dequantized = dequantize_weights (
361- w31_weight , w2_weight , w31_scales , w2_scales
378+ w31_weight , w2_weight , w31_scales , w2_scales , W_GEN_SCALE
362379 )
363380
364381 routing_weights , selected_experts = compute_routing (router_logits , top_k )
0 commit comments