Skip to content

fix(oq): chunked load/quantize and streaming VLM sanitizer for huge MoE models#737

Open
yohann-bearzi wants to merge 3 commits intojundot:mainfrom
yohann-bearzi:main
Open

fix(oq): chunked load/quantize and streaming VLM sanitizer for huge MoE models#737
yohann-bearzi wants to merge 3 commits intojundot:mainfrom
yohann-bearzi:main

Conversation

@yohann-bearzi
Copy link
Copy Markdown

PR description

…oE models

Quantizing Qwen3.5-397B-A17B (and similar large MoE checkpoints) on Apple
Silicon failed in two ways:

1. mx.load and mx.quantize each issue a single Metal dispatch per tensor,
   which exceeds the command-buffer timeout on 512x2048x4096 expert tensors.
2. mlx-vlm's Model.sanitize() returns a transformed dict containing every
   weight, which OOMs a 512 GB Mac on a 397B-parameter model.

Changes:

* Replace eager mx.load with _LazyTensorIndex, a memory-mapped view over
  safetensors files. Tensors are read on demand via _LazyTensor._load_rows,
  which sub-chunks numpy->MLX conversion to stay under both the device's
  max_buffer_length (queried via mx.device_info) and MLX's int32 element
  count limit. Chunk budgets scale with hardware: ~14 GiB per chunk on M3
  Ultra, ~875 MiB on M1, with safe fallbacks if Metal info is unavailable.

* Add _quantize_chunked, a drop-in replacement for mx.quantize that
  bisects on dim 0 and concatenates the per-chunk results. Same buffer
  and element-count budgets. mx.synchronize + mx.clear_cache between
  chunks drains the command queue.

* Add _StreamingPlan, a streaming sanitizer for VLM models that builds a
  per-output-tensor transformation plan from the lazy index without
  materializing any weights. Implements the Qwen3.5 MoE Model.sanitize
  logic (drop mtp.*, optional lm_head tied-embedding drop, fused
  gate_up_proj split on axis -2, model.language_model -> language_model.model
  and model.visual -> vision_tower renames, lm_head -> language_model.lm_head,
  conv1d.weight axis (2,1) permute, +1.0 on 1D norm weights, and patch_embed
  Conv3d (out,in,T,H,W) -> (out,T,H,W,in) permute). Quantize loop pulls
  one tensor at a time via pop(), peak RAM stays bounded.

Tested end-to-end on M3 Ultra 512GB with Qwen3.5-397B-A17B oQ4: model
loads directly in mlx-vlm with no post-hoc converter, generates coherent
output at ~30 tok/s, peak memory 229 GB.

LLM-only models still go through the original _build_model_sanitizer
path; only VLM checkpoints (architectures containing 'ForConditionalGeneration')
use the streaming plan.
@yohann-bearzi
Copy link
Copy Markdown
Author

I quantized Qwen3.5-397B-A17B with oMLX and you can check out what it produces at https://huggingface.co/collections/bearzi/qwen35-397b-a17b-oq

@yohann-bearzi
Copy link
Copy Markdown
Author

yohann-bearzi commented Apr 12, 2026

I'm quite busy so I may be slow to respond. If it's stalling, I am fine with letting someone take over from there. Just leave me as author on the original commit and append your tweaks on follow up commits.

@jundot jundot force-pushed the main branch 4 times, most recently from 4eb9c29 to 7376c8e Compare April 14, 2026 09:08
Copy link
Copy Markdown
Owner

@jundot jundot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work on this, the chunked load/quantize approach is exactly what huge MoE models need. I found a few things that should be addressed before merging though.

The biggest issue is that _StreamingPlan activates for every ForConditionalGeneration architecture, but the transform rules inside are Qwen3.5-specific. So if someone quantizes a Gemma4 or LLaVA model, the rename logic would corrupt weight names silently. For example Gemma4's model.language_model.model.layers.* would become language_model.model.model.layers.* with a double prefix. This needs to be scoped to Qwen3.5 MoE only and let other VLMs keep using the existing _build_model_sanitizer path.

Related to that, _LazyTensorIndex.pop() returns a _LazyTensor object instead of mx.array. This is fine in the main quantize loop where you added the isinstance check, but sanitizers from mlx-lm/mlx-vlm don't know about _LazyTensor. Something like mx.stack([weights.pop(k) for ...]) in Qwen3 MoE's sanitize would blow up. I think pop() should just materialize like __getitem__ does.

A smaller one: __iter__ only yields _index keys but keys() includes _overrides too. If a sanitizer writes back with weights[k] = v, those keys become invisible during iteration.

Since this is low-level binary parsing and chunked GPU operations, some basic tests would really help. A roundtrip test for _LazyTensorIndex, a comparison test between _quantize_chunked and mx.quantize, and a key mapping test for _StreamingPlan would catch the kind of silent corruption bugs that are hardest to debug later.

The core primitives here are genuinely useful and i'd love to see this merged once these are sorted out. Happy to help with follow-up commits if you need.

Replaces Qwen3.5-specific _StreamingPlan with a generic discovery mechanism
that runs the real Model.sanitize() on _TrackedTensor proxies. The proxies
record shape/dtype/lineage without materializing GPU data, and a set of
monkey-patched mx ops (stack/concatenate/split/moveaxis/transpose) capture
the transforms. Result is a plan of output_key -> {sources, transform, shape}
that _DiscoveredPlan materializes one tensor at a time with chunked stacking.

Addresses review feedback on jundot#737:

- _StreamingPlan no longer corrupts non-Qwen VLMs — it's not even in the
  activation path anymore. Discovery handles every model mlx-lm/mlx-vlm
  supports (tested on Gemma 4 E2B, Trinity Nano AfMoE, Qwen 3.5 397B MoE).

- _LazyTensorIndex.pop() now materializes to mx.array instead of returning
  _LazyTensor, so third-party sanitizers that call mx.stack on popped
  tensors work correctly.

- _LazyTensorIndex.__iter__ and items() now include _overrides keys so
  sanitize-written tensors are visible during iteration.

- _LazyTensor.__getitem__ and _materialize_source now handle 0-dim scalars
  (needed for Gemma 4's scaling factors).

Tested end-to-end:
- Gemma 4 E2B oQ8: generates coherent text
- Qwen 3.5 397B: unchanged behavior (discovery produces same plan
  _StreamingPlan did)
@yohann-bearzi
Copy link
Copy Markdown
Author

It's failing on MiniMax M2.7, I'm patching it then I'll add a commit for it.

…port

Replaces Qwen3.5-specific _StreamingPlan activation with a generic
discovery mechanism that works for any model architecture:

Discovery-based streaming sanitizer:
- _TrackedTensor: fake tensor proxy that records shape/dtype/lineage
  during a sanitize() dry run. Supports reshape, astype, arithmetic,
  None-broadcasting indexing, and slice patterns.
- _discover_sanitize_plan(): runs the real Model.sanitize() on tracked
  tensors with monkey-patched mx ops (stack/concatenate/split/moveaxis/
  transpose/from_fp8/pad/eval/clear_cache). Produces a transform plan
  without materializing any GPU data. Cost: <1s even on 42K-tensor models.
- _DiscoveredPlan: dict-like wrapper that materializes one tensor at a
  time using the discovered plan, with chunked stacking (16 experts per
  chunk) to bound peak memory on large MoE models.
- Graceful fallback to eager sanitize if discovery fails.

FP8 source model support (MiniMax-M2.7, DeepSeek FP8, etc.):
- _LazyTensor: F8_E4M3 and F8_E5M2 dtype support — loaded as uint8
  so sanitize can call mx.from_fp8() on them.
- _streaming_fp8_dequant(): processes FP8 weight/scale_inv pairs one
  at a time, runs block-scaled dequant (from_fp8 + pad + reshape +
  scale multiply + slice), writes bf16 results to scratch safetensors
  shards on disk, and re-indexes the lazy loader. Peak RAM bounded to
  one tensor at a time regardless of model size.
- FP8 sources bypass discovery (dequant chain is too complex to replay)
  and use eager sanitize after streaming dequant completes.

Other fixes:
- _LazyTensorIndex.pop() materializes mx.array instead of returning
  raw _LazyTensor objects.
- _LazyTensorIndex.__iter__ and items() include _overrides keys.
- _LazyTensor.__getitem__ and _materialize_source handle 0-dim scalars
  (needed for Gemma 4 scaling factors).

Tested end-to-end on M3 Ultra 512GB:
- Gemma 4 E2B oQ2-8: coherent output at all levels
- Trinity Nano Preview (AfMoE) oQ4-8: coherent output
- Qwen 3.5 397B oQ2-8: unchanged behavior
- MiniMax-M2.7 (FP8 source): streaming dequant completes, oQ8 builds
@yohann-bearzi
Copy link
Copy Markdown
Author

Ready for new review. I'm going to work on making GLM 5.1 quantizable. Hopefully I'm lucky.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants