Skip to content

Commit 697f821

Browse files
nzmora-nvidialkomali
authored andcommitted
[None][feature] AutoDeploy: tighter MoE UT thresholds (#9195)
Scale down the weights in the MoE test so that the output has reasonable magnitude, allowing for tighter atol and rtol Signed-off-by: Neta Zmora <[email protected]> Signed-off-by: lkomali <[email protected]>
1 parent d77a3ef commit 697f821

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,16 @@ def relu2(x: torch.Tensor) -> torch.Tensor:
119119

120120

121121
def _get_test_data(
122-
otype, wtype, batch_size, hidden_size, num_experts, intermediate_size, X_GEN_SCALE
122+
otype, wtype, batch_size, hidden_size, num_experts, intermediate_size, X_GEN_SCALE, W_GEN_SCALE
123123
):
124124
input_shape = (batch_size, hidden_size)
125125
w31_shape = (num_experts, 2 * intermediate_size, hidden_size)
126126
w2_shape = (num_experts, hidden_size, intermediate_size)
127127

128128
x = cast_to_representable(gen_tensor(input_shape, otype, scale=X_GEN_SCALE))
129129
router_logits = gen_tensor((batch_size, num_experts), otype)
130-
w31_weight = gen_tensor(w31_shape, otype, wtype)
131-
w2_weight = gen_tensor(w2_shape, otype, wtype)
130+
w31_weight = gen_tensor(w31_shape, otype, wtype, W_GEN_SCALE)
131+
w2_weight = gen_tensor(w2_shape, otype, wtype, W_GEN_SCALE)
132132
w31_empty_scales = torch.empty(num_experts, 2, dtype=otype).cuda()
133133
w2_empty_scales = torch.empty(num_experts, 1, dtype=otype).cuda()
134134
return x, router_logits, w31_weight, w2_weight, w31_empty_scales, w2_empty_scales
@@ -203,9 +203,17 @@ def test_trtllm_fused_moe(
203203
X_GEN_SCALE = 1.0
204204
else:
205205
X_GEN_SCALE = 0.5
206+
W_GEN_SCALE = 0.1
206207

207208
x, router_logits, w31_weight, w2_weight, w31_scales, w2_scales = _get_test_data(
208-
otype, wtype, batch_size, hidden_size, num_experts, intermediate_size, X_GEN_SCALE
209+
otype,
210+
wtype,
211+
batch_size,
212+
hidden_size,
213+
num_experts,
214+
intermediate_size,
215+
X_GEN_SCALE,
216+
W_GEN_SCALE,
209217
)
210218

211219
routing_weights, selected_experts = compute_routing(router_logits, top_k)
@@ -278,14 +286,14 @@ def get_fc1_expert_weights(
278286
w1_weight.contiguous(),
279287
w2_weight.contiguous(),
280288
)[0].view(x.shape)
281-
torch.testing.assert_close(output_triton_moe, ad_test_output, rtol=1e-1, atol=1e-1)
289+
torch.testing.assert_close(output_triton_moe, ad_test_output, rtol=1e-2, atol=1e-2)
282290

283291
diff = (ref_output - ad_test_output).abs()
284292
print(f"max diff: {diff.max()}")
285293
torch.testing.assert_close(ad_test_output, trtllm_test_output, rtol=1e-6, atol=1e-6)
286294

287295
_print_diff_if(lambda diff: diff.max() > 1e-1, diff, ad_test_output, ref_output)
288-
torch.testing.assert_close(ref_output, ad_test_output, rtol=1e-1, atol=1e-1)
296+
torch.testing.assert_close(ref_output, ad_test_output, rtol=1e-2, atol=1e-2)
289297

290298

291299
FP8_TEST_DTYPES = [
@@ -305,7 +313,7 @@ def get_fc1_expert_weights(
305313
not fp8_compatible() or not trtllm_ops_available(),
306314
reason="Requires fp8 and trtllm support",
307315
)
308-
def test_trtllm_fused_fp8moe(
316+
def test_trtllm_fused_moe_fp8(
309317
batch_size,
310318
hidden_size,
311319
num_experts,
@@ -333,16 +341,18 @@ def test_trtllm_fused_fp8moe(
333341
else:
334342
X_GEN_SCALE = 0.5
335343

336-
def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales):
344+
W_GEN_SCALE = 0.1
345+
346+
def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE):
337347
# input_shape = (batch_size, hidden_size)
338348
w31_shape = (num_experts, 2 * intermediate_size, hidden_size)
339349
w2_shape = (num_experts, hidden_size, intermediate_size)
340350

341351
w31_dequantized = gen_tensor(w31_weight.shape, otype)
342352
w2_dequantized = gen_tensor(w2_weight.shape, otype)
343353
for expert_id in range(num_experts):
344-
w31 = cast_to_representable(gen_tensor(w31_shape[1:], otype, scale=0.1))
345-
w2 = cast_to_representable(gen_tensor(w2_shape[1:], otype, scale=0.09))
354+
w31 = cast_to_representable(gen_tensor(w31_shape[1:], otype, scale=W_GEN_SCALE))
355+
w2 = cast_to_representable(gen_tensor(w2_shape[1:], otype, scale=W_GEN_SCALE))
346356
w31_quant, s31 = dynamic_per_tensor_fp8_quant(w31)
347357
w2_quant, s2 = dynamic_per_tensor_fp8_quant(w2)
348358
w31_weight.data[expert_id].copy_(w31_quant)
@@ -354,11 +364,18 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales):
354364
return w31_dequantized, w2_dequantized
355365

356366
x, router_logits, w31_weight, w2_weight, w31_scales, w2_scales = _get_test_data(
357-
otype, wtype, batch_size, hidden_size, num_experts, intermediate_size, X_GEN_SCALE
367+
otype,
368+
wtype,
369+
batch_size,
370+
hidden_size,
371+
num_experts,
372+
intermediate_size,
373+
X_GEN_SCALE,
374+
W_GEN_SCALE,
358375
)
359376

360377
w31_dequantized, w2_dequantized = dequantize_weights(
361-
w31_weight, w2_weight, w31_scales, w2_scales
378+
w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE
362379
)
363380

364381
routing_weights, selected_experts = compute_routing(router_logits, top_k)

0 commit comments

Comments
 (0)