Skip to content

Commit e051a05

Browse files
authored
[#8694][fix] fix AutoDeploy cuda memory access failure in nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3 (#8696)
Signed-off-by: Eran Geva <[email protected]>
1 parent b37a8a9 commit e051a05

File tree

1 file changed

+20
-2
lines changed
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe

1 file changed

+20
-2
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,16 @@ def triton_fused_moe(
517517
"""Triton unquantized MoE with 2-layer MLP and ReLU^2 activation."""
518518
x_shape = x.shape
519519
x2d = x.view(-1, x_shape[-1])
520-
topk_ids = selected_experts.to(torch.int32).contiguous()
520+
521+
# Get number of local experts from weight shape
522+
num_experts = w1_stacked_weight.shape[0]
523+
524+
# Clamp expert IDs to valid range to handle EP sharding
525+
# After EP sharding, some expert IDs may be negative (for experts on other ranks)
526+
# Clamp them to 0 (first expert) - these will be masked by routing_weights=0 anyway
527+
selected_experts_clamped = torch.clamp(selected_experts, min=0, max=num_experts - 1)
528+
529+
topk_ids = selected_experts_clamped.to(torch.int32).contiguous()
521530
topk_weights = routing_weights.to(torch.float32).contiguous()
522531

523532
out2d = _fused_moe_mlp_relu2(x2d, w1_stacked_weight, w2_stacked_weight, topk_ids, topk_weights)
@@ -565,7 +574,16 @@ def triton_quant_fp8_moe(
565574

566575
x_shape = x.shape
567576
x2d = x.view(-1, x_shape[-1])
568-
topk_ids = selected_experts.to(torch.int32).contiguous()
577+
578+
# Get number of local experts from weight shape
579+
num_experts = w1_weight.shape[0]
580+
581+
# Clamp expert IDs to valid range to handle EP sharding
582+
# After EP sharding, some expert IDs may be negative (for experts on other ranks)
583+
# Clamp them to 0 (first expert) - these will be masked by routing_weights=0 anyway
584+
selected_experts_clamped = torch.clamp(selected_experts, min=0, max=num_experts - 1)
585+
586+
topk_ids = selected_experts_clamped.to(torch.int32).contiguous()
569587
topk_weights = routing_weights.to(torch.float32).contiguous()
570588

571589
# Weights are already stacked [E, ...] - just ensure contiguous and extract scales

0 commit comments

Comments
 (0)