diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 9edf1ce6836..64b62419162 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -80,12 +80,11 @@ def _triton_cached_ssm( if num_prefill > 0: seq_len_prefill = seq_len[:num_prefill].to(torch.int32) total_prefill_tokens = int(seq_len_prefill.sum().item()) - prefill_idx = torch.arange(total_prefill_tokens, device=device, dtype=torch.long) - hs_prefill = hs_flat.index_select(0, prefill_idx).unsqueeze(0) # [1, S_p, H, D] - B_prefill = B_flat.index_select(0, prefill_idx).unsqueeze(0) # [1, S_p, G, N] - C_prefill = C_flat.index_select(0, prefill_idx).unsqueeze(0) # [1, S_p, G, N] - dt_prefill = dt_flat.index_select(0, prefill_idx).unsqueeze(0) # [1, S_p, H] + hs_prefill = hs_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H, D] + B_prefill = B_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] + C_prefill = C_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] + dt_prefill = dt_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H] cu_seqlens = torch.cat( [ @@ -128,20 +127,20 @@ def _triton_cached_ssm( return_varlen_states=True, ) - y_flat.index_copy_(0, prefill_idx, y_prefill[0].to(y_flat.dtype)) + y_flat[:total_prefill_tokens] = y_prefill[0].to(y_flat.dtype) ssm_state_cache.index_copy_( 0, slot_idx[:num_prefill].to(torch.long), varlen_states.to(ssm_state_cache.dtype) ) # Decode: batch single-token updates via selective_state_update if num_decode > 0: - decode_idx = seq_start[num_prefill:].to(torch.long) + total_prefill_tokens = 0 if num_prefill == 0 else int(seq_len[:num_prefill].sum().item()) slot_idx_decode = slot_idx[num_prefill:].to(torch.long) - x_decode = hs_flat.index_select(0, decode_idx) # [nd, H, D] - B_decode = B_flat.index_select(0, decode_idx) # [nd, G, N] - C_decode = C_flat.index_select(0, decode_idx) # [nd, G, N] - dt_decode = dt_flat.index_select(0, decode_idx) # [nd, H] + x_decode = hs_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H, D] + B_decode = B_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N] + C_decode = C_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N] + dt_decode = dt_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H] dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim) dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim) @@ -165,7 +164,7 @@ def _triton_cached_ssm( state_batch_indices=slot_idx_decode, ) # [nd, H, D] - y_flat.index_copy_(0, decode_idx, y_dec.to(y_flat.dtype)) + y_flat[total_prefill_tokens : total_prefill_tokens + num_decode] = y_dec.to(y_flat.dtype) return y diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py index 64207513c00..c639c355e82 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py @@ -75,7 +75,8 @@ def _pack_routed_tokens_reference( return sorted_token_ids_used, expert_ids_used, used_len -def test_triton_moe_matches_torch_moe_mlp_relu2(): +@pytest.mark.parametrize("early_exit", [False, True]) +def test_triton_moe_matches_torch_moe_mlp_relu2(early_exit): torch.manual_seed(0) if not torch.cuda.is_available(): @@ -83,7 +84,7 @@ def test_triton_moe_matches_torch_moe_mlp_relu2(): device = "cuda" dtype = torch.bfloat16 - M = 8 # tokens + M = 32 if early_exit else 8 # tokens HIDDEN_SIZE = 8 INTERMEDIATE_SIZE = 16 E = 8 # experts @@ -102,12 +103,26 @@ def test_triton_moe_matches_torch_moe_mlp_relu2(): w_up_stacked = torch.stack(w_up_list, dim=0).contiguous() # [E, I, H] w_down_stacked = torch.stack(w_down_list, dim=0).contiguous() # [E, H, I] - # Create routing with top-k normalization - router_logits = torch.randn(M, E, device=device, dtype=torch.float32) - routing_full = torch.softmax(router_logits, dim=-1) - routing_weights, selected_experts = torch.topk(routing_full, k=top_k, dim=-1) - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(torch.float32) + # Create routing based on whether we want to test early exit + if not early_exit: + # Random routing with top-k normalization + router_logits = torch.randn(M, E, device=device, dtype=torch.float32) + routing_full = torch.softmax(router_logits, dim=-1) + routing_weights, selected_experts = torch.topk(routing_full, k=top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(torch.float32) + else: + # Imbalanced routing: concentrate 75% of tokens on first 2 experts + # This tests early exit logic in num_tokens_post_padded path + selected_experts = torch.zeros((M, top_k), dtype=torch.int64, device=device) + for i in range(M): + if i < M * 3 // 4: + selected_experts[i, 0] = 0 + selected_experts[i, 1] = 1 + else: + selected_experts[i, 0] = i % E + selected_experts[i, 1] = (i + 1) % E + routing_weights = torch.ones((M, top_k), device=device, dtype=torch.float32) / top_k # Triton fused MoE (mlp with relu^2 activation between two GEMMs) out_triton = torch.ops.auto_deploy.triton_moe_fused( @@ -219,7 +234,8 @@ def test_moe_align_kernel_groups_tokens_by_expert_and_block_padding(): @skip_pre_hopper -def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(): +@pytest.mark.parametrize("early_exit", [False, True]) +def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit): """Test triton_quant_fp8_moe against torch_quant_fp8_moe reference.""" torch.manual_seed(0) @@ -228,7 +244,7 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(): device = "cuda" dtype = torch.bfloat16 - M = 32 # tokens + M = 64 if early_exit else 32 # tokens HIDDEN_SIZE = 16 # Must be multiple of 16 for FP8 linear INTERMEDIATE_SIZE = 32 # Must be multiple of 16 for FP8 linear E = 4 # experts @@ -313,12 +329,23 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(): w3_weight_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)] w3_weight_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32) - # Create controlled routing to ensure even token distribution across experts + # Create routing based on whether we want to test early exit selected_experts = torch.zeros((M, top_k), dtype=torch.int64, device=device) - for i in range(M): + if not early_exit: # Distribute tokens evenly: token i goes to experts (i % E) and ((i+1) % E) - selected_experts[i, 0] = i % E - selected_experts[i, 1] = (i + 1) % E + for i in range(M): + selected_experts[i, 0] = i % E + selected_experts[i, 1] = (i + 1) % E + else: + # Imbalanced routing: concentrate 75% of tokens on first 2 experts + # This tests early exit logic in num_tokens_post_padded path + for i in range(M): + if i < M * 3 // 4: + selected_experts[i, 0] = 0 + selected_experts[i, 1] = 1 + else: + selected_experts[i, 0] = i % E + selected_experts[i, 1] = (i + 1) % E # Create equal routing weights routing_weights = torch.ones((M, top_k), device=device, dtype=torch.float32) / top_k