Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,11 @@ def _triton_cached_ssm(
if num_prefill > 0:
seq_len_prefill = seq_len[:num_prefill].to(torch.int32)
total_prefill_tokens = int(seq_len_prefill.sum().item())
prefill_idx = torch.arange(total_prefill_tokens, device=device, dtype=torch.long)

hs_prefill = hs_flat.index_select(0, prefill_idx).unsqueeze(0) # [1, S_p, H, D]
B_prefill = B_flat.index_select(0, prefill_idx).unsqueeze(0) # [1, S_p, G, N]
C_prefill = C_flat.index_select(0, prefill_idx).unsqueeze(0) # [1, S_p, G, N]
dt_prefill = dt_flat.index_select(0, prefill_idx).unsqueeze(0) # [1, S_p, H]
hs_prefill = hs_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
B_prefill = B_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
C_prefill = C_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
dt_prefill = dt_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H]

cu_seqlens = torch.cat(
[
Expand Down Expand Up @@ -128,20 +127,20 @@ def _triton_cached_ssm(
return_varlen_states=True,
)

y_flat.index_copy_(0, prefill_idx, y_prefill[0].to(y_flat.dtype))
y_flat[:total_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
ssm_state_cache.index_copy_(
0, slot_idx[:num_prefill].to(torch.long), varlen_states.to(ssm_state_cache.dtype)
)

# Decode: batch single-token updates via selective_state_update
if num_decode > 0:
decode_idx = seq_start[num_prefill:].to(torch.long)
total_prefill_tokens = 0 if num_prefill == 0 else int(seq_len[:num_prefill].sum().item())
slot_idx_decode = slot_idx[num_prefill:].to(torch.long)

x_decode = hs_flat.index_select(0, decode_idx) # [nd, H, D]
B_decode = B_flat.index_select(0, decode_idx) # [nd, G, N]
C_decode = C_flat.index_select(0, decode_idx) # [nd, G, N]
dt_decode = dt_flat.index_select(0, decode_idx) # [nd, H]
x_decode = hs_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H, D]
B_decode = B_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N]
C_decode = C_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N]
dt_decode = dt_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H]

dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim)
dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim)
Expand All @@ -165,7 +164,7 @@ def _triton_cached_ssm(
state_batch_indices=slot_idx_decode,
) # [nd, H, D]

y_flat.index_copy_(0, decode_idx, y_dec.to(y_flat.dtype))
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode] = y_dec.to(y_flat.dtype)

return y

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,16 @@ def _pack_routed_tokens_reference(
return sorted_token_ids_used, expert_ids_used, used_len


def test_triton_moe_matches_torch_moe_mlp_relu2():
@pytest.mark.parametrize("early_exit", [False, True])
def test_triton_moe_matches_torch_moe_mlp_relu2(early_exit):
torch.manual_seed(0)

if not torch.cuda.is_available():
pytest.skip("CUDA is required for triton_moe fused MLP test")
device = "cuda"
dtype = torch.bfloat16

M = 8 # tokens
M = 32 if early_exit else 8 # tokens
HIDDEN_SIZE = 8
INTERMEDIATE_SIZE = 16
E = 8 # experts
Expand All @@ -102,12 +103,26 @@ def test_triton_moe_matches_torch_moe_mlp_relu2():
w_up_stacked = torch.stack(w_up_list, dim=0).contiguous() # [E, I, H]
w_down_stacked = torch.stack(w_down_list, dim=0).contiguous() # [E, H, I]

# Create routing with top-k normalization
router_logits = torch.randn(M, E, device=device, dtype=torch.float32)
routing_full = torch.softmax(router_logits, dim=-1)
routing_weights, selected_experts = torch.topk(routing_full, k=top_k, dim=-1)
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(torch.float32)
# Create routing based on whether we want to test early exit
if not early_exit:
# Random routing with top-k normalization
router_logits = torch.randn(M, E, device=device, dtype=torch.float32)
routing_full = torch.softmax(router_logits, dim=-1)
routing_weights, selected_experts = torch.topk(routing_full, k=top_k, dim=-1)
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(torch.float32)
else:
# Imbalanced routing: concentrate 75% of tokens on first 2 experts
# This tests early exit logic in num_tokens_post_padded path
selected_experts = torch.zeros((M, top_k), dtype=torch.int64, device=device)
for i in range(M):
if i < M * 3 // 4:
selected_experts[i, 0] = 0
selected_experts[i, 1] = 1
else:
selected_experts[i, 0] = i % E
selected_experts[i, 1] = (i + 1) % E
routing_weights = torch.ones((M, top_k), device=device, dtype=torch.float32) / top_k

# Triton fused MoE (mlp with relu^2 activation between two GEMMs)
out_triton = torch.ops.auto_deploy.triton_moe_fused(
Expand Down Expand Up @@ -219,7 +234,8 @@ def test_moe_align_kernel_groups_tokens_by_expert_and_block_padding():


@skip_pre_hopper
def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe():
@pytest.mark.parametrize("early_exit", [False, True])
def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
"""Test triton_quant_fp8_moe against torch_quant_fp8_moe reference."""
torch.manual_seed(0)

Expand All @@ -228,7 +244,7 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe():
device = "cuda"
dtype = torch.bfloat16

M = 32 # tokens
M = 64 if early_exit else 32 # tokens
HIDDEN_SIZE = 16 # Must be multiple of 16 for FP8 linear
INTERMEDIATE_SIZE = 32 # Must be multiple of 16 for FP8 linear
E = 4 # experts
Expand Down Expand Up @@ -313,12 +329,23 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe():
w3_weight_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)]
w3_weight_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32)

# Create controlled routing to ensure even token distribution across experts
# Create routing based on whether we want to test early exit
selected_experts = torch.zeros((M, top_k), dtype=torch.int64, device=device)
for i in range(M):
if not early_exit:
# Distribute tokens evenly: token i goes to experts (i % E) and ((i+1) % E)
selected_experts[i, 0] = i % E
selected_experts[i, 1] = (i + 1) % E
for i in range(M):
selected_experts[i, 0] = i % E
selected_experts[i, 1] = (i + 1) % E
else:
# Imbalanced routing: concentrate 75% of tokens on first 2 experts
# This tests early exit logic in num_tokens_post_padded path
for i in range(M):
if i < M * 3 // 4:
selected_experts[i, 0] = 0
selected_experts[i, 1] = 1
else:
selected_experts[i, 0] = i % E
selected_experts[i, 1] = (i + 1) % E

# Create equal routing weights
routing_weights = torch.ones((M, top_k), device=device, dtype=torch.float32) / top_k
Expand Down