Skip to content

Port the Qwen3-Omni multimodal encoders into M* and optimize them #131

@NSagan271

Description

@NSagan271

Difficulty: 🟡 Intermediate (easier end of medium)

Scope: Medium; contained to the Qwen3-Omni audio/vision encoder submodules plus their weight loading. No changes to the scheduler or engine core.

Subsystems: model/qwen3_omni/ · model/components/

Prerequisites: Familiarity with the Qwen3-Omni vision (ViT + spatial merge) and audio (Whisper-style) encoders, and the M* NodeSubmodule pattern. No scheduler/conductor knowledge needed.

Problem

Qwen3-Omni's backbone (the Thinker) is already a native M* implementation
(ThinkerSubmodule in qwen3_omni/submodules.py
wraps a FlashInfer-based MoE transformer). The multimodal encoders are not
AudioEncoderSubmodule and VisionEncoderSubmodule (same file) are thin
wrappers that just call the upstream HuggingFace module (self.audio_encoder(...)
/ self.vision_encoder(...)) and reshape the output. torch.compile is then
applied to them naively and uniformly by the engine (see issue #3), with no
encoder-specific optimization.

Both encoders also currently run once per request, not batched across
requests
(noted in their docstrings), which leaves throughput on the table when
multiple multimodal requests are in flight.

Why it's worth doing

  • Performance: a first-class M* implementation can be batched, CUDA-graph
    captured, and given an attention path suited to the encoder's fixed/known
    shapes, rather than relying on whatever the HF forward does under a blanket
    torch.compile.
  • Consistency: brings the encoders in line with the already-native Thinker, Talker, and Code2Wav, and removes a hard dependency on HF internals that can break across transformers
    versions.

Suggested tasks

  • Reimplement the vision encoder (ViT + spatial merge) and the audio encoder
    as native M* submodules, reusing shared building blocks in
    model/components/ where they exist, following the
    ThinkerSubmodule as the template for "native, optimized submodule."
  • Wire up weight loading from the HF checkpoint into the native modules
    (state-dict remap).
  • Add batching across requests where the shapes allow, and CUDA-graph
    capture / an appropriate attention path for the encoder.
  • Validate parity against the HF encoder outputs (numerical closeness), then
    benchmark encoder latency/throughput before vs. after.

Acceptance criteria

  • Native encoders produce outputs matching the HF encoders within tolerance for
    reference image/audio inputs.
  • A before/after benchmark shows the encoder path is at least as fast (ideally
    faster, especially with multiple concurrent multimodal requests).
  • The HF-wrapper encoder path can be removed (or kept only as a reference/fallback).

Gotchas

  • Watch the DeepStack intermediate features the vision encoder returns for the
    Thinker — the native version has to reproduce them, not just the final
    embeddings.
  • Encoder outputs feed directly into the Thinker's KV cache splicing
    (prefill_vision / prefill_audio walks); keep the output contract identical
    so the Thinker side needs no changes.
  • Don't assume the standard optimizations uniformly help here; torch.compile, CUDA-graph
    capture, and cross-request batching don't always pay off (and can even hurt,
    e.g. when shapes vary enough to trigger recompiles, or batching adds padding
    overhead). A/B test each optimization against the eager/uncompiled baseline and
    keep only the ones that actually win for this encoder.

New to M*? Skim How it works and the Contributing guide first.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No fields configured for Task.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions