Open
Conversation
c12752d to
3244927
Compare
|
Great! I was looking for this feature. BTW have you tried this: https://liranringel.github.io/ddtree/ ? |
Author
|
Looks like someone beat me to adding dflash. Working on extracting just the
PlanarQuant3 then will look at ddtree.
…On Wed, Apr 15, 2026 at 4:54 AM Raiyan Mohamed ***@***.***> wrote:
*raiRaiyan* left a comment (jundot/omlx#757)
<#757 (comment)>
Great! I was looking for this feature. BTW have you tried this:
https://liranringel.github.io/ddtree/ ?
—
Reply to this email directly, view it on GitHub
<#757 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AANO6HGXI25MUIGM7LE2NID4V5E3VAVCNFSM6AAAAACXY7AZ7SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHM2DENJQGYYTCOBUGM>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
…ed storage, deferred quant, fused Metal SDPA Reimplements PlanarQuant3 KV cache to match upstream llama.cpp fork: - CUDA rotation constants (PLANAR_CUDA_COS/SIN_64) matching upstream benchmarks - Packed block_planar3_0 storage: 50 bytes/128-elem block (qs+signs+norm) - Deferred quantization: FP16 during prefill, bulk-convert on finalize_prefill() - Fused Metal SDPA kernels: inline dequant QK/AV, no K/V materialization - Asymmetric K/V support: quantize_v flag for speed/memory tradeoff - Auto-finalize on first decode transition via turboquant_attention patch Test results: 44/44 pass, integration cos_sim=1.000000 with Qwen3.5-4B Benchmark (Qwen3.5-4B, quantize_v=False): 0.87x FP16 decode speed, 29x attention-only memory compression Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
… 0.95x FP16 Adds planarquant3_quantize_packed Metal kernel that performs the entire quantize pipeline (normalize → Givens rotate → midpoint lookup → bit-pack) in a single GPU dispatch, replacing ~20 Python MLX ops per layer. Kernel design: one threadgroup per row, one thread per rotation pair. Thread 0 computes L2 norm, writes inv_norm to shared memory. After barrier, all threads normalize their pairs, apply Givens, do 7-comparison midpoint lookup, write indices to shared memory. Thread 0 then packs and writes output. kv_cache.py now uses _quantize() which dispatches to Metal when available. Results: bit-exact packed parity with Python reference, 0.95x FP16 decode speed (up from 0.87x), 50/50 PlanarQuant tests pass. Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
…98x FP16 The decode path was re-dequantizing the full packed K tensor every step (O(T) cost). Now we cache the dequantized K in an FP16 buffer and only append the new K row on each step (O(1) cost). The cache is invalidated on trim/state reset/finalize. This eliminates the dominant per-step overhead: - Before: dequantize K (O(T)) + quantize new K + SDPA = 0.95x FP16 - After: append 1 row to cached K + quantize new K + SDPA = 0.98x FP16 Benchmark (Qwen3.5-4B, T=106): FP16: 107.7 tok/s PQ3: 105.0 tok/s (0.98x) Memory: 1.10x compression (SSM layers dominate) At T=848: 0.95x speed, 1.25x memory compression (improving with length) Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
…— speed parity with FP16 Replace per-token K/V quantization with lazy-pack on state save, add _v_dequant_cache for quantize_v=True path, and route both K+V decode through Apple MPS SDPA via FP16 dequant caches. Add fused_flash_sdpa reference kernel (3-50x slower than MPS, retained for research). K+V decode: 0.15x → ~1.0x FP16 (6.8x improvement) PQ3-KV at T=641 on 27B: 1.02x FP16 speed, 1.25x memory compression Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
… types Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Add BatchPlanarQuantKVCache lifecycle ops (prepare, finalize with right-padding roll, filter, extend, merge, extract) and packed-state helpers for batched inference. All ops axis-2-aware, handle mismatched capacities between batches via zero-padding, and maintain dequant cache invalidation. Tests: 33 new cases covering lifecycle, packed-state manipulation, update_and_fetch with array offsets, and decode attention. Bench scripts: single-request, batch (B=1/2/4/8), scale validation, E2E (memory + speed at scale). Results on Qwen3.5-4B-MLX-4bit: - cos_sim = 1.000 across all prompt sizes - decode speed: 0.94–1.02x FP16 (parity) - cache memory: 13–55% smaller at runtime, 82–164x for packed storage
3244927 to
fd87286
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Context
oMLX runs LLMs on Apple Silicon via MLX. KV cache memory is a key bottleneck at long context lengths: on memory-constrained machines this limits max context or batch size.
This PR adds PlanarQuant3: 3-bit KV cache quantization using Givens rotation + Lloyd-Max centroids, matching the upstream rotorquant/llama.cpp architecture. It achieves speed parity with FP16 (0.94–1.02x) while compressing the attention-layer KV state by 82–164x at the packed-storage level (for SSD offload or serialization), and 13–55% runtime memory savings depending on context length.
PlanarQuant3 KV Cache
What it does
Quantizes the KV cache of attention layers to 3 bits per element using:
uint8blocks (50 bytes per 128-element block)finalize_prefill()scaled_dot_product_attentionBatchPlanarQuantKVCachewithprepare/finalize(right-padding roll),filter/extend/merge/extractlifecycle ops, packed-state helpers, and array-offset aware decodeArchitecture (6 commits)
feat(planarquant): rebuild with upstream-matching architecturefeat(planarquant): add fused Metal quantize kernelperf(planarquant): cache dequantized K for O(1) decodeperf(planarquant): v2 deferred decode quantization + V dequant cachefix(planarquant): register cache types in scheduler's known sliceable typesfeat(planarquant): batch KV cache + bench scriptsBenchmarks — Qwen3.5-4B MLX-4bit (Apple Silicon)
Single-request, varying prompt and decode length (cos_sim = 1.000000 across all):
Batch decode (64 decode steps, ~80 token prompt):
Batch-ops microbench (D=128, 3-bit, T=1024, B=8):
Per-token compression (D=128 per head):
Runtime memory is near-neutral at the attention-layer level because dequant caches are kept resident for decode speed. The packed buffers provide 82–164x compression for serialization (SSD offload, state save) and can be the sole resident form under memory pressure via
evict_dequant_caches().New files
omlx/cache/planarquant/__init__.pyomlx/cache/planarquant/constants.pyomlx/cache/planarquant/reference.pyomlx/cache/planarquant/metal_kernels.pyomlx/cache/planarquant/kv_cache.pyPlanarQuantKVCache+BatchPlanarQuantKVCacheomlx/patches/planarquant_cache.pyenable_planarquant_cache(bits, quantize_v)monkey-patchesmake_prompt_cachescripts/bench_planarquant.pyscripts/bench_planarquant_batch.pyscripts/bench_scale_validation.py,scripts/bench_e2e_validation.pyIntegration points
omlx/patches/turboquant_attention.py— Detects PQ caches, auto-finalizes prefill, routes decode throughcache.decode_attention()omlx/cache/type_registry.py— RegistersPlanarQuantKVCachefor serializationomlx/scheduler.py— Registers PQ cache types as sliceableTests
test_planarquant_constantstest_planarquant_referencetest_planarquant_activationtest_planarquant_kv_cachetest_planarquant_metal_kerneltest_planarquant_integrationtest_planarquant_batchDesign decisions
Why route through Apple MPS SDPA instead of custom Metal kernels? Custom fused quantized SDPA kernels were 3–103x slower than Apple's proprietary MPS
scaled_dot_product_attention. The solution: maintain FP16 dequant caches alongside packed PQ3 storage, route decode attention through MPS, and lazily pack decode rows only on state serialization.Why deferred/lazy-pack instead of per-token quantization? Per-token quantization costs ~0.3ms per layer × N layers. Deferring until state save means the decode hot path has zero quantize cost. The packed buffers are only needed for serialization — they can be updated lazily.
Why is runtime compression lower than storage compression? Dequant caches are kept resident for decode speed — so attention-layer RAM is near-neutral while packed storage is 82–164x smaller. Under memory pressure
evict_dequant_caches()can free the dequant caches and rebuild lazily.Why mutually exclusive with TurboQuant? Both patch the SDPA dispatch path. Having two competing patches would cause undefined behavior. The mutual exclusion is enforced in
ModelSettings.__post_init__().Future work