Skip to content

Metal kernel optimizations, ComfyUI custom node, M5/M4 Pro benchmarks#1

Open
DanielHou315 wants to merge 27 commits into
tashiscool:mainfrom
DanielHou315:main
Open

Metal kernel optimizations, ComfyUI custom node, M5/M4 Pro benchmarks#1
DanielHou315 wants to merge 27 commits into
tashiscool:mainfrom
DanielHou315:main

Conversation

@DanielHou315
Copy link
Copy Markdown

Summary

  • Drop C++ bridgefp8_bridge.cpp and setup.py removed; the native torch.mps.compile_shader() path has zero CPU round-trips and no build step required
  • ComfyUI custom node install — add the repo as a custom node (custom_nodes/fp8-mps-metal/__init__.py); the patch self-installs on startup with no ComfyUI source edits needed; supports LTX-Video 2.3 workflow and all FLUX/SD3.5 FP8 workloads
  • SGMMA kernel (Family 10 GPU) — adds a Metal matrix-unit kernel using simdgroup_matrix_multiply_accumulate, with automatic fallback to the tiled kernel on older GPUs; enables hardware-accelerated FP8 matmul on M4/M5 chips
  • Kernel optimizations (P1–P18) — tiled matmul with threadgroup memory, 256-entry LUT FP8 decode, coalesced vecmat access, fused scale+dequant, FP16 activation fast-path, transposed-B cache in monkey-patch, and more; collectively ~2–5x faster than the original implementation
  • M5 Pro benchmarks — measured on M5 Pro (64GB, PyTorch 2.11); at K=14336 the FP8 fused kernel beats FP16 native (0.98 ms vs 1.50 ms); 7–22x faster than CPU fallback

Test plan

  • cd tests && uv run pytest test_correctness.py -v — all 25 correctness tests pass
  • uv run python bench_ai_workloads.py — confirm perf numbers on your hardware
  • Load a FLUX or LTX-Video 2.3 FP8 model in ComfyUI with the custom node installed and verify no CPU fallback warnings

🤖 Generated with Claude Code

DanielHou315 and others added 27 commits April 8, 2026 14:56
…PEEDUP.md

Auto-selector: crossover benchmarks show fused kernel wins at M=1-4 but
loses to fast path (dequant+native matmul) from M=8+. Old threshold M<=16
routed M=5..16 to the slower untiled fused kernel.
Improvement for M=8..16: 12-58% faster depending on K/N.

Also adds:
- tests/ with correctness suite (25 tests) and AI workload benchmarks
- CLAUDE.md project instructions
- SPEEDUP.md benchmark results and P1-P7 performance analysis
- .gitignore update for tests/.venv

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace fp8_e4m3fn_to_float() inline function (2 branches + exp2()
per element) with a precomputed constant-address-space LUT.

256 float values cover all possible uint8 FP8 e4m3fn bit patterns.
Zero runtime init cost, no threadgroup barriers, cached per GPU core.

Fused kernel: 13-18% faster across all M sizes.
Dequant kernel (fast path): 3-5% faster.
M=1 decode FFN: 0.85x → 0.73x vs FP16 (larger win margin).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add fp8_to_scaled_half_kernel that applies output = half(lut[input] * scale)
in one pass, replacing separate dequant + elementwise scale multiply.

Reduces fast path from 6 GPU dispatches to 4 per matmul call.
Fast path 15-24% faster for large-N shapes (FFN layers).
Decode FFN ratio vs FP16 now 0.69x (was 0.73x).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Restructure inner loop so consecutive SIMD lanes access consecutive
bytes (one cache line per iteration). Old pattern strided by 128 bytes
between lanes, preventing memory coalescing.

M=1 decode FFN (N=14336): 6% faster, 0.66x vs FP16.
No regressions on other shapes.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Pre-dequantize FP8 weights to scaled FP16 once, reuse across calls.
Auto selector detects float16 B and skips per-call dequant.

Prefill (M>=4): 1.4-2.0x faster with prepared weights.
  prefill128/ffn: 2.05ms → 1.03ms (1.99x)
  prefill128/qkv: 1.02ms → 0.55ms (1.85x)
Decode (M=1): prepared path is slower (0.49ms vs 0.26ms fused)
  because vecmat kernel with uint8 avoids FP16 intermediates.
  Users should use unprepared path for M=1 autoregressive decode.

Adds 6 new correctness tests (31 total) and prepared-weight benchmarks.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace untiled one-thread-per-element matmul with 16x16 tiled blocked
GEMM using threadgroup memory. Each tile iteration cooperatively loads
16x16 blocks of A and B, then all threads reuse them.

Reduces global memory traffic from O(M*N*K) to O(M*K + N*K).
Fused kernel 67-68% faster for prefill sizes (M>=128).
  M=128 K=N=4096: 10.4ms → 3.5ms
  M=128 K=4096 N=14336: 35.5ms → 11.5ms
  M=512 K=N=4096: 40.5ms → 13.1ms

Tiled kernel now beats fast path up to M=16 (was M=4).
Auto threshold updated M<=4 → M<=16.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add deprecation notices to fp8_bridge.cpp, setup.py, and pybind11 module
docstring. The native path (fp8_mps_native.py via torch.mps.compile_shader)
is zero-copy and significantly faster. C++ bridge kept for PyTorch < 2.10.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
All 7 performance optimizations complete.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
11 optimizations in 4 tiers targeting the remaining 1.2-1.65x gap to FP16.
Key design: single-library conditional compilation (#if __METAL_VERSION__)
with dir(lib) runtime detection. SGMMA kernel compiles on M2+ (MSL 3.0),
scalar tiled kernel remains as M1 fallback. No try/except needed.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Eliminates one GPU kernel dispatch per fast-path call.
The monkey-patch already handles out_dtype casting.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Mirrors existing B.dtype check. When A arrives as FP16 (from LayerNorm,
attention output), the fast path uses it directly instead of dispatching
a dequant kernel. Auto selector routes FP16 A to fast path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace fp8_to_half_kernel + elementwise scale with single
fp8_to_scaled_half_kernel call. Eliminates 1 dispatch per dequantize.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Cache other.t().contiguous() result keyed by data_ptr. Avoids
reallocating N*K weight copy on every _scaled_mm call.
Cache cleared on uninstall().

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Change threadgroup tiles from float to half, halving threadgroup
memory (2KB → 1KB). Accumulator stays float32. FP8 values (max ±448)
fit losslessly in FP16 range. Improves occupancy by allowing more
concurrent threadgroups.

Benchmarks show 22-25% improvement at prefill sizes (M=128-2048)
and 32-48% improvement for small M=1 square matmuls. Accuracy
unchanged (37/37 tests pass).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add float_to_fp8_scaled_kernel that does input*scale + FP8 encode in
one pass. Eliminates separate elementwise multiply and .contiguous()
copy from fp8_quantize (2 fewer GPU dispatches).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Uses simdgroup_matrix_multiply_accumulate (Metal 3 / M2+) for ~2x
higher ALU throughput on the fused FP8 matmul path. Conditionally
compiled via #if __METAL_VERSION__ >= 300. Falls back to scalar tiled
kernel on M1 (MSL 2.x).

Runtime detection via dir(lib) — single shader library, no try/except.
BM=BN=32, BK=16, 4 simdgroups per threadgroup, each computing a 16x16
sub-tile via 2x2 grid of 8x8 SGMMA ops with float32 accumulators.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
P17: Move fp8_quantize calls outside timed loop in layer benchmark.
P18: Replace gid/32 with tgid_x*8+simd_group in vecmat kernel.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
All 18 optimizations complete. Key results:
- SGMMA fused kernel 88-90% faster than original for prefill
- Prepared weights within 10-15% of native FP16 at M=512+
- M=1 decode FFN: FP8 wins by 24-34% vs FP16

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Update README with proper custom node install instructions (replacing
the outdated pip + manual import approach). Add ComfyUI integration
section to CLAUDE.md with note about the hardcoded source path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… for MPS FP8

comfy_kitchen (ComfyUI's FP8 backend) uses torch.nn.functional.scaled_mm
(PyTorch 2.10 new API) instead of torch._scaled_mm, bypassing the existing
patch. When that fails, it falls back to dequantize_per_tensor_fp8 which
calls x.to(output_dtype) on an FP8 MPS tensor — also unsupported.

Add two additional intercepts in install():
- torch.nn.functional.scaled_mm → routes FP8+MPS to Metal kernel
- comfy_kitchen.backends.eager.quantization.dequantize_per_tensor_fp8
  → views FP8 as uint8 and calls fp8_mps_native.fp8_dequantize

Also refactor shared logic into _to_uint8, _make_mps_fp8_result,
_transpose_cached helpers to reduce duplication between the two matmul
intercept paths.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
comfy_kitchen's dequantize call chain:
  fp8.py → ck.dequantize_per_tensor_fp8()   ← public API in __init__.py
           → torch.ops.comfy_kitchen.dequantize_fp8  ← custom PyTorch op
           → registry.get_implementation()   ← holds direct ref to eager fn

Patching the backend module attribute is invisible to the registry.
Patch comfy_kitchen.dequantize_per_tensor_fp8 at the module level instead,
which intercepts the call before it reaches the custom op entirely.

Also: the primary failure mode for LTX is that activations are bfloat16
while only weights are FP8 QuantizedTensors, so _handle_fp8_linear skips
_scaled_mm and goes straight to dequantize on every forward pass.
The dequantize fix is the critical path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Drop fp8_bridge.cpp and setup.py — the native path (torch.mps.compile_shader)
has been the only active code path since it was introduced. Update all docs to
reflect single-path architecture and add measured M5 Pro performance numbers
alongside the existing M4 Pro reference table.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Previous measurement used wrong shape (K=4096 not K=14336) and wrong
methodology (CPU-only float32 vs original test's dequant+CPU→MPS transfer).
Re-ran with identical shapes and methodology as M4 Pro baseline for an
apples-to-apples comparison.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Previous measurement used per-call torch.mps.synchronize() which inflates
small-matmul latency by paying full MPS dispatch overhead every iteration.
Re-ran with amortized timing (single sync pair over 20 iters) matching the
original M4 Pro methodology exactly.

FP16 baseline now matches expected M5/M4 parity (~15% faster).
FP8 fused beats FP16 at K=14336 where weight size dominates.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant