Add Experts4bit for 4-bit quantization of fused MoE experts#1965
Add Experts4bit for 4-bit quantization of fused MoE experts#1965pjordanandrsn wants to merge 2 commits into
Conversation
…ytes-foundation#1849) transformers v5 stores fused MoE experts as a single 3D nn.Parameter (e.g. OlmoeExperts, Qwen3MoeExperts), which the nn.Linear-based 4-bit walker skips. The experts stay in full precision and load_in_4bit barely shrinks the model (issue bitsandbytes-foundation#1849). Experts4bit holds gate_up_proj and down_proj packed in NF4/FP4 as plain nn.Parameter buffers, with per-expert absmax kept on the module itself. The forward pass dequantizes one expert at a time (a per-expert loop), mirroring the reference fused-experts forward. There is no Params4bit tensor-subclass machinery, so the module serializes through the default state_dict with no custom hooks. - from_float() quantizes existing bf16/fp16 expert stacks - enforces in_features % blocksize == 0 for clean per-expert blocking - double-quant (compress_statistics) and grouped-GEMM intentionally deferred for a first cut - tests: quant round-trip, forward vs. full-precision reference, state_dict round-trip, and validation guards
|
Hi, thanks for the PR. I am a little concerned with how quickly it was opened after discussion. With that said I'll follow up soon, but likely we won't merge something for this until after v0.50.0 release. |
|
Thanks @matthewdouglas — fair concern. The asking-first part was real: nothing was written until the shape was pinned down, and the PR follows it — plain No rush from me; post-v0.50.0 was always the plan. Converting to draft so it reads as what it is — something concrete to react to when you pick the feature up. Happy to rework it toward whatever you land on, or for you to cherry-pick the useful parts. |
What
Adds
bitsandbytes.nn.Experts4bit, a module that stores fused Mixture-of-Expertsweights in 4-bit (NF4/FP4) precision.
Fixes the memory issue in #1849: transformers v5 stores MoE experts as a single 3D
nn.Parameter(e.g.OlmoeExperts,Qwen3MoeExperts—gate_up_proj[num_experts, 2*intermediate, hidden],down_proj[num_experts, hidden, intermediate]).The
nn.Linear-based 4-bit walker only swapsnn.Linear, so these fused experts areskipped, stay in full precision, and dominate the loaded footprint.
Design
This follows the approach @matthewdouglas outlined on the issue:
nn.Parameterfor the packed weights (notParams4bit), with per-expertabsmaxkept on the module as buffers. This avoids bendingParams4bit'stensor-subclass + device-movement machinery around a 3D stack, and the module
serializes through the default
state_dict— no custom save/load hooks.forward(mirrors the reference fused-experts forward inOlmoeExperts/FP8Experts): one expert's weight is dequantized, used, and freed at atime. This keeps the runtime working set small and leaves a clean path to a grouped-GEMM
kernel later.
in_features % blocksize == 0so per-expert quantization blocks tile eachexpert exactly and never straddle an expert boundary.
Relationship to
replace_parameter_4bit(#1720): that generic parametrization alsoquantizes arbitrary
nn.Parameters, but dequantizes the entire[num_experts, …]stackon every access.
Experts4bitis MoE-aware — it only touches the experts a batch actuallyroutes to — which is what enables the grouped-GEMM follow-up.
Intentionally deferred for this first cut (per the issue discussion): double-quant
(
compress_statistics), a grouped-GEMM forward, and the transformers-side walker wiring.API
Footprint & validation (measured on an RTX A2000 12 GB, sm_86)
For one real OLMoE-1B-7B layer (
num_experts=64, hidden=2048, intermediate=1024, NF4,blocksize 64, no double-quant), measured
Experts4bitvs. the bf16 stack:Experts4bit(192 MB packed + 24 MB absmax)3.56× smaller for the expert weights, which are the bulk of the model — combined with
the existing
Linear4bitpath on the non-expert layers this takes OLMoE-1B-7B from ~13 GBto ~3.5 GB (fits a single 12 GB card). A forward over the real-sized layer peaks at
1295 MB of VRAM: because experts are dequantized one at a time, the working set never
materializes the full bf16 stack — the property that makes the grouped-GEMM follow-up
worthwhile.
Testing
tests/test_experts4bit.py— 11 cases, all green on the CPU default backend:packed-weight / absmax shape + dtype assertions
forwardvs. a full-precision reference forward (gated + non-gated), float32 compute,rtol=atol=1e-4state_dictround-trip: bit-exact restore of packed weights + absmax, identical forwardafter reload
in_features % blocksize, invalidquant_type)On CUDA (A2000, bnb 0.49.2 / torch 2.4.1) the NF4 round-trip mean-abs error is 0.0073 and
the forward matches the full-precision reference exactly (max-abs 0.0).
Closes #1849.
cc @matthewdouglas @SunMarc