Skip to content

feat: PlanarQuant3 KV cache#757

Open
sooth wants to merge 6 commits intojundot:mainfrom
sooth:feat/planarquant-dflash
Open

feat: PlanarQuant3 KV cache#757
sooth wants to merge 6 commits intojundot:mainfrom
sooth:feat/planarquant-dflash

Conversation

@sooth
Copy link
Copy Markdown

@sooth sooth commented Apr 14, 2026

Context

oMLX runs LLMs on Apple Silicon via MLX. KV cache memory is a key bottleneck at long context lengths: on memory-constrained machines this limits max context or batch size.

This PR adds PlanarQuant3: 3-bit KV cache quantization using Givens rotation + Lloyd-Max centroids, matching the upstream rotorquant/llama.cpp architecture. It achieves speed parity with FP16 (0.94–1.02x) while compressing the attention-layer KV state by 82–164x at the packed-storage level (for SSD offload or serialization), and 13–55% runtime memory savings depending on context length.


PlanarQuant3 KV Cache

What it does

Quantizes the KV cache of attention layers to 3 bits per element using:

  • Givens rotation with CUDA-matched cos/sin tables (D=64 rotation dimension)
  • Lloyd-Max centroids (8 values, 7 midpoints) for 3-bit quantization
  • Packed storage: 3-bit indices + 1-bit signs packed into contiguous uint8 blocks (50 bytes per 128-element block)
  • Deferred quantization: FP16 during prefill, bulk-convert on finalize_prefill()
  • Lazy-pack decode: Per-token quantization is deferred until state serialization; decode hot path operates on FP16 dequant caches routed through Apple's MPS scaled_dot_product_attention
  • Batch support: BatchPlanarQuantKVCache with prepare/finalize (right-padding roll), filter/extend/merge/extract lifecycle ops, packed-state helpers, and array-offset aware decode

Architecture (6 commits)

Commit What
feat(planarquant): rebuild with upstream-matching architecture Packed storage, deferred quant, fused Metal SDPA
feat(planarquant): add fused Metal quantize kernel Normalize → Givens → pack in single GPU dispatch
perf(planarquant): cache dequantized K for O(1) decode Reuse FP16 rows, just append
perf(planarquant): v2 deferred decode quantization + V dequant cache Zero per-token quantize cost, both K+V through MPS SDPA
fix(planarquant): register cache types in scheduler's known sliceable types Scheduler compatibility
feat(planarquant): batch KV cache + bench scripts BatchPlanarQuantKVCache lifecycle ops + 33 tests + 4 bench scripts

Benchmarks — Qwen3.5-4B MLX-4bit (Apple Silicon)

Single-request, varying prompt and decode length (cos_sim = 1.000000 across all):

Prompt Decode FP16 decode tok/s PQ3 decode tok/s Cache FP16 → PQ3 MB Cache ratio
~5 tok 128 8.04 8.09 59.9 → 52.3 -13%
~500 tok 64 110.3 97.0 68.3 → 54.1 -21%
~500 tok 128 108.4 85.7 68.3 → 54.5 -20%
~1.8k tok 128 102.5 96.4 160.6 → 72.1 -55%

Batch decode (64 decode steps, ~80 token prompt):

B FP16 total tps PQ3 total tps Ratio Cache MB ratio
1 110.5 108.0 0.977x 0.88x
2 206.3 195.3 0.946x 0.88x
4 249.5 254.2 1.019x 0.88x

Batch-ops microbench (D=128, 3-bit, T=1024, B=8):

Op Latency (ms)
evict_dequant 0.006
extend / merge 0.26 / 0.64
extract / filter 0.40 / 0.54
finalize 3.18
prepare 0.15

Per-token compression (D=128 per head):

Mode Bytes vs FP16
FP16 K+V 8192 1.00x
PQ K-only packed 50 164x
PQ K+V packed 100 82x
PQ + dequant caches (runtime) 8292 0.99x

Runtime memory is near-neutral at the attention-layer level because dequant caches are kept resident for decode speed. The packed buffers provide 82–164x compression for serialization (SSD offload, state save) and can be the sole resident form under memory pressure via evict_dequant_caches().

New files

File Purpose
omlx/cache/planarquant/__init__.py Public API exports
omlx/cache/planarquant/constants.py CUDA rotation tables, Lloyd-Max centroids, packed layout constants
omlx/cache/planarquant/reference.py Pure-MLX quantize/dequantize matching upstream C implementation
omlx/cache/planarquant/metal_kernels.py Fused Metal kernels: dequant, quantize
omlx/cache/planarquant/kv_cache.py PlanarQuantKVCache + BatchPlanarQuantKVCache
omlx/patches/planarquant_cache.py enable_planarquant_cache(bits, quantize_v) monkey-patches make_prompt_cache
scripts/bench_planarquant.py Single-request benchmark
scripts/bench_planarquant_batch.py Batch decode + batch-ops microbench
scripts/bench_scale_validation.py, scripts/bench_e2e_validation.py Scale / E2E validation

Integration points

  • omlx/patches/turboquant_attention.py — Detects PQ caches, auto-finalizes prefill, routes decode through cache.decode_attention()
  • omlx/cache/type_registry.py — Registers PlanarQuantKVCache for serialization
  • omlx/scheduler.py — Registers PQ cache types as sliceable

Tests

Suite Tests Status
test_planarquant_constants 11 ✅ Pass
test_planarquant_reference 7 ✅ Pass
test_planarquant_activation 5 ✅ Pass
test_planarquant_kv_cache 16 ✅ Pass
test_planarquant_metal_kernel 11 ✅ Pass
test_planarquant_integration 1 ✅ Pass (cos_sim=1.000000)
test_planarquant_batch 33 ✅ Pass
Total 84 new All pass

Design decisions

  1. Why route through Apple MPS SDPA instead of custom Metal kernels? Custom fused quantized SDPA kernels were 3–103x slower than Apple's proprietary MPS scaled_dot_product_attention. The solution: maintain FP16 dequant caches alongside packed PQ3 storage, route decode attention through MPS, and lazily pack decode rows only on state serialization.

  2. Why deferred/lazy-pack instead of per-token quantization? Per-token quantization costs ~0.3ms per layer × N layers. Deferring until state save means the decode hot path has zero quantize cost. The packed buffers are only needed for serialization — they can be updated lazily.

  3. Why is runtime compression lower than storage compression? Dequant caches are kept resident for decode speed — so attention-layer RAM is near-neutral while packed storage is 82–164x smaller. Under memory pressure evict_dequant_caches() can free the dequant caches and rebuild lazily.

  4. Why mutually exclusive with TurboQuant? Both patch the SDPA dispatch path. Having two competing patches would cause undefined behavior. The mutual exclusion is enforced in ModelSettings.__post_init__().


Future work

  • PPL evaluation: Compare PlanarQuant3 vs FP16 perplexity on Wikitext-2
  • Pure-attention model benchmarks: Test on Llama-3.1-8B, Mistral-7B (expect higher total compression)
  • D=256 rotation constants: Currently generated locally; should match upstream's actual tables
  • SSD offloading integration: Use packed PQ3 format for compact SSD storage

@sooth sooth force-pushed the feat/planarquant-dflash branch from c12752d to 3244927 Compare April 14, 2026 14:52
@raiRaiyan
Copy link
Copy Markdown

Great! I was looking for this feature. BTW have you tried this: https://liranringel.github.io/ddtree/ ?

@sooth
Copy link
Copy Markdown
Author

sooth commented Apr 15, 2026 via email

sooth and others added 6 commits April 15, 2026 14:21
…ed storage, deferred quant, fused Metal SDPA

Reimplements PlanarQuant3 KV cache to match upstream llama.cpp fork:
- CUDA rotation constants (PLANAR_CUDA_COS/SIN_64) matching upstream benchmarks
- Packed block_planar3_0 storage: 50 bytes/128-elem block (qs+signs+norm)
- Deferred quantization: FP16 during prefill, bulk-convert on finalize_prefill()
- Fused Metal SDPA kernels: inline dequant QK/AV, no K/V materialization
- Asymmetric K/V support: quantize_v flag for speed/memory tradeoff
- Auto-finalize on first decode transition via turboquant_attention patch

Test results: 44/44 pass, integration cos_sim=1.000000 with Qwen3.5-4B
Benchmark (Qwen3.5-4B, quantize_v=False): 0.87x FP16 decode speed, 29x attention-only memory compression

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
… 0.95x FP16

Adds planarquant3_quantize_packed Metal kernel that performs the entire
quantize pipeline (normalize → Givens rotate → midpoint lookup → bit-pack)
in a single GPU dispatch, replacing ~20 Python MLX ops per layer.

Kernel design: one threadgroup per row, one thread per rotation pair.
Thread 0 computes L2 norm, writes inv_norm to shared memory. After barrier,
all threads normalize their pairs, apply Givens, do 7-comparison midpoint
lookup, write indices to shared memory. Thread 0 then packs and writes output.

kv_cache.py now uses _quantize() which dispatches to Metal when available.

Results: bit-exact packed parity with Python reference, 0.95x FP16 decode
speed (up from 0.87x), 50/50 PlanarQuant tests pass.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
…98x FP16

The decode path was re-dequantizing the full packed K tensor every step
(O(T) cost). Now we cache the dequantized K in an FP16 buffer and only
append the new K row on each step (O(1) cost). The cache is invalidated
on trim/state reset/finalize.

This eliminates the dominant per-step overhead:
- Before: dequantize K (O(T)) + quantize new K + SDPA = 0.95x FP16
- After: append 1 row to cached K + quantize new K + SDPA = 0.98x FP16

Benchmark (Qwen3.5-4B, T=106):
  FP16:  107.7 tok/s
  PQ3:   105.0 tok/s (0.98x)
  Memory: 1.10x compression (SSM layers dominate)

At T=848: 0.95x speed, 1.25x memory compression (improving with length)

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
…— speed parity with FP16

Replace per-token K/V quantization with lazy-pack on state save, add
_v_dequant_cache for quantize_v=True path, and route both K+V decode
through Apple MPS SDPA via FP16 dequant caches. Add fused_flash_sdpa
reference kernel (3-50x slower than MPS, retained for research).

K+V decode: 0.15x → ~1.0x FP16 (6.8x improvement)
PQ3-KV at T=641 on 27B: 1.02x FP16 speed, 1.25x memory compression

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
… types

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Add BatchPlanarQuantKVCache lifecycle ops (prepare, finalize with
right-padding roll, filter, extend, merge, extract) and packed-state
helpers for batched inference. All ops axis-2-aware, handle mismatched
capacities between batches via zero-padding, and maintain dequant cache
invalidation.

Tests: 33 new cases covering lifecycle, packed-state manipulation,
update_and_fetch with array offsets, and decode attention.

Bench scripts: single-request, batch (B=1/2/4/8), scale validation,
E2E (memory + speed at scale).

Results on Qwen3.5-4B-MLX-4bit:
  - cos_sim = 1.000 across all prompt sizes
  - decode speed: 0.94–1.02x FP16 (parity)
  - cache memory: 13–55% smaller at runtime, 82–164x for packed storage
@sooth sooth force-pushed the feat/planarquant-dflash branch from 3244927 to fd87286 Compare April 15, 2026 18:40
@sooth sooth changed the title feat: PlanarQuant3 KV cache + DFlash speculative decoding feat: PlanarQuant3 KV cache Apr 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants