Skip to content

feat: MLX VRAM & performance optimizations for Apple Silicon#8

Draft
cmoyates wants to merge 3 commits intomainfrom
feat/misc-optimizations
Draft

feat: MLX VRAM & performance optimizations for Apple Silicon#8
cmoyates wants to merge 3 commits intomainfrom
feat/misc-optimizations

Conversation

@cmoyates
Copy link
Collaborator

@cmoyates cmoyates commented Mar 7, 2026

Summary

  • FP16 mixed precision: decoders cast to FP16 (backbone+refiner stay FP32 for accuracy)
  • Eliminate CPU-GPU roundtrips: MLX-native preprocessing, slim forward mode (4 outputs vs 9)
  • Decoupled backbone resolution: backbone at 1024 while refiner runs at full 2048 — 6.3x faster, 3.4x less peak memory
  • Hardened tiled inference: 96px overlap, lazy tile slicing, per-tile cache clearing, bounded memory growth
  • Engine API: new params fp16, backbone_size, tile_size, tile_overlap with validation
  • Bug fix: mask resize now handles mismatched image/mask dimensions

Benchmark Highlights (M3 Max, 2048x2048)

Config Median Peak MB
Baseline (FP32, full backbone) 9,068 ms 27,560
Backbone 1024 1,447 ms 8,044
Tiled 512 3,467 ms 2,626
All optimizations (FP16+bb1024+tiled) 8,450 ms* 2,301

*wall time with real images + resize overhead

Test plan

  • 115 tests pass, 0 failures
  • Resolution benchmark (256/512/1024)
  • Engine config matrix (fp16/fp32 × backbone × tiled at 2048)
  • 2048 smoke test with all optimizations — PASSED
  • Full results in docs/benchmarks/RESULTS.md

🤖 Generated with Claude Code

cmoyates and others added 3 commits March 7, 2026 18:59
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…se 0)

- memory_snapshot() / reset_peak() helpers in utils/profiling
- bench_mlx.py: peak/active MB columns per resolution
- smoke_2048.py: replace try/except fallback with direct mx.* APIs
- baseline-fp32.md: 512/1024/2048 numbers (26.7GB peak at 2048)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…er FP32

Full FP16 exceeded 1e-3 tolerance (backbone drift 2.5e-3, refiner 10x
scale → 1.3e-2). Decoder-only FP16 stays within bounds.

- Add fp16 param to load_model() + _cast_model_fp16() helper
- 6 parity tests: 4 output keys + validity + FP32 opt-in check
- Update plan with empirical mixed precision findings

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@cmoyates cmoyates marked this pull request as draft March 7, 2026 23:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant