Skip to content

fix: fused_experts observer shape bug and router tuple crash#12

Open
naoufelito wants to merge 1 commit intoCerebrasResearch:mainfrom
naoufelito:fix/fused-experts-observer-shape
Open

fix: fused_experts observer shape bug and router tuple crash#12
naoufelito wants to merge 1 commit intoCerebrasResearch:mainfrom
naoufelito:fix/fused-experts-observer-shape

Conversation

@naoufelito
Copy link

@naoufelito naoufelito commented Feb 12, 2026

Summary

Fixes two bugs in the fused_experts branch of MoETransformerObserver._hook_fn:

  • Router tuple crash: module.router(flat_input) may return a tuple (e.g. Llama4Router returns (scores, logits)), causing torch.topk to fail with a TypeError. Now handled by checking isinstance(router_output, tuple) and extracting the last element.

  • Shape mismatch: router_scores.size(0) equals total_tokens, not num_experts (router output shape is (total_tokens, num_experts), not (num_experts, total_tokens) as the comment claimed). This caused routed_in to have shape (total_tokens², hidden_dim) instead of (num_experts * total_tokens, hidden_dim), and the subsequent .view(num_experts, ...) crashes with a RuntimeError. Replaced the broken gather-based construction with flat_input.repeat(num_experts, 1).

Fixes #11

Test plan

  • Added regression test tests/test_fused_observer_shape_bug.py with total_tokens=10 and num_experts=4 (crashes before fix, passes after)
  • All 13 existing tests pass

@naoufelito naoufelito force-pushed the fix/fused-experts-observer-shape branch from a7e6d5b to dc7d47b Compare February 12, 2026 11:22
Fix two bugs in the fused_experts branch of MoETransformerObserver:

1. module.router(flat_input) returns a tuple for some routers (e.g.
   Llama4Router returns (scores, logits)), causing torch.topk to fail
   with a TypeError. Replace with F.linear(flat_input, router.weight,
   router.bias) to get raw logits directly.

2. router_scores.size(0) equals total_tokens, not num_experts, causing
   routed_in to have shape (total_tokens^2, hidden_dim) instead of
   (num_experts * total_tokens, hidden_dim). The subsequent .view()
   then crashes with a RuntimeError. Replace the broken gather-based
   construction with flat_input.repeat(num_experts, 1).

Fixes CerebrasResearch#11
@naoufelito naoufelito force-pushed the fix/fused-experts-observer-shape branch from dc7d47b to af642b2 Compare February 12, 2026 11:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug: Shape mismatch and invalid tensor operations in fused experts path of MoETransformerObserver

1 participant