Skip to content

Pluggable attention backends #129

@NSagan271

Description

@NSagan271

Difficulty: 🔴 Advanced

Scope: Large — touches the hot path and the CUDA-graph capture story; needs a clean abstraction + perf parity.

Subsystems: engine/cache_manager.py · utils/flashinfer_utils.py · engine/kv_cache_engine.py · engine/cuda_graph_runner.py

Prerequisites: Attention kernels (FlashInfer / FlashAttention), paged KV cache, and CUDA-graph-capturable launch patterns.

Problem

There's effectively one KV-cache attention backend: FlashInfer paged attention.
It's a good default but not always the most performant choice. The wrappers
FlashInferPrefillWrapper / FlashInferDecodeWrapper
(utils/flashinfer_utils.py) are wired directly
into the cache manager's plan_attention and stored on _PlanState.wrapper
(engine/cache_manager.py). The key constraint is
that whatever backend we use must support the plan-then-replay scheme so it's
CUDA-graph capturable (persistent wrappers whose static buffers are updated by
plan() — see the CUDA-graph-mode notes in
cache_manager.py and the double-buffer logic in
kv_cache_engine.py).

This matters in practice: for some shapes a non-paged kernel (e.g.
flash_attn_varlen, or FlashInfer's ragged prefill) beats paged attention,
because the paged path also pays to scatter-write K/V into the page table. Part
of this issue is building a small microbenchmark that compares paged vs. ragged
vs. flash_attn_varlen per shape to quantify the gap.

Suggested tasks

  • Extract the wrapper interface the cache manager depends on (plan(),
    run(), and the static-buffer accessors like _qo_indptr_buf that
    cache_manager.py reads) into an explicit backend protocol.
  • Add a FlashAttention backend implementing that protocol with the same
    plan + cuda-graph-capturable forward scheme.
  • Make the backend selectable (config / YAML), defaulting to FlashInfer paged
    so existing deployments are unchanged.
  • Benchmark backends per model/walk (extend the existing bench script) and
    document when each wins.

Stretch / open

  • Some submodules may not benefit from paged attention at all (e.g. short,
    fixed-shape full-attention blocks). Investigate a non-paged path for those; this can be scoped as a follow-up.
  • An "in-house" kernel set is out of scope for a first PR, but feel free to put out a separate PRs for any kernel development.

Acceptance criteria

  • FlashInfer remains the default with no perf regression.
  • The FlashAttention backend produces correct results, is CUDA-graph capturable,
    and a benchmark shows the regime where it's competitive/better.

New to M*? Skim How it works and the Contributing guide first.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request
    No fields configured for Feature.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions