Skip to content

Commit b9b2802

Browse files
authored
[None][feat] Autodeploy: Update the ssm to use slice (#8667)
Signed-off-by: nvchenghaoz <[email protected]>
1 parent 7c8ba71 commit b9b2802

File tree

2 files changed

+52
-26
lines changed

2 files changed

+52
-26
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,11 @@ def _triton_cached_ssm(
8080
if num_prefill > 0:
8181
seq_len_prefill = seq_len[:num_prefill].to(torch.int32)
8282
total_prefill_tokens = int(seq_len_prefill.sum().item())
83-
prefill_idx = torch.arange(total_prefill_tokens, device=device, dtype=torch.long)
8483

85-
hs_prefill = hs_flat.index_select(0, prefill_idx).unsqueeze(0) # [1, S_p, H, D]
86-
B_prefill = B_flat.index_select(0, prefill_idx).unsqueeze(0) # [1, S_p, G, N]
87-
C_prefill = C_flat.index_select(0, prefill_idx).unsqueeze(0) # [1, S_p, G, N]
88-
dt_prefill = dt_flat.index_select(0, prefill_idx).unsqueeze(0) # [1, S_p, H]
84+
hs_prefill = hs_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
85+
B_prefill = B_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
86+
C_prefill = C_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
87+
dt_prefill = dt_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H]
8988

9089
cu_seqlens = torch.cat(
9190
[
@@ -128,20 +127,20 @@ def _triton_cached_ssm(
128127
return_varlen_states=True,
129128
)
130129

131-
y_flat.index_copy_(0, prefill_idx, y_prefill[0].to(y_flat.dtype))
130+
y_flat[:total_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
132131
ssm_state_cache.index_copy_(
133132
0, slot_idx[:num_prefill].to(torch.long), varlen_states.to(ssm_state_cache.dtype)
134133
)
135134

136135
# Decode: batch single-token updates via selective_state_update
137136
if num_decode > 0:
138-
decode_idx = seq_start[num_prefill:].to(torch.long)
137+
total_prefill_tokens = 0 if num_prefill == 0 else int(seq_len[:num_prefill].sum().item())
139138
slot_idx_decode = slot_idx[num_prefill:].to(torch.long)
140139

141-
x_decode = hs_flat.index_select(0, decode_idx) # [nd, H, D]
142-
B_decode = B_flat.index_select(0, decode_idx) # [nd, G, N]
143-
C_decode = C_flat.index_select(0, decode_idx) # [nd, G, N]
144-
dt_decode = dt_flat.index_select(0, decode_idx) # [nd, H]
140+
x_decode = hs_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H, D]
141+
B_decode = B_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N]
142+
C_decode = C_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N]
143+
dt_decode = dt_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H]
145144

146145
dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim)
147146
dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim)
@@ -165,7 +164,7 @@ def _triton_cached_ssm(
165164
state_batch_indices=slot_idx_decode,
166165
) # [nd, H, D]
167166

168-
y_flat.index_copy_(0, decode_idx, y_dec.to(y_flat.dtype))
167+
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode] = y_dec.to(y_flat.dtype)
169168

170169
return y
171170

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,16 @@ def _pack_routed_tokens_reference(
7575
return sorted_token_ids_used, expert_ids_used, used_len
7676

7777

78-
def test_triton_moe_matches_torch_moe_mlp_relu2():
78+
@pytest.mark.parametrize("early_exit", [False, True])
79+
def test_triton_moe_matches_torch_moe_mlp_relu2(early_exit):
7980
torch.manual_seed(0)
8081

8182
if not torch.cuda.is_available():
8283
pytest.skip("CUDA is required for triton_moe fused MLP test")
8384
device = "cuda"
8485
dtype = torch.bfloat16
8586

86-
M = 8 # tokens
87+
M = 32 if early_exit else 8 # tokens
8788
HIDDEN_SIZE = 8
8889
INTERMEDIATE_SIZE = 16
8990
E = 8 # experts
@@ -102,12 +103,26 @@ def test_triton_moe_matches_torch_moe_mlp_relu2():
102103
w_up_stacked = torch.stack(w_up_list, dim=0).contiguous() # [E, I, H]
103104
w_down_stacked = torch.stack(w_down_list, dim=0).contiguous() # [E, H, I]
104105

105-
# Create routing with top-k normalization
106-
router_logits = torch.randn(M, E, device=device, dtype=torch.float32)
107-
routing_full = torch.softmax(router_logits, dim=-1)
108-
routing_weights, selected_experts = torch.topk(routing_full, k=top_k, dim=-1)
109-
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
110-
routing_weights = routing_weights.to(torch.float32)
106+
# Create routing based on whether we want to test early exit
107+
if not early_exit:
108+
# Random routing with top-k normalization
109+
router_logits = torch.randn(M, E, device=device, dtype=torch.float32)
110+
routing_full = torch.softmax(router_logits, dim=-1)
111+
routing_weights, selected_experts = torch.topk(routing_full, k=top_k, dim=-1)
112+
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
113+
routing_weights = routing_weights.to(torch.float32)
114+
else:
115+
# Imbalanced routing: concentrate 75% of tokens on first 2 experts
116+
# This tests early exit logic in num_tokens_post_padded path
117+
selected_experts = torch.zeros((M, top_k), dtype=torch.int64, device=device)
118+
for i in range(M):
119+
if i < M * 3 // 4:
120+
selected_experts[i, 0] = 0
121+
selected_experts[i, 1] = 1
122+
else:
123+
selected_experts[i, 0] = i % E
124+
selected_experts[i, 1] = (i + 1) % E
125+
routing_weights = torch.ones((M, top_k), device=device, dtype=torch.float32) / top_k
111126

112127
# Triton fused MoE (mlp with relu^2 activation between two GEMMs)
113128
out_triton = torch.ops.auto_deploy.triton_moe_fused(
@@ -219,7 +234,8 @@ def test_moe_align_kernel_groups_tokens_by_expert_and_block_padding():
219234

220235

221236
@skip_pre_hopper
222-
def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe():
237+
@pytest.mark.parametrize("early_exit", [False, True])
238+
def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
223239
"""Test triton_quant_fp8_moe against torch_quant_fp8_moe reference."""
224240
torch.manual_seed(0)
225241

@@ -228,7 +244,7 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe():
228244
device = "cuda"
229245
dtype = torch.bfloat16
230246

231-
M = 32 # tokens
247+
M = 64 if early_exit else 32 # tokens
232248
HIDDEN_SIZE = 16 # Must be multiple of 16 for FP8 linear
233249
INTERMEDIATE_SIZE = 32 # Must be multiple of 16 for FP8 linear
234250
E = 4 # experts
@@ -313,12 +329,23 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe():
313329
w3_weight_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)]
314330
w3_weight_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32)
315331

316-
# Create controlled routing to ensure even token distribution across experts
332+
# Create routing based on whether we want to test early exit
317333
selected_experts = torch.zeros((M, top_k), dtype=torch.int64, device=device)
318-
for i in range(M):
334+
if not early_exit:
319335
# Distribute tokens evenly: token i goes to experts (i % E) and ((i+1) % E)
320-
selected_experts[i, 0] = i % E
321-
selected_experts[i, 1] = (i + 1) % E
336+
for i in range(M):
337+
selected_experts[i, 0] = i % E
338+
selected_experts[i, 1] = (i + 1) % E
339+
else:
340+
# Imbalanced routing: concentrate 75% of tokens on first 2 experts
341+
# This tests early exit logic in num_tokens_post_padded path
342+
for i in range(M):
343+
if i < M * 3 // 4:
344+
selected_experts[i, 0] = 0
345+
selected_experts[i, 1] = 1
346+
else:
347+
selected_experts[i, 0] = i % E
348+
selected_experts[i, 1] = (i + 1) % E
322349

323350
# Create equal routing weights
324351
routing_weights = torch.ones((M, top_k), device=device, dtype=torch.float32) / top_k

0 commit comments

Comments
 (0)