feat: add full MPS (Apple Silicon) support #288
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR enables true MPS GPU acceleration for macOS Apple Silicon users, providing a 44x speedup over CPU fallback.
Performance Benchmarks
Background
PR #167 added
example/simple-mac.pyas a Mac support solution, but it was actually a CPU fallback workaround that avoided MPS issues rather than fixing them. This PR addresses the root causes to enable true MPS GPU acceleration.Changes
1. Guard CUDAGraph Call (Critical)
File:
dia/model.py:701-702Why:
cudagraph_mark_step_begin()is CUDA-specific and has no MPS equivalent. Called ~860 times per generation in the autoregressive loop.2. Add Device Check to torch.compile (Critical)
File:
dia/model.py:657-667Why:
mode="max-autotune"requires Triton which is CUDA/ROCm only. MPS torch.compile is still experimental (PyTorch #150121).3. Fix CLI Device Auto-Detection
File:
cli.py:12-18, 82-83Why: Mac users running
python cli.py "text" --output out.wavnow automatically get MPS acceleration instead of slow CPU fallback.4. Fix simple-mac.py dtype
File:
example/simple-mac.py:4-6Why: Matches
app.pyrecommendation. MPS float16 has documented precision bugs and provides minimal speedup on Apple Silicon (no dedicated Tensor Cores).5. Guard Triton Config in Benchmark
File:
example/benchmark.py:9-13Why: Triton is only available on Linux/Windows per
pyproject.toml. These settings fail on macOS.6. Add MPS Seed Management
Files:
cli.py:30-32,app.py:66-68Why: Explicit MPS RNG seeding for reproducibility (though MPS reproducibility has known limitations).
7. Add ARM64 Linux Dockerfile
File:
docker/Dockerfile.arm(new)Adds support for ARM64 Linux servers (AWS Graviton, Ampere Altra, etc.) with CPU-only inference.
Note: Docker cannot access MPS on macOS (runs in Linux VM). macOS users must install natively for GPU acceleration.
Test Plan
python example/simple-mac.pyruns without errors on MPSpython cli.py "[S1] Hello world." --output test.wavauto-detects MPSReferences
🤖 Generated with Claude Code