Skip to content

OnnxKQuantQuantization skips per-expert MatMuls in static-unroll MoE blocks #2489

@justinchuby

Description

@justinchuby

Summary

OnnxKQuantQuantization only quantizes MatMul nodes whose second input is a 2-D static initializer (kquant_quantization.py:303-308). This is the correct fast path for dense transformer weights, but it leaves a whole class of MatMul nodes unquantized: the per-expert MatMuls in fallback/static-unroll Mixture-of-Experts (MoE) blocks, where the B input is dynamic.

Concrete impact: on a google/gemma-4-26b-a4b-it export from mobius (onnxruntime/mobius#324) the decoder has 7680 expert MatMul nodes (30 layers × 128 experts × 2 matmuls per expert). The k-quant pass quantizes only the 236 dense MatMuls (attention projections, lm_head, …) and leaves all 7680 expert MatMuls at fp16. Resulting Q4_K_M model is only ~6% smaller than the fp16 source (47 GB vs 51 GB).

Pattern

Static-unroll MoE typically packs all experts' weights into a single 3-D initializer [E, fc1_inter, hidden] (matching HF's experts.gate_up_proj) and then dispatches per-expert at runtime:

W_all = [E, K, N]  ← static initializer
                          ↓
W_expert = Gather(W_all, [expert_idx], axis=0)        # [1, K, N]
W_2d     = Squeeze(W_expert, [0])                     # [K, N]
y        = MatMul(x, W_2d)                            # ← B input is *not* an initializer

K-quant skips that MatMul because node.inputs[1].is_initializer() is False. The "weight" is effectively a fixed value at runtime but the static-analysis sees a Squeeze output.

Two reasonable fixes

(A) Pattern-aware quantization in OnnxKQuantQuantization. When MatMul.input[1] traces back to a Gather(3D_initializer, axis=0) → (optional) Squeeze → MatMul, slice the 3-D initializer along axis 0 and quantize each [K, N] 2-D plane independently. Reassemble into a 3-D MatMulNBits-style packed initializer indexed by the same expert id. This is the optimal solution for size but invasive: needs a new packed 3-D MatMulNBits op (or GroupedMatMulNBits) that ORT can dispatch.

(B) Pre-pass that unstacks 3-D MoE weights. Add a separate pass (StackedMatMulUnstack or similar) that runs before OnnxKQuantQuantization:

  1. Find Gather(W_3d, expert_idx, axis=0) → Squeeze → MatMul(x, ...).
  2. For each constant expert_idx that the model could pick (or unconditionally for all E), replace with If(expert_idx == k, MatMul(x, W_3d[k]), ...). Or simpler: replace the dispatch with E parallel MatMuls and an aggregate, but only one is selected at runtime.
  3. After this pass each per-expert weight W_3d[k] is a separate 2-D initializer that K-quant picks up automatically.

Option B keeps OnnxKQuantQuantization focused on its single job and is straightforward to implement, at the cost of E times more weight initializers (manageable: 128 experts × 30 layers = 3840 extra inits for Gemma 4 26B, but each is only ~2.2 MB after Q4_K_M).

Workaround today

Drop the fused com.microsoft::MoE op (which is currently broken for standard SwiGLU — see microsoft/onnxruntime#28738) and accept the fp16-weight, large-ONNX, slow-session-load tradeoff. mobius onnxruntime/mobius#324 does this for Gemma 4.

Environment

  • olive-ai latest main
  • mobius onnxruntime/mobius#324 (Gemma 4 fallback MoE)
  • ONNX Runtime 1.27.0

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions