Skip to content

feat: add MLX inference for Apple Silicon (CTC + RNNT)#62

Open
misteral wants to merge 1 commit intosalute-developers:mainfrom
misteral:feat/mlx-clean
Open

feat: add MLX inference for Apple Silicon (CTC + RNNT)#62
misteral wants to merge 1 commit intosalute-developers:mainfrom
misteral:feat/mlx-clean

Conversation

@misteral
Copy link
Copy Markdown

Summary

Native MLX inference for GigaAM v3 on Apple Silicon — supports both CTC (139× realtime) and RNNT (48× realtime, ~9% lower WER) models.

Supersedes #61 (cleaned up: rebased on main, removed dev artifacts).

What's included

File Description
mlx_convert/gigaam_mlx.py Full MLX model: Conformer encoder (16 layers, 768d, RoPE), CTC + RNNT heads, mel spectrogram, streaming
mlx_convert/convert_gigaam_to_mlx.py PyTorch → MLX conversion (safetensors + config.json)
mlx_convert/gigaam-cli Single-file transcription CLI
mlx_convert/gigaam-stream Real-time streaming (live mic + file)
mlx_convert/gigaam-transcribe Shell wrapper
mlx_convert/README.md Documentation: Python API, CLI, benchmarks

Architecture

Audio (16kHz) → Log-Mel (64 bins) → Conv1d Subsampling (4×)
  → 16× Conformer (RoPE MHSA + GLU Conv + SiLU FFN)
  → CTC Head → Greedy Decode
    or
  → RNNT Head (Joint + LSTM Decoder) → Greedy Decode

Key details:

  • RoPE applied before Q/K/V projections (matching original PyTorch model)
  • Mel filterbank saved from PyTorch (exact match, no recomputation drift)
  • All Conv1d weights transposed: [out, in, K][out, K, in] for MLX
  • RNNT LSTM weights properly mapped to MLX layout

Performance (Apple M4)

Model Batch (11s audio) Realtime factor Model size (fp16)
CTC 81ms 139× 421 MB
RNNT 230ms 48× 423 MB

Python API

from gigaam_mlx import load_model, load_audio

# CTC (fastest)
model = load_model("./gigaam-v3-ctc-mlx")
text = model.transcribe(load_audio("audio.wav"))

# RNNT (higher quality)
model = load_model("./gigaam-v3-rnnt-mlx")
text = model.transcribe(load_audio("audio.wav"))

# Streaming
for r in model.stream_generate(load_audio("audio.wav")):
    print(r.cumulative_text)

Testing

Tested on Apple M4 with various Russian speech samples. Output matches PyTorch reference (character-level exact match on short utterances).

Add native MLX inference for GigaAM v3 on Apple Silicon:

- Full Conformer encoder (16 layers, 768d) with RoPE attention
- CTC and RNNT heads with greedy decoding
- PyTorch → MLX weight conversion (safetensors + config.json)
- Streaming transcription (growing buffer, live mic + file)
- CLI tools: gigaam-cli, gigaam-stream, gigaam-transcribe
- Python API: load_model, transcribe, stream_generate, stream_live

Performance on Apple M4:
- CTC: 139x realtime (11s audio in 81ms), fp16 421 MB
- RNNT: 48x realtime (11s in 230ms), ~9% lower WER than CTC

New files in mlx_convert/:
  gigaam_mlx.py              — MLX model + inference + streaming
  convert_gigaam_to_mlx.py   — PyTorch → MLX conversion
  gigaam-cli                 — single-file transcription CLI
  gigaam-stream              — real-time streaming CLI
  gigaam-transcribe          — shell wrapper
  README.md                  — documentation, API, benchmarks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant