Skip to content

Per-submodule control over torch.compile #125

@NSagan271

Description

@NSagan271

Difficulty: 🟡 Intermediate (partly open-ended; needs design buy-in first)

Scope: Medium; small code change, but the API design is the hard part.

Subsystems: engine/stateless_engine.py · model/submodule_base.py · model authors

Prerequisites: Familiarity with torch.compile modes/dynamic shapes and how recompiles get triggered.

Problem

torch.compile(dynamic=None) is applied uniformly across submodules by the
engine (the compile call lives in
stateless_engine.py and kv_cache_engine.py). The only knob a model
author has today is a coarse, all-or-nothing escape hatch: @torch.compiler.disable
on individual methods — used on the prepare_inputs / postprocess hooks in
model/submodule_base.py, and by models that need
to fence off a graph-breaking region (e.g. the BAGEL ViT encoder in
bagel/components/vit_encoder.py and
the Qwen3-Omni talker in
qwen3_omni/components/talker.py).

So it's on the author to manually fence off anything that would thrash the
compile cache (data-dependent loops, varlen shapes), and there's no middle
ground between "fully compiled with dynamic=None" and "not compiled at all."
There's no way to say "compile this submodule, but with dynamic=True" or "use
mode='max-autotune' here" or "compile only forward, not forward_batched".

Open questions (resolve before coding)

  • Is per-submodule torch.compile config actually worth the surface area, or is
    the current @torch.compiler.disable escape hatch good enough for the models
    we care about?
  • What's the right granularity — per submodule, or per method?
  • What should the knobs be? (enabled, dynamic, mode, fullgraph,
    per-method include/exclude?)

Suggested approach (if we proceed)

  • Let a submodule declare a compile spec (e.g. a
    get_torch_compile_config() returning per-method options, defaulting to
    today's behavior).
  • Have the engine read that spec where it currently compiles (the
    torch.compile call in stateless_engine.py)
    instead of hardcoding dynamic=None.
  • Express the existing @torch.compiler.disable escape hatch in terms of the
    new spec (so "disable compilation for this submodule/method" becomes one
    option among several), and migrate any special-cased submodule onto it.

Acceptance criteria

  • Existing models compile identically (no new recompiles, no perf regression) by
    default.
  • At least one submodule demonstrates a non-default setting (e.g. dynamic=True)
    end-to-end.

New to M*? Skim How it works and the Contributing guide first.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request
    No fields configured for Feature.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions