Skip to content

Conversation

@qsang-nv
Copy link
Collaborator

@qsang-nv qsang-nv commented Sep 25, 2025

📌 Description

Add xqa fp8 mha and fp8 kv cache. Add fp8 mla for sm120. Use vllm kv layout.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • MLA-based attention path (SM120/121) with MLA-specific modules and entrypoints
    • FP8 KV-cache support with optional paged KV layout and separate K/V inputs
    • New async tensor-map/TMA and matrix-descriptor primitives for high-throughput GPU transfers
    • Build/runtime APIs now accept dtype-driven configs and explicit SM-version selection
  • Bug Fixes

    • Improved numerical stability for attention mask initialization
  • Tests

    • Expanded tests covering MLA, FP8/FP16 modes and new cache layouts
  • Documentation

    • Added XQA API docs and documented new entrypoints

@qsang-nv qsang-nv changed the title add xqa fp8 mha add xqa fp8 mha and fp8 kv cache Sep 25, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the FlashInfer library by integrating FP8 support for both Multi-Head Attention computations and the Key-Value cache within the XQA framework. These changes are primarily aimed at leveraging the advanced capabilities of NVIDIA Hopper GPUs (SM90+) to achieve substantial performance and memory efficiency gains. The implementation includes new CUDA kernels utilizing GMMA and TMA, along with Python-side modifications to enable configurable FP8 execution paths, ensuring that users can opt into these optimizations while maintaining numerical stability.

Highlights

  • FP8 Multi-Head Attention (MHA) Support: Introduced support for FP8 (8-bit floating point) Multi-Head Attention, enabling more efficient computations on compatible hardware.
  • FP8 Key-Value (KV) Cache: Added functionality for using FP8 for the Key-Value cache, which can significantly reduce memory footprint and improve performance.
  • NVIDIA Hopper (SM90+) Optimizations: Integrated specialized Hopper GPU features like GMMA (General Matrix Multiply Accumulate) and TMA (Tensor Memory Accelerator) for optimized FP8 operations and efficient memory access patterns.
  • Configurable FP8 Execution: The XQA module generation and runtime execution now allow explicit control over whether FP8 MHA and FP8 KV cache are utilized, providing flexibility for different precision requirements.
  • Numerical Stability Adjustments: Modified the safeInitRowMax value and adjusted test tolerances to account for the numerical characteristics of lower precision FP8 computations, ensuring stability.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the FlashInfer library by introducing support for FP8 Multi-Head Attention (MHA) and FP8 Key-Value (KV) cache. These additions leverage advanced features of NVIDIA Hopper GPUs, such as GMMA and TMA, to achieve higher performance and memory efficiency for large language model inference. The changes span the CUDA C++ backend, including new kernel implementations and memory management utilities, as well as updates to the Python AOT compilation and testing framework to ensure robust integration and validation of the new FP8 capabilities.

Highlights

  • FP8 MHA Integration: Introduced a new "run_fp8_mha" boolean parameter to the "xqa_wrapper" function, enabling conditional execution of FP8 Multi-Head Attention kernels.
  • FP8 KV Cache Support: Added "fp8_kv_cache" parameter to the AOT compilation and Python interface, allowing the use of FP8 for Key-Value cache storage.
  • New Low-Level CUDA Kernels: Incorporated "gmma.cuh" for Generic Matrix Multiply Accumulate (GMMA) operations and "tma.h" for Tensor Memory Accelerator (TMA) asynchronous memory transfers, leveraging Hopper (SM90) architecture features for optimized FP8 performance.
  • CUDA Tensor Map Utilities: Introduced "tensorMap.cpp" and "tensorMap.h" to facilitate efficient access and management of KV caches using CUDA Tensor Maps.
  • Numerical Stability Improvements: Adjusted "safeInitRowMax" in "utils.cuh" and modified attention mask application in "mha.cu" to enhance numerical stability, especially relevant for lower precision formats.
  • Expanded Testing: Updated "test_xqa.py" to include comprehensive tests for FP8 MHA and FP8 KV cache, with adjusted numerical tolerance checks to account for reduced precision.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP8 multi-head attention (MHA) and FP8 KV cache within the XQA kernel, primarily targeting the NVIDIA Hopper architecture. This is a significant feature addition, enabled by new CUDA primitives for Hopper's Tensor Memory Access (TMA) and Grace Hopper MMA (GMMA) instructions. The changes are well-implemented, including new CUDA headers for hardware abstraction, a dispatch mechanism for the new FP8 kernel path, and corresponding updates to the Python build system and tests. The tests have been thoughtfully adjusted with relaxed tolerances for FP8 precision. My review includes one suggestion to refactor a small piece of duplicated code to enhance maintainability.

Comment on lines 37 to 72
if (run_fp8_mha) {
launchHopperF8MHAFlashInfer(
multiProcessorCount, nbKHeads, slidingWinSize, qScale,
reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()),
maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), maxSeqLen,
reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
} else {
launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale,
reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()),
maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if and else blocks contain identical arguments passed to two different functions (launchHopperF8MHAFlashInfer and launchMHAFlashInfer). This code duplication can be reduced to improve maintainability. Since both functions share the same signature, you can use a function pointer to select the appropriate kernel and then call it once with the common set of arguments.

  using mha_launcher_t = decltype(&launchMHAFlashInfer);
  mha_launcher_t launcher = run_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer;
  launcher(multiProcessorCount, nbKHeads, slidingWinSize, qScale,
           reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
           reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
           reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
           reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
           reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()),
           maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
           reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
           qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
           reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
           reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
           reinterpret_cast<void*>(scratch.data_ptr()), stream);

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the XQA (eXtended Query Attention) kernels by introducing support for FP8 Multi-Head Attention and FP8 Key-Value cache. These changes are designed to optimize performance and memory usage on NVIDIA Hopper (SM90+) GPUs through low-level CUDA programming, including asynchronous tensor memory access and matrix operations. The integration ensures that the system can efficiently handle lower precision data types, with comprehensive testing to maintain accuracy within acceptable bounds.

Highlights

  • FP8 Multi-Head Attention (MHA) Support: Introduced a new launchHopperF8MHAFlashInfer function and integrated it into the xqa_wrapper to enable FP8 MHA execution, controlled by a new run_fp8_mha boolean parameter.
  • FP8 KV Cache Implementation: Added support for FP8 Key-Value (KV) cache, including new CUtensorMap utilities for efficient memory access and configuration options in the AOT/JIT compilation system.
  • Low-Level CUDA Optimizations: Incorporated new CUDA kernel files (gmma.cuh, tma.h) that define asynchronous matrix multiply accumulate (GMMA) operations and tensor memory access (TMA) functions, leveraging Hopper architecture features for performance.
  • Numerical Stability Improvements: Adjusted the safeInitRowMax constant and its usage in applyMaskFromInput to enhance numerical stability, especially for large values, in attention calculations.
  • AOT/JIT Compilation and Testing: Updated the AOT and JIT compilation infrastructure to generate kernels for various FP8 configurations and expanded the test suite to thoroughly validate the new FP8 MHA and KV cache functionalities, including adjusted precision checks.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP8 multi-head attention (MHA) and FP8 KV cache in the XQA kernels, targeting Hopper architecture for performance improvements. The changes include new low-level CUDA files (gmma.cuh, tma.h, tensorMap.cpp) with Hopper-specific WGMMA and TMA instructions, a new FP8 MHA kernel entry point, and updates to the AOT compilation scripts and Python wrappers to handle the new FP8 variants. The tests have also been updated to include FP8 configurations and use a more lenient assertion method to account for precision differences.

My review focuses on code maintainability and clarity. I've suggested refactoring a duplicated code block in the C++ wrapper to improve readability and proposed adding a comment in the Python tests to clarify a magic number used for data scaling. Overall, the changes are well-structured and the addition of FP8 support is a valuable performance enhancement.

Comment on lines 37 to 72
if (run_fp8_mha) {
launchHopperF8MHAFlashInfer(
multiProcessorCount, nbKHeads, slidingWinSize, qScale,
reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()),
maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), maxSeqLen,
reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
} else {
launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale,
reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()),
maxSeqLen, reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
reinterpret_cast<void*>(scratch.data_ptr()), stream);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a large block of duplicated code for launching the MHA kernels. The only difference between the if and else blocks is the function being called (launchHopperF8MHAFlashInfer vs. launchMHAFlashInfer). This could be refactored to improve maintainability and reduce redundancy.

Consider using a function pointer to select the kernel, and then make a single call. This would make the code cleaner and easier to manage if more arguments are added in the future.

For example:

void (*mha_func)(uint32_t, uint32_t, ...); // Using a function pointer type alias

if (run_fp8_mha) {
    mha_func = &launchHopperF8MHAFlashInfer;
} else {
    mha_func = &launchMHAFlashInfer;
}

mha_func(
    multiProcessorCount,
    nbKHeads,
    slidingWinSize,
    // ... other arguments
);
  using mha_func_t = void (*)(uint32_t, uint32_t, uint32_t, float, OutputHead*,
#if LOW_PREC_OUTPUT
                              float const*,
#endif
                              InputHead const*, float const*, GMemCacheHead*,
                              KVCachePageIndex const*, uint32_t, uint32_t const*, uint32_t,
                              float const* __restrict__,
#if SPEC_DEC
                              uint32_t, uint32_t const*, MaskType const*,
#endif
                              uint32_t*, void*, cudaStream_t);

  mha_func_t mha_func = run_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer;

  mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale,
           reinterpret_cast<OutputHead*>(output.data_ptr()),
#if LOW_PREC_OUTPUT
           reinterpret_cast<float const*>(rcpOutScale.data_ptr()),
#endif
           reinterpret_cast<InputHead const*>(q.data_ptr()), attentionSinksPtr,
           reinterpret_cast<GMemCacheHead*>(pool.data_ptr()),
           reinterpret_cast<KVCachePageIndex const*>(kvCachePageList.data_ptr()), maxSeqLen,
           reinterpret_cast<uint32_t const*>(seqLen.data_ptr()), batchSize,
           reinterpret_cast<float const*>(kvCacheScale.data_ptr()),
#if SPEC_DEC
           qSeqLen, reinterpret_cast<uint32_t const*>(qCuSeqLens.data_ptr()),
           reinterpret_cast<MaskType const*>(mask.data_ptr()),
#endif
           reinterpret_cast<uint32_t*>(semaphores.data_ptr()),
           reinterpret_cast<void*>(scratch.data_ptr()), stream);

Comment on lines 230 to 233
if fp8_kv_cache:
cache_heads /= 4.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value 4.0 is used to scale down the cache_heads tensor when fp8_kv_cache is enabled. This appears to be a magic number. To improve code clarity and maintainability, please add a comment explaining the rationale for this specific scaling factor. For example, explaining that it's to prevent overflow and how 4.0 was determined would be very helpful for future readers.

Suggested change
if fp8_kv_cache:
cache_heads /= 4.0
if fp8_kv_cache:
# Scale down the cache heads to keep values within the representable range of FP8
# and prevent overflow during computation. The factor 4.0 is chosen empirically.
cache_heads /= 4.0

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP8 multi-head attention (MHA) and FP8 KV cache for Cross-Query Attention (XQA), targeting NVIDIA's Hopper architecture. This is a significant feature addition that leverages low-level hardware capabilities like TMA and GMMA for performance. The changes include new CUDA files for these Hopper-specific features, along with updates to the Python build system and tests to accommodate the new configurations. The review identified a critical bug in the new TMA storeAsync implementation and a high-severity correctness issue related to the handling of masked values in the softmax computation.

Comment on lines +220 to +243
: "memory");
} else if constexpr (nbDims == 5) {
asm volatile(
"cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], "
"[%6];\n"
:
: "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]),
"r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src))
: "memory");
} else {
static_assert(nbDims >= 1 && nbDims <= 5);
}
}

__device__ inline void setTensorMapGlbAddr(CUtensorMap& tensorMap, void* ptr) {
asm volatile(
"tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;\n" ::"l"(&tensorMap),
"l"(ptr)
: "memory");
}

__device__ inline void commitGroup() {
asm volatile("cp.async.bulk.commit_group;\n" : : : "memory");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a copy-paste error in the storeAsync template function. For nbDims of 3, 4, and 5, the inline assembly instruction is cp.async.bulk.tensor.2d..., but it should be cp.async.bulk.tensor.3d..., cp.async.bulk.tensor.4d..., and cp.async.bulk.tensor.5d... respectively. This will lead to incorrect memory access patterns and likely data corruption for higher-dimensional tensors.

    else if constexpr (nbDims == 3)
    {
        asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n"
                     :
                     : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]),
                     "l"(__cvta_generic_to_shared(src))
                     : "memory");
    }
    else if constexpr (nbDims == 4)
    {
        asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n"
                     :
                     : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]),
                     "r"(offset[3]), "l"(__cvta_generic_to_shared(src))
                     : "memory");
    }
    else if constexpr (nbDims == 5)
    {
        asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], [%6];\n"
                     :
                     : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]),
                     "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src))
                     : "memory");
    }

? true
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : -INFINITY;
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using safeInitRowMax for masked elements can lead to incorrect results. When an entire row/sequence is masked, all attention scores become safeInitRowMax. In the softmax computation, maxVal also becomes safeInitRowMax, and exp(score - maxVal) evaluates to 1 for all masked positions. This results in a uniform attention distribution over masked tokens, and the output becomes the average of values in V, instead of zero.

A correct implementation should ensure that the softmax output for masked tokens is zero. If the entire row is masked, the final output should also be zero. This might require changes in the softmax function to handle safeInitRowMax specially, and in the final normalization step to handle a row sum of zero.

@qsang-nv qsang-nv requested a review from yzh119 September 25, 2025 08:52
@@ -16,8 +16,8 @@

#include "pytorch_extension_utils.h"

void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize,
double qScale, at::Tensor output,
void xqa_wrapper(bool run_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of making this a flag, could we pass a dtype?

Same for the other places where we pass:

  • the type of the input (only bf16 and fp16 supported I think)
  • the type of the kv-cache (fp8 or bf16)
  • the type in which we perform arithmetic (the same type as the kv-cache I think?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it is passing dtype in flashinfer/flashinfer/xqa.py


inline constexpr float log2e = 1.4426950408889634; // std::log2(M_E)
inline constexpr float safeInitRowMax = -1e+30F;
// we used an optimization where exp(x-rowMax) is computed as:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's interesting: what were the symptoms of the instability? Accuracy loss?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from NVIDIA/TensorRT-LLM@c1aa7f3, you may ask the author, I am not sure about this question.

@@ -354,4 +364,21 @@ def cache_head_at(
kernel_output = output[req][b][
idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size
].to(torch.float32)
assert torch.allclose(ref_output, kernel_output, atol=0.01, rtol=0.01)
if fp8_kv_cache or run_fp8_mha:
atol = 0.05
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you tune this tolerance? Can it be smaller?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From 0.01 to 0.05, add 0.01 every step. And it can't be smaller from my test.

Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
else:
flag_sliding_window = ["-DSLIDING_WINDOW=0"]

if sm_version == 100:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add SM103 support by targeting SM100f instead of SM100a?

And similarly, can we add SM121 support by targeting SM120f instead of SM120a?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference in those archs? I mean SM103/SM100f/SM100a, and SM121/SM120f/SM120a.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "a" means "arch-specific" and the "f" means "family".

SM100 and SM103 are in the "SM100f family".

SM100a (SM100 arch-specific) will only run on SM100 devices.

But I believe SM103 devices have a superset of the SM100 features, and therefore if you target SM100f instead of SM100a during compilation, your cubin will be able to run on SM103 as well, without any loss of optimization on either device. So I think it's strictly better than targeting SM103a.

SM121 and SM120 have a similar story: it's better to target SM120f as a compilation target, yielding a cubin that will run on both SM120 and SM121 devices without any compromise to performance.

See this documentation for details: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#family-specific-features

@aleozlx can you confirm my understanding?

Copy link
Collaborator

@aleozlx aleozlx Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes sm_100f is known as family specific or family conditional. this is important to enhance device compatibility in a sort of out of the box fashion.

i'd like to point out a few things tho based on my experience:

  1. with strictly jit compilation where the premise is that at runtime fewer devices are in the compatibility question, the arch conditionals may be the safest way to target the instruction supersets. (when the target is available to compile by the toolkit at the time of implementation)
  2. family conditionals are important for compatibility story (indeed aligning with your understanding conceptually) but it is not without inherent engineering complexity. from an engineer's perspective i naturally experience a slightly more complicated story at the levels beneath. i'll spare the details but leave it as a reliability intuition (in the context of jit).

however, i want to bring this CompilationContext.get_nvcc_flags_list() up for consideration to not hard code any guidance either way but have it abstracted so we can adjust if situation changes. briefly, how this works is for each op/backend if you whitelist your supported targets, this function shall serve as the mapping to provide the recommended flags.

we can put in sm103 support (likely fine for attn) supposing our cicd will catch the issue if not

Copy link
Member

@sricketts sricketts Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, let me clarify my guidance then:

  1. Since SM120 and SM121 are identical architectures, naively I would think that supporting both is only marginally harder than supporting one.
  2. The story seems even a little harder for SM100 and SM103, because they are not identical, but if you don't care about SM103-specific features, I would think the marginal effort of supporting SM103 shouldn't be massive.
  3. Therefore the default design for any solution for SM100 and SM120 should at least try to include SM103 and SM121 support, or should at least be designed with SM103 and SM121 in mind, even if some of the details are left as a future TODO.
  4. What compilation targets you use, and how you architect the code to query for those compilation targets, is an engineering implementation detail, and I probably shouldn't be opinionated about that. :)

The problem here is that the PR doesn't address (3), maybe because @qsang-nv didn't know about (1) and (2). I'm only proposing that we re-think the design here with the above in mind.

@aleozlx @qsang-nv do you agree?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed explanation! I've added support for sm100f and sm121a.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 17, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds MLA (SM120/121) and FP8-capable MHA paths, conditional paged KV-cache layout support, tensor-map and TMA async APIs, MatDesc/gmma helpers, MLA kernels/launchers, Python dtype-driven XQA build/runtime wiring, expanded tests, and docs updates.

Changes

Cohort / File(s) Summary
C++ binding / public API
csrc/flashinfer_xqa_binding.cu
Added xqa_wrapper_mla export and split xqa_wrapper variants under #if MLA_WRAPPER guards; introduced leading bool run_sm90_fp8_mha, tvm::ffi::Optional<TensorView> for attention sinks, and conditional KV-cache parameter groups with TVM_FFI_DLL_EXPORT_TYPED_FUNC entries and final #endif.
XQA wrapper / dispatcher
csrc/xqa/xqa_wrapper.cu
New MLA wrapper xqa_wrapper_mla; xqa_wrapper now takes run_sm90_fp8_mha, selects MHA launcher via a function pointer, and forwards conditional KV-cache args (either kCacheVLLM/vCacheVLLM or pool) and optional attention sinks to MLA/non-MLA paths.
MHA interfaces & kernels
csrc/xqa/mha.h, csrc/xqa/mha.cu
launchMHAFlashInfer, launchHopperF8MHA*, and related launches updated to accept conditional KV-cache params: kCacheVLLM,vCacheVLLM when PAGED_KV_CACHE_LAYOUT==1, else pool; callers updated; arch guard extended to include __CUDA_ARCH__ == 1000.
MLA SM120 path
csrc/xqa/mla_sm120.cu, csrc/xqa/mla_sm120.cuh
New MLA SM120 kernel and host launchers (kernel_mha, launchMLA, launchMLAFlashInfer) plus device helpers (row-max load/store/async, computeRowMax, hashRegData) and tensor-map helpers for Q layout when GENERATE_CUBIN is enabled.
Matrix / MMA helpers
csrc/xqa/gmma.cuh
New gmma namespace: SwizzleMode, packed MatDesc bitfield, makeMatDesc/addAddr, inst constants, templated mma_async_* declarations, fence/commit/wait primitives, and inclusion of gmma_impl.cuh.
TensorMap utilities
csrc/xqa/tensorMap.h, csrc/xqa/tensorMap.cpp
New header+impl to build CUtensorMap for contiguous and paged KV cache layouts, getElemBytes, swizzle/part selection, layout branching on PAGED_KV_CACHE_LAYOUT, and error checks around cuTensorMapEncodeTiled.
TMA async API
csrc/xqa/tma.h
New tma namespace exposing StateSpace, conditional CUtensorMap typedef (GENERATE_CUBIN path), cp.async/tensormap-based loadAsync/storeAsync/prefetchTensorMap, setTensorMapGlbAddr, commitGroup/waitGroup, and related device primitives.
Utils
csrc/xqa/utils.cuh
Changed safeInitRowMax from -1e+30F to -1e+5F with explanatory comments about numerical stability; extended kMAX_SMEM_SIZE arch conditional to include __CUDA_ARCH__ == 1000.
Python JIT / module generation
flashinfer/jit/xqa.py
gen_xqa_module refactored to dtype-driven signature (input_dtype, kv_cache_dtype, page_size, head_dim, head_group_ratio, sm_version); added gen_xqa_module_mla for MLA builds; NVCC flags chosen by sm_version; sources updated to include tensorMap and MLA sources; added -DPAGED_KV_CACHE_LAYOUT=1 and -lcuda where applicable.
Python AOT / build orchestration
flashinfer/aot.py
Updated gen_xqa to accept fp16_input_, fp8_kv_cache_, iterate sm_versions (90/100/120/121), and emit specs including sm_version and kv-cache dtype; imports gen_xqa_module_mla.
Python public API & wiring
flashinfer/xqa.py, flashinfer/__init__.py
get_xqa_module/xqa signatures converted to dtype-driven args; added MLA variants get_xqa_module_mla and xqa_mla; runtime callsites updated to pass run_sm90_fp8_mha, k_cache, v_cache, page_table, workspace_buffer, semaphores; package re-exports xqa_mla.
Tests
tests/attention/test_xqa.py
Tests extended for MLA/VLLM-style layout: shared page_list, separate K/V caches, fp16_input and fp8_kv_cache flags, SM120 gating, FP8 conversion/scaling, updated indexing, and adjusted tolerances/pass-ratio checks; new test_xqa_mla.
Docs
docs/api/attention.rst
Added flashinfer.xqa module doc entry and xqa autosummary symbol to API docs.

Sequence Diagram(s)

sequenceDiagram
    participant Python as Client
    participant Gen as ModuleGen
    participant NVCC as Compiler
    participant Module as CompiledModule
    participant CppBind as flashinfer_xqa_binding
    participant Wrapper as xqa_wrapper / xqa_wrapper_mla
    participant Launcher as MHA Dispatcher
    participant MLAKernel as MLA Kernel (SM120/121)
    participant HopperKernel as Hopper F8 Kernel
    participant StdKernel as Std MHA Kernel

    Python->>Gen: request module (input_dtype, kv_cache_dtype, sm_version, MLA?)
    Gen->>NVCC: compile with sources (tensorMap.cpp, mha_sm90.cu, mla_sm120.cu...) and flags
    NVCC-->>Gen: compiled Module
    Gen-->>Python: Module handle

    Python->>Module: call xqa(...) / xqa_mla(...)
    Module->>CppBind: call exported wrapper
    CppBind->>Wrapper: forward args (run_sm90_fp8_mha?, caches, page_table, ...)
    Wrapper->>Launcher: select launcher (MLA vs HopperF8 vs Std)
    alt MLA path
        Launcher->>MLAKernel: launchMLAFlashInfer(...)
    else if run_sm90_fp8_mha
        Launcher->>HopperKernel: launchHopperF8MHAFlashInfer(...)
    else
        Launcher->>StdKernel: launchMHAFlashInfer(...)
    end
    Kernel-->>Wrapper: results
    Wrapper-->>CppBind: done
    CppBind-->>Python: return
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Suggested reviewers

  • cyx-6
  • yzh119
  • yongwww
  • nvmbreughe
  • kahyunnam

Poem

🐇 I hopped through pages, K and V in tow,

FP8 sparks and swizzles set the flow,
Async maps hum, barriers clap in time,
SMs learn new steps, MLA keeps the rhyme,
A rabbit's cheer: fast attention — watch it go!

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings, 1 inconclusive)
Check name Status Explanation Resolution
Title Check ⚠️ Warning The PR title "add xqa fp8 mha and fp8 kv cache" is concise and clearly describes key features being added to the codebase. However, examining the raw summary reveals that the changeset includes substantial additional work beyond what the title indicates, most notably the introduction of a complete MLA (Multi-head Latent Attention) implementation with new SM120/SM121 kernels, tensor map utilities, and gmma/tma abstractions. While FP8 MHA and FP8 KV cache are indeed real components of this PR, the title omits the equally significant MLA pathway and related infrastructure, making it only partially descriptive of the full scope of changes. The title should be expanded to capture the full scope of changes. Consider revising to something like "Add XQA FP8 MHA, FP8 KV cache, and MLA support" or a similar formulation that acknowledges both the FP8 enhancements and the new MLA pathway, which are both substantial components of this PR.
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description Check ❓ Inconclusive The pull request description is extremely minimal and lacks substantive detail about the changes. The description section contains only a single sentence listing features ("Add xqa fp8 mha and fp8 kv cache. Add fp8 mla for sm120. Use vllm kv layout.") without explaining what these components do, why they are needed, or how they work together. The related issues section is completely empty with no linked issues. While the pre-commit checklist items are marked as completed and tests are indicated as passing, the actual narrative description fails to document the scope and purpose of this significant PR, which involves changes across numerous files including new MLA support, tensor map utilities, FP8 functionality, and API modifications. The description reads as a feature checklist rather than a meaningful explanation suitable for code review. To address this, the author should expand the description to provide meaningful context about what FP8 MHA and KV cache support entails, explain the rationale for switching to VLLM KV layout, detail the MLA implementation for SM120, and document how these components integrate with the existing codebase. Additionally, any related GitHub issues should be linked in the "Related Issues" section, and specific areas requiring reviewer attention should be noted in the "Reviewer Notes" section to facilitate more effective code review.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/attention/test_xqa.py (1)

28-33: Avoid GPU property access at import time.

Accessing torch.cuda.get_device_properties(0) during import can break test discovery on CPU/multi-device envs. Move it inside the test after skip checks.

-props = torch.cuda.get_device_properties(0)
-sm_count = props.multi_processor_count
+sm_count = None  # set inside test to avoid import-time CUDA queries
♻️ Duplicate comments (4)
csrc/xqa/tma.h (1)

208-229: Bug: 3D/4D/5D storeAsync use 2D opcode (will corrupt data).

The cp.async store paths for nbDims 3–5 incorrectly use tensor.2d. Must be tensor.3d/4d/5d.

Apply this diff:

@@
   } else if constexpr (nbDims == 3) {
-    asm volatile(
-        "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n"
-        :
-        : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]),
-          "r"(offset[2]), "l"(__cvta_generic_to_shared(src))
-        : "memory");
+    asm volatile(
+        "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n"
+        :
+        : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]),
+          "r"(offset[2]), "l"(__cvta_generic_to_shared(src))
+        : "memory");
   } else if constexpr (nbDims == 4) {
-    asm volatile(
-        "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n"
-        :
-        : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]),
-          "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(src))
-        : "memory");
+    asm volatile(
+        "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n"
+        :
+        : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]),
+          "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(src))
+        : "memory");
   } else if constexpr (nbDims == 5) {
-    asm volatile(
-        "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], "
-        "[%6];\n"
-        :
-        : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]),
-          "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src))
-        : "memory");
+    asm volatile(
+        "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], [%6];\n"
+        :
+        : "l"(reinterpret_cast<uint64_t>(&tensorMap)), "r"(offset[0]), "r"(offset[1]),
+          "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src))
+        : "memory");
   }
csrc/xqa/mha.cu (1)

479-479: Critical: Masked position initialization may cause incorrect attention output.

Using safeInitRowMax for masked elements can lead to incorrect results. When an entire row is masked, all scores become safeInitRowMax, and in softmax computation exp(score - maxVal) evaluates to 1 for all positions, producing a uniform distribution over masked tokens instead of zero output.

As noted in the previous review, the softmax function should handle safeInitRowMax specially to ensure masked tokens contribute zero to the output, or alternatively masked positions should use a different sentinel value that results in zero after softmax.

csrc/flashinfer_xqa_binding.cu (1)

19-21: Prefer a typed precision enum over a new boolean flag.

Using bool run_fp8_mha does not scale. Replace with a small enum (e.g., int32_t precision: {bf16, fp16, fp8}) and, similarly, pass/cache element/compute dtypes as enums instead of separate flags. This reduces combinatorial overload and ABI churn.

tests/attention/test_xqa.py (1)

241-246: Good: FP8 cache scaling is documented.

Comment explains the 4.0 factor and overflow concerns.

🧹 Nitpick comments (10)
csrc/xqa/tma.h (1)

74-83: Comment/code mismatch for nbDims==1 path.

The comment says “nbDims==1 does not need tensormap,” but the code uses the tensor.1d variant taking a tensor map. Either drop the map for 1D linear copies or update the comment.

Also applies to: 129-138

csrc/xqa/gmma.cuh (1)

27-56: Bitfield layout is implementation-defined; prefer explicit packing.

Relying on 64‑bit bitfield layout and reinterpret_cast to Raw can be brittle across compilers/ABIs. Recommend encoding/decoding with shifts/masks into a uint64_t to guarantee layout and endianness. Keep sizeof(MatDesc)==8 as a guard.

csrc/xqa/tensorMap.h (1)

3-3: cuda.h include: make header robust to non-CUDA analysis/compiles.

Static analysis flagged ‘cuda.h’ not found. If this header is transitively included by non‑CUDA TU(s), guard the include or move these declarations behind a build flag. Example: wrap with a small shim header included only from .cpp, or add a dedicated config that ensures CUDA include paths are present in CI.

csrc/xqa/tensorMap.cpp (1)

43-73: Tensor map for contiguous KV cache looks correct.

The function properly constructs a tensor map for contiguous KV cache layout:

  • Global dimensions and strides are configured appropriately
  • Swizzle selection based on cache line size (128B or 64B)
  • Error handling via checkCu wrapper

Minor suggestion: The error message on line 64 "unsupported cache head size" could be more specific about expected values.

-        throw std::runtime_error("unsupported cache head size");
+        throw std::runtime_error("unsupported partElems: " + std::to_string(partElems) + 
+                                 ", expected 128 or 64");
flashinfer/jit/xqa.py (1)

76-100: SM version selection and build configuration verified; optional refactor still recommended for clarity and error handling.

The changes are correct:

  • New source files (mha_sm90.cu, tensorMap.cpp) exist in csrc/xqa/
  • Build configuration properly references and links them
  • Required CUDA Driver API linker flag (-lcuda) and cache layout flag included

However, the SM version selection logic could be improved for maintainability. The current code defaults to sm90a_nvcc_flags for unrecognized versions, which implicitly handles sm_version=90 but obscures intent and provides no validation for truly unsupported architectures.

Consider making the SM90 case explicit and adding validation:

-    if sm_version == 100:
+    if sm_version == 90:
+        sm_nvcc_flags = sm90a_nvcc_flags
+    elif sm_version == 100:
         sm_nvcc_flags = sm100a_nvcc_flags
     elif sm_version == 120:
         sm_nvcc_flags = sm120a_nvcc_flags
     else:
-        sm_nvcc_flags = sm90a_nvcc_flags
+        raise ValueError(f"Unsupported sm_version: {sm_version}")

This makes supported architectures explicit and catches invalid SM versions early.

csrc/flashinfer_xqa_binding.cu (1)

25-35: KV cache params behind preprocessor guards: keep Python and C++ signatures locked.

Since params differ when PAGED_KV_CACHE_LAYOUT!=1, ensure JIT always defines it (the JIT path does) and document this contract near the binding to avoid accidental ABI mismatches. Consider adding a static assert/log print on init when it’s not set.

Also applies to: 28-29

tests/attention/test_xqa.py (3)

263-275: Remove unused beam_width arg (lint: ARG001).

beam_width in cache_head_at is unused; drop it and update call sites.

-def cache_head_at(
+def cache_head_at(
     batch,
     is_k,
     idx_kv_head,
     pos,
-    cache_k_heads,
-    cache_v_heads,
-    page_list,
-    beam_width,
+    cache_k_heads,
+    cache_v_heads,
+    page_list,
     nb_k_heads,
     tokens_per_page,
 ):
@@
-                    cache_head = cache_head_at(
+                    cache_head = cache_head_at(
                         batch,
                         kv == 0,
                         idx_kv_head,
                         pos,
                         cache_k_heads,
                         cache_v_heads,
-                        page_list_arg,
-                        beam_width,
+                        page_list_arg,
                         nb_k_heads,
                         tokens_per_page,
                     )

Also applies to: 291-303


317-319: Make scratch size configurable; 256 MiB may OOM CI.

Read from an env var with a sane default to reduce flakiness.

-    scratch_size = 256 << 20
+    import os
+    scratch_mb = int(os.environ.get("FLASHINFER_TEST_SCRATCH_MB", "256"))
+    scratch_size = scratch_mb << 20

You can validate with different values in CI matrix.


392-397: Stable epsilon for relative diff.

Optional: use dtype-aware epsilon via torch.finfo to avoid hard-coded 1e-8.

-                diff_rel = diff_abs / (torch.abs(ref_output) + 1e-8)
+                eps = torch.finfo(torch.float32).eps
+                diff_rel = diff_abs / (torch.abs(ref_output) + eps)
flashinfer/xqa.py (1)

147-150: Avoid repeated capability queries and shorten the error.

Cache CC once and use a shorter exception message (addresses Ruff TRY003).

-    if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]:
-        raise RuntimeError("XQA is only supported on SM90, SM100, SM120 GPUs")
-    sm_version = int(get_compute_capability(torch.device(device="cuda"))[0] * 10)
+    cc_major, _ = get_compute_capability(torch.device(device="cuda"))
+    if cc_major not in (9, 10, 12):
+        raise RuntimeError("Unsupported GPU (require SM90/100/120)")
+    sm_version = int(cc_major * 10)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4b55b26 and 9ef83d1.

📒 Files selected for processing (13)
  • csrc/flashinfer_xqa_binding.cu (1 hunks)
  • csrc/xqa/gmma.cuh (1 hunks)
  • csrc/xqa/mha.cu (5 hunks)
  • csrc/xqa/mha.h (2 hunks)
  • csrc/xqa/tensorMap.cpp (1 hunks)
  • csrc/xqa/tensorMap.h (1 hunks)
  • csrc/xqa/tma.h (1 hunks)
  • csrc/xqa/utils.cuh (2 hunks)
  • csrc/xqa/xqa_wrapper.cu (2 hunks)
  • flashinfer/aot.py (3 hunks)
  • flashinfer/jit/xqa.py (2 hunks)
  • flashinfer/xqa.py (6 hunks)
  • tests/attention/test_xqa.py (10 hunks)
🧰 Additional context used
🧬 Code graph analysis (10)
csrc/xqa/tensorMap.h (1)
csrc/xqa/tensorMap.cpp (6)
  • getElemBytes (10-41)
  • getElemBytes (10-10)
  • makeTensorMapForContiguousKVCache (43-73)
  • makeTensorMapForContiguousKVCache (43-47)
  • makeTensorMapForPagedKVCache (75-117)
  • makeTensorMapForPagedKVCache (75-78)
csrc/xqa/tensorMap.cpp (1)
csrc/xqa/utils.h (1)
  • checkCu (39-48)
flashinfer/jit/xqa.py (1)
flashinfer/jit/core.py (2)
  • JitSpec (181-280)
  • gen_jit_spec (283-347)
flashinfer/xqa.py (3)
flashinfer/jit/xqa.py (1)
  • gen_xqa_module (38-101)
flashinfer/jit/core.py (1)
  • build_and_load (268-280)
flashinfer/utils.py (3)
  • register_custom_op (266-275)
  • register_custom_op (285-304)
  • get_compute_capability (245-248)
csrc/xqa/xqa_wrapper.cu (2)
csrc/xqa/mha_sm90.cu (4)
  • launchHopperF8MHAFlashInfer (3168-3275)
  • launchHopperF8MHAFlashInfer (3168-3185)
  • scratch (506-513)
  • scratch (506-506)
csrc/xqa/mha.cu (2)
  • launchMHAFlashInfer (2657-2749)
  • launchMHAFlashInfer (2657-2674)
flashinfer/aot.py (1)
flashinfer/jit/core.py (1)
  • JitSpec (181-280)
csrc/xqa/mha.h (1)
csrc/xqa/mha_sm90.cu (4)
  • launchHopperF8MHAFlashInfer (3168-3275)
  • launchHopperF8MHAFlashInfer (3168-3185)
  • scratch (506-513)
  • scratch (506-506)
tests/attention/test_xqa.py (1)
flashinfer/utils.py (1)
  • get_compute_capability (245-248)
csrc/xqa/tma.h (1)
csrc/xqa/mha_sm90.cu (16)
  • void (548-577)
  • void (579-584)
  • void (588-598)
  • void (1693-1727)
  • void (1765-1797)
  • void (1799-1816)
  • void (1841-1887)
  • void (1976-1997)
  • void (1999-2017)
  • void (2049-2131)
  • void (2180-2254)
  • void (2256-2275)
  • void (2278-2296)
  • void (2316-2332)
  • void (2336-2359)
  • void (2396-2420)
csrc/flashinfer_xqa_binding.cu (1)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (4)
  • output (230-396)
  • output (230-238)
  • output (398-566)
  • output (398-409)
🪛 Clang (14.0.6)
csrc/xqa/tensorMap.h

[error] 3-3: 'cuda.h' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.0)
flashinfer/xqa.py

148-148: Avoid specifying long messages outside the exception class

(TRY003)

tests/attention/test_xqa.py

271-271: Unused function argument: beam_width

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (21)
csrc/xqa/utils.cuh (2)

34-41: Numerical-stability note: initialize rowMax safely but validate ranges seen in practice.

Lowering safeInitRowMax to -1e5 avoids FMA overflow in x*log2e - bias, but it changes the effective lower bound. Please validate on adversarial logits (very negative rows) to ensure no early saturation and no accuracy regressions. Consider guarding the optimization per-arch or switching to compute (x - rowMax) before scaling to avoid FMA on large magnitudes.


49-51: Code is correct; review comment contains incorrect assumptions.

SM100 (Blackwell) opt-in dynamic shared memory per block is 227 KB, which matches the value at line 50. SM120 (Hopper Next) is 99 KB, which is already correctly configured on line 46—not on lines 49-51 as the review suggests.

The conditional structure properly segregates architectures:

  • Line 45-46: SM120 (__CUDA_ARCH__ == 1200) → 99 KB ✓
  • Line 49-50: SM100 (__CUDA_ARCH__ == 1000) → 227 KB ✓

Lines 49-51 handle only SM90 and SM100; SM120 is on a separate branch.

Likely an incorrect or invalid review comment.

csrc/xqa/gmma.cuh (1)

60-66: Assumption: shared address fits 18 bits (0x3FFFF) — verify on SM100/SM120.

The assert/encode mask assumes __cvta_generic_to_shared(data) fits 18 bits (>>4 stored). Please validate against SM100/SM120 SMEM addressing; increase field width or add debug fallback if this trips on larger SMEM configs.

csrc/xqa/mha.h (2)

138-150: API surface change verified — no action required.

Declaration (mha.h:133), definition (mha.cu:2657), and the single call site (xqa_wrapper.cu:43–68) all use identical #if PAGED_KV_CACHE_LAYOUT == 1 preprocessor guards. Arguments are correctly conditionalized at the call site to match the conditional signature. Python bindings invoke through xqa_wrapper, which handles the dispatch internally, isolating the FFI boundary from the conditional signature. ODR/ABI consistency is maintained within the compilation unit.


194-211: All concerns verified successfully.

The function declaration (mha.h:194) and definition (mha_sm90.cu:3168) match exactly, including all conditional parameters (LOW_PREC_OUTPUT, PAGED_KV_CACHE_LAYOUT, SPEC_DEC). The function is properly integrated: it's used in xqa_wrapper.cu:43 with correct FP8 conditional logic (run_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer), defined in the correct TU (mha_sm90.cu), and linked properly without issues.

csrc/xqa/mha.cu (2)

95-95: LGTM: SM100 architecture support added correctly.

The addition of __CUDA_ARCH__ == 1000 appropriately extends support to SM100 devices with the same configuration as SM800/870/900, using larger kHeadPartBytes (128) and cacheVTileSeqLen (64).


2657-2722: Based on my analysis of the kernel signature and the function call, there is a parameter mismatch in the launchMHAFlashInfer function call when SPEC_DEC is defined.

Looking at the kernel_mha signature (mha_sm90.cu:615-652), the kernel expects:

  • nbKHeads (always)
  • slidingWinSize (if SLIDING_WINDOW)
  • qScale, output, rcpOutScale, q, attentionSinks, cacheList
  • beamSearchParams (if USE_BEAM_SEARCH)
  • batchSize, kvCacheScale
  • Tensor maps via grid_constant (not passed as regular parameters)
  • specDecParams (if SPEC_DEC)
  • semaphores, scratch

However, at line 2707-2722, when SPEC_DEC is defined, the call passes qSeqLen, nbKHeads, headGrpSize, qCuSeqLens as four separate parameters, but the kernel expects only nbKHeads at that position. Additionally, the call passes mask (line 2722) but the kernel has no mask parameter—it expects specDecParams instead.

The review comment requires verification of how SpecDecParams and BeamSearchParams should be constructed and passed, since the current call site appears to pass individual fields separately rather than properly constructed structs.

flashinfer/jit/xqa.py (4)

18-24: LGTM: Imports updated appropriately.

The added imports for SM-specific NVCC flags enable proper multi-architecture support.


26-35: LGTM: NVCC flags configured correctly.

The flags properly enable paged KV cache with layout 1, consistent with the conditional compilation paths in the C++ code.


47-55: LGTM: Flag generation logic is correct.

The conditional flag generation properly handles:

  • FP16 vs BF16 input (DTYPE and INPUT_FP16)
  • FP8 vs FP16/BF16 KV cache (CACHE_ELEM_ENUM)

38-46: All call sites are already updated with the new signature.

Verification confirms that:

  • The new fp16_input, fp8_kv_cache, and sm_version parameters are consistently used across the codebase
  • Both call sites (flashinfer/aot.py:404 and flashinfer/xqa.py:40) correctly pass the new parameters
  • Wrapper functions (get_xqa_module and xqa) use the updated signature
  • No use_fp16 parameter exists anywhere in the codebase

The API changes are complete and properly integrated.

csrc/xqa/xqa_wrapper.cu (2)

22-38: LGTM: Function signature updated appropriately.

The signature changes are well-designed:

  • run_fp8_mha parameter enables runtime selection between FP8 and standard MHA
  • Optional<TensorView> for attentionSinks is more idiomatic than raw pointers
  • Conditional KV cache parameters based on PAGED_KV_CACHE_LAYOUT properly support both layout modes

39-65: Function pointer approach is type-safe; signatures are compatible.

Verification confirms that both launchHopperF8MHAFlashInfer and launchMHAFlashInfer have identical signatures with matching conditional compilation blocks and parameter lists, making the function pointer assignment safe and correct.

flashinfer/aot.py (3)

358-372: LGTM: gen_xqa signature updated for multi-architecture support.

The function signature changes are consistent with the JIT module updates:

  • Parameter renaming improves clarity
  • SM gating ensures generation only when supported architectures are available
  • New fp8_kv_cache_ parameter enables FP8 KV cache configurations

373-412: Multi-SM architecture support implemented correctly.

The iteration logic properly:

  • Constructs sm_versions list based on available architectures
  • Iterates over SM versions along with other configuration parameters
  • Validates configurations before generating modules
  • Passes all parameters to gen_xqa_module consistently

527-546: LGTM: gen_all_modules updated consistently.

The changes to gen_all_modules properly wire through the new parameters and SM version support to the XQA generator.

csrc/xqa/tensorMap.cpp (2)

10-41: LGTM: Data type size lookup implemented correctly.

The getElemBytes function provides comprehensive coverage of CUDA tensor map data types with appropriate error handling.


75-117: Paged KV cache tensor map correctly supports two layout modes with consistent stride calculations.

The implementation correctly configures tensor map dimensions and strides for two distinct layouts:

  • VLLM Layout (PAGED_KV_CACHE_LAYOUT == 1): dimensions {headElems, nbKHeads, tokensPerPage, pages} with strides accounting for head-first ordering
  • XQA Layout (PAGED_KV_CACHE_LAYOUT == 0, default): dimensions {headElems, tokensPerPage, nbKHeads, pages} with strides accounting for token-first ordering

The dimension ordering aligns with memory access patterns throughout the codebase (verified in mha.cu, mhaUtils.cuh, and mha_sm90.cu). Both layouts apply the same swizzle modes and error handling. No issues identified.

csrc/flashinfer_xqa_binding.cu (2)

24-25: Good: Optional attention sinks.

Switching to tvm::ffi::Optional<TensorView> makes the API safer and clearer.


21-23: No changes needed; LOW_PREC_OUTPUT=0 is already set in compilation flags.

The codebase already includes "-DLOW_PREC_OUTPUT=0" in the xqa_nvcc_flags list within flashinfer/jit/xqa.py. This flag is passed to extra_cuda_cflags in the gen_jit_spec() call, ensuring the rcpOutScale parameter is not included in the C++ function signature. There is no ABI drift risk because the conditional parameter is compiled out consistently.

flashinfer/xqa.py (1)

50-72: Signature wiring looks consistent with the binding.

Param order matches xqa_wrapper (including run_fp8_mha, optional attentionSinks, and separate K/V caches).

If LOW_PREC_OUTPUT is ever enabled, extend these call sites to pass rcpOutScale or force -DLOW_PREC_OUTPUT=0 in JIT.

Also applies to: 73-91

Comment on lines 180 to 183
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] != 9:
pytest.skip("XQA only supports on Hopper at this moment")
if compute_capability[0] != 9 and run_fp8_mha:
pytest.skip("XQA supports fp8 mha only on Hopper GPUs")
set_random_seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Compute sm_count inside the test.

Set SM count after capability checks to avoid premature CUDA access.

 def test_xqa(
@@
-    compute_capability = get_compute_capability(torch.device(device="cuda"))
+    compute_capability = get_compute_capability(torch.device(device="cuda"))
     if compute_capability[0] != 9 and run_fp8_mha:
         pytest.skip("XQA supports fp8 mha only on Hopper GPUs")
     set_random_seed(42)
+    props = torch.cuda.get_device_properties(torch.cuda.current_device())
+    sm_count = props.multi_processor_count

Also applies to: 329-330

🤖 Prompt for AI Agents
In tests/attention/test_xqa.py around lines 180-183, the test currently calls
into CUDA to compute sm_count before checking compute capability and may access
CUDA prematurely; move the sm_count computation so it runs after the
compute_capability check and any pytest.skip decision (i.e., compute sm_count
only after verifying compute_capability and run_fp8_mha), and apply the same
change to the other occurrence around lines 329-330; ensure you call the
sm_count helper (or get_sm_count) with the CUDA device only after the skip logic
and after set_random_seed(42) if that ordering is required.

Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/attention/test_xqa.py (1)

271-294: Address unused beam_width parameter in both cache_head_at definitions and call sites.

The static analysis flag is correct: beam_width is unused. However, the codebase has two identical definitions of cache_head_at (lines 271 and 488) with two corresponding call sites (lines 299 and 516). Both must be updated:

  1. Remove beam_width parameter from the second definition at line 488 (identical to line 271 fix)
  2. Remove beam_width argument from the second call at line 516 (identical to line 299 fix)

Apply the same diffs from the review comment to both locations to ensure consistency across the test file.

flashinfer/aot.py (1)

338-341: Stale comment contradicts implementation

Comment says “fp8 kv not supported in MLA”, but MLA path enforces fp8 KV (and code passes float8_e4m3fn). Update the note to avoid confusion.

-    # MLA
-    # NOTE: fp8 kv not supported in MLA
+    # MLA
+    # NOTE: MLA path uses FP8 input and FP8 KV cache (float8_e4m3fn)
```<!-- review_comment_end -->

</blockquote></details>

</blockquote></details>
♻️ Duplicate comments (3)
tests/attention/test_xqa.py (1)

250-255: Good: FP8 cache scale-down is explained.

The added rationale for the 4.0 factor addresses prior feedback and improves maintainability.

flashinfer/jit/xqa.py (1)

85-93: Prefer family-wide SM12x flags or centralize arch flag selection

Hardcoding 120→sm120a and 121→sm121a risks unnecessary cubin proliferation. Consider targeting the SM120 family (sm120f) for both 120/121, or delegating selection to CompilationContext.get_nvcc_flags_list() to avoid “a vs f” drift and ease future SM additions. This also aligns with prior guidance on family targets.

Would you like a follow-up patch to route sm_nvcc_flags via CompilationContext?

csrc/xqa/xqa_wrapper.cu (1)

71-73: Good dedup with function pointer for MHA launcher

This removes duplicated argument blocks and matches prior recommendations.

Also applies to: 73-94

🧹 Nitpick comments (8)
tests/attention/test_xqa.py (5)

28-30: Avoid initializing CUDA and reading device 0 at import time; infer SM count in xqa/xqa_mla or compute after skip.

Computing props/sm_count globally can initialize CUDA prematurely and may pick the wrong device on multi-GPU. Either rely on library-side inference (preferred) or compute inside the test after skip using torch.cuda.current_device().

-props = torch.cuda.get_device_properties(0)
-sm_count = props.multi_processor_count
+# Prefer letting flashinfer infer SM count from q.device, or compute later if strictly needed.

Follow-ups below remove explicit sm_count passing at the call sites.


328-344: Let xqa infer sm_count; remove explicit sm_count argument.

flashinfer.xqa already infers SM count from q.device when sm_count=None. This also avoids the import-time CUDA query.

     xqa(
         q_heads,
         cache_k_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_k_heads,
         cache_v_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_v_heads,
         page_list_arg,
         seq_len_list,
         output,
         scratch_buf,
         semaphores,
         nb_k_heads,
         tokens_per_page,
         sinks=attention_sinks,
         q_scale=q_scale,
         kv_scale=kv_cache_scale,
         sliding_win_size=sliding_win_size,
-        sm_count=sm_count,
     )

545-558: Same: let xqa_mla infer sm_count.

     xqa_mla(
         q_heads.to(torch.float8_e4m3fn),
         cache_k_heads.to(torch.float8_e4m3fn),
         cache_v_heads.to(torch.float8_e4m3fn),
         page_list_arg,
         seq_len_list,
         output,
         scratch_buf,
         semaphores,
         tokens_per_page,
         q_scale=q_scale,
         kv_scale=kv_cache_scale,
-        sm_count=sm_count,
     )

488-511: Same unused parameter in MLA helper; remove and fix calls.

-    def cache_head_at(
+    def cache_head_at(
         batch,
         is_k,
         idx_kv_head,
         pos,
         cache_k_heads,
         cache_v_heads,
         page_list,
-        beam_width,
         nb_k_heads,
         tokens_per_page,
     ):

And update the call (Lines 517-527):

                     cache_head = cache_head_at(
                         batch,
                         kv == 0,
                         idx_kv_head,
                         pos,
                         cache_k_heads,
                         cache_v_heads,
                         page_list_arg,
-                        beam_width,
                         nb_k_heads,
                         tokens_per_page,
                     )

91-98: Allocate ref tensors on q’s device, not hardcoded "cuda".

Keeps tests robust on non-0 devices and honors current CUDA context.

-    k_cache_f32 = torch.zeros(
-        seq_len, valid_elems_per_head, dtype=torch.float32, device="cuda"
-    )
+    k_cache_f32 = torch.zeros(
+        seq_len, valid_elems_per_head, dtype=torch.float32, device=q.device
+    )
-    v_cache_f32 = torch.zeros(
-        seq_len, valid_elems_per_v_head, dtype=torch.float32, device="cuda"
-    )
+    v_cache_f32 = torch.zeros(
+        seq_len, valid_elems_per_v_head, dtype=torch.float32, device=q.device
+    )
csrc/xqa/mla_sm120.cu (1)

1874-1978: Launcher tensor-map dtype mapping: good coverage; consider centralizing dtype->CUtensorMapDataType mapping.

Mapping duplicated across launchers; extract helper to keep consistent if new dtypes are added.

flashinfer/aot.py (1)

423-446: De-duplicate MLA module gen for SM120/121

Compact the two blocks into one loop over available SM12x variants.

@@
-    if has_sm120:
-        for token_per_page in token_per_page_:
-            yield gen_xqa_module_mla(
-                input_dtype=torch.float8_e4m3fn,
-                kv_cache_dtype=torch.float8_e4m3fn,
-                page_size=token_per_page,
-                head_dim=576,
-                head_group_ratio=128,
-                use_sliding_window=False,
-                sm_version=120,
-            )
-
-    if has_sm121:
-        for token_per_page in token_per_page_:
-            yield gen_xqa_module_mla(
-                input_dtype=torch.float8_e4m3fn,
-                kv_cache_dtype=torch.float8_e4m3fn,
-                page_size=token_per_page,
-                head_dim=576,
-                head_group_ratio=128,
-                use_sliding_window=False,
-                sm_version=121,
-            )
+    sm12x = []
+    if has_sm120: sm12x.append(120)
+    if has_sm121: sm12x.append(121)
+    for smv in sm12x:
+        for token_per_page in token_per_page_:
+            yield gen_xqa_module_mla(
+                input_dtype=torch.float8_e4m3fn,
+                kv_cache_dtype=torch.float8_e4m3fn,
+                page_size=token_per_page,
+                head_dim=576,
+                head_group_ratio=128,
+                use_sliding_window=False,
+                sm_version=smv,
+            )
```<!-- review_comment_end -->

</blockquote></details>
<details>
<summary>flashinfer/xqa.py (1)</summary><blockquote>

`364-413`: **MLA docstring and dtype expectations are inconsistent**

This path enforces FP8 input/KV (and head_dim=576, head_group_ratio=128). Update docstring to reflect FP8 requirements and fixed constraints to prevent misuse.

```diff
@@
-    r"""Apply attention with paged KV cache using XQA kernel.
+    r"""Apply MLA attention with paged KV cache (SM12x, FP8-only).
@@
-    q : torch.Tensor
-        Query tensor with shape ``[batch_size, beam_width, num_q_heads, head_dim]``.
-        Data type should be torch.float16 or torch.bfloat16.
+    q : torch.Tensor
+        Query tensor with shape ``[batch_size, beam_width, num_q_heads, 576]``.
+        Data type must be torch.float8_e4m3fn.
@@
-    k_cache: torch.Tensor
+    k_cache: torch.Tensor
         Paged K cache tensor with shape ``[total_num_cache_heads, head_dim]``.
-        Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation.
+        Data type must be torch.float8_e4m3fn (FP8 KV cache).
@@
-    v_cache: torch.Tensor
+    v_cache: torch.Tensor
         Paged V cache tensor with shape ``[total_num_cache_heads, head_dim]``.
-        Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation.
+        Data type must be torch.float8_e4m3fn (FP8 KV cache).
```<!-- review_comment_end -->

</blockquote></details>

</blockquote></details>

<details>
<summary>📜 Review details</summary>

**Configuration used**: CodeRabbit UI

**Review profile**: CHILL

**Plan**: Pro

<details>
<summary>📥 Commits</summary>

Reviewing files that changed from the base of the PR and between 7372dd602f723e378d110257c168509c0bfae5a0 and 9b0fd24b955289d712ef536bd127405cb14979aa.

</details>

<details>
<summary>📒 Files selected for processing (10)</summary>

* `csrc/flashinfer_xqa_binding.cu` (2 hunks)
* `csrc/xqa/mha.h` (3 hunks)
* `csrc/xqa/mla_sm120.cu` (1 hunks)
* `csrc/xqa/mla_sm120.cuh` (1 hunks)
* `csrc/xqa/xqa_wrapper.cu` (2 hunks)
* `flashinfer/__init__.py` (1 hunks)
* `flashinfer/aot.py` (3 hunks)
* `flashinfer/jit/xqa.py` (1 hunks)
* `flashinfer/xqa.py` (3 hunks)
* `tests/attention/test_xqa.py` (11 hunks)

</details>

<details>
<summary>🧰 Additional context used</summary>

<details>
<summary>🧬 Code graph analysis (9)</summary>

<details>
<summary>flashinfer/jit/xqa.py (1)</summary><blockquote>

<details>
<summary>flashinfer/jit/core.py (2)</summary>

* `JitSpec` (181-280)
* `gen_jit_spec` (283-347)

</details>

</blockquote></details>
<details>
<summary>flashinfer/__init__.py (1)</summary><blockquote>

<details>
<summary>flashinfer/xqa.py (4)</summary>

* `xqa` (56-93)
* `xqa` (124-267)
* `xqa_mla` (294-323)
* `xqa_mla` (350-464)

</details>

</blockquote></details>
<details>
<summary>tests/attention/test_xqa.py (2)</summary><blockquote>

<details>
<summary>flashinfer/xqa.py (4)</summary>

* `xqa` (56-93)
* `xqa` (124-267)
* `xqa_mla` (294-323)
* `xqa_mla` (350-464)

</details>
<details>
<summary>flashinfer/utils.py (1)</summary>

* `get_compute_capability` (245-248)

</details>

</blockquote></details>
<details>
<summary>csrc/flashinfer_xqa_binding.cu (1)</summary><blockquote>

<details>
<summary>csrc/xqa/xqa_wrapper.cu (2)</summary>

* `xqa_wrapper_mla` (23-47)
* `xqa_wrapper_mla` (23-31)

</details>

</blockquote></details>
<details>
<summary>csrc/xqa/xqa_wrapper.cu (4)</summary><blockquote>

<details>
<summary>csrc/xqa/mha_sm90.cu (4)</summary>

* `scratch` (506-513)
* `scratch` (506-506)
* `launchHopperF8MHAFlashInfer` (3168-3275)
* `launchHopperF8MHAFlashInfer` (3168-3185)

</details>
<details>
<summary>csrc/tvm_ffi_utils.h (1)</summary>

* `get_stream` (272-274)

</details>
<details>
<summary>csrc/xqa/mla_sm120.cu (2)</summary>

* `launchMLAFlashInfer` (2010-2131)
* `launchMLAFlashInfer` (2010-2025)

</details>
<details>
<summary>csrc/xqa/mha.cu (2)</summary>

* `launchMHAFlashInfer` (2657-2749)
* `launchMHAFlashInfer` (2657-2674)

</details>

</blockquote></details>
<details>
<summary>csrc/xqa/mla_sm120.cu (3)</summary><blockquote>

<details>
<summary>csrc/xqa/utils.h (3)</summary>

* `checkedVal` (284-287)
* `checkCu` (39-48)
* `checkCuda` (32-37)

</details>
<details>
<summary>csrc/xqa/tma.h (1)</summary>

* `waitGroup` (247-249)

</details>
<details>
<summary>csrc/xqa/hostUtils.h (1)</summary>

* `makeLaunchConfig` (4-12)

</details>

</blockquote></details>
<details>
<summary>csrc/xqa/mha.h (2)</summary><blockquote>

<details>
<summary>csrc/xqa/mha_sm90.cu (4)</summary>

* `launchHopperF8MHAFlashInfer` (3168-3275)
* `launchHopperF8MHAFlashInfer` (3168-3185)
* `scratch` (506-513)
* `scratch` (506-506)

</details>
<details>
<summary>csrc/xqa/mla_sm120.cu (4)</summary>

* `launchMLA` (1874-2008)
* `launchMLA` (1874-1893)
* `launchMLAFlashInfer` (2010-2131)
* `launchMLAFlashInfer` (2010-2025)

</details>

</blockquote></details>
<details>
<summary>flashinfer/xqa.py (4)</summary><blockquote>

<details>
<summary>flashinfer/jit/xqa.py (2)</summary>

* `gen_xqa_module` (41-116)
* `gen_xqa_module_mla` (119-180)

</details>
<details>
<summary>flashinfer/utils.py (6)</summary>

* `get_device_sm_count` (589-590)
* `register_custom_op` (266-275)
* `register_custom_op` (285-304)
* `register_fake_op` (277-281)
* `register_fake_op` (306-311)
* `get_compute_capability` (245-248)

</details>
<details>
<summary>csrc/flashinfer_xqa_binding.cu (2)</summary>

* `xqa_wrapper` (34-50)
* `xqa_wrapper_mla` (20-28)

</details>
<details>
<summary>csrc/xqa/xqa_wrapper.cu (4)</summary>

* `xqa_wrapper` (50-94)
* `xqa_wrapper` (50-66)
* `xqa_wrapper_mla` (23-47)
* `xqa_wrapper_mla` (23-31)

</details>

</blockquote></details>
<details>
<summary>flashinfer/aot.py (2)</summary><blockquote>

<details>
<summary>flashinfer/xqa.py (2)</summary>

* `xqa` (56-93)
* `xqa` (124-267)

</details>
<details>
<summary>flashinfer/jit/xqa.py (2)</summary>

* `gen_xqa_module` (41-116)
* `gen_xqa_module_mla` (119-180)

</details>

</blockquote></details>

</details><details>
<summary>🪛 GitHub Actions: pre-commit</summary>

<details>
<summary>csrc/xqa/mla_sm120.cuh</summary>

[error] 1-1: clang-format failed. The hook modified one or more files; please run 'pre-commit run --all-files' locally and commit the changes.

---

[error] 1-1: clang-format failed. The hook modified one or more files; please run 'pre-commit run --all-files' locally and commit the changes.

</details>
<details>
<summary>csrc/xqa/mla_sm120.cu</summary>

[error] 1-1: clang-format failed. The hook modified one or more files; please run 'pre-commit run --all-files' locally and commit the changes.

---

[error] 1-1: clang-format failed. The hook modified one or more files; please run 'pre-commit run --all-files' locally and commit the changes.

</details>

</details>
<details>
<summary>🪛 Ruff (0.14.1)</summary>

<details>
<summary>flashinfer/jit/xqa.py</summary>

55-57: Avoid specifying long messages outside the exception class

(TRY003)

---

67-69: Avoid specifying long messages outside the exception class

(TRY003)

---

73-75: Avoid specifying long messages outside the exception class

(TRY003)

---

144-146: Avoid specifying long messages outside the exception class

(TRY003)

</details>
<details>
<summary>tests/attention/test_xqa.py</summary>

279-279: Unused function argument: `beam_width`

(ARG001)

---

496-496: Unused function argument: `beam_width`

(ARG001)

</details>
<details>
<summary>flashinfer/xqa.py</summary>

234-234: Avoid specifying long messages outside the exception class

(TRY003)

---

435-435: Avoid specifying long messages outside the exception class

(TRY003)

</details>

</details>

</details>

<details>
<summary>⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)</summary>

* GitHub Check: Deploy Docs

</details>

<details>
<summary>🔇 Additional comments (8)</summary><blockquote>

<details>
<summary>csrc/xqa/mha.h (1)</summary><blockquote>

`213-252`: **API surface looks consistent for MLA flashinfer launcher.**

Signatures align with paged KV cache layout branching and VLLM vs pool modes. No blocking issues here.


Please ensure corresponding definitions in mla_sm120.cu match these exact parameter orders for both layout branches.

</blockquote></details>
<details>
<summary>csrc/xqa/mla_sm120.cuh (1)</summary><blockquote>

`1-5`: I cannot verify the formatting claim because:
- `clang-format` is not available in the sandbox environment
- CI build logs are not accessible to confirm the reported violations

To resolve this review comment, manually verify by running:
```bash
clang-format -i csrc/xqa/mla_sm120.cuh
git diff

Then commit any formatting changes if they are made. Alternatively, check the CI build output to confirm what violations were flagged.

csrc/xqa/mla_sm120.cu (1)

1-32: Review comment verified—formatting violations confirmed.

The clang-format check confirms multiple formatting violations in the file, including include reordering and indentation issues across the entire file (incomplete_format='false'). The review comment is accurate and requires no changes.

csrc/flashinfer_xqa_binding.cu (1)

19-31: No changes needed—review comment is based on incorrect premise.

The review comment claims that xqa_wrapper.cu declares TensorView batchSize, but both files actually declare int64_t batchSize at identical locations (flashinfer_xqa_binding.cu:27 and xqa_wrapper.cu:30). The function signatures are identical; there is no type mismatch or ABI risk.

Likely an incorrect or invalid review comment.

flashinfer/jit/xqa.py (1)

162-176: MLA input DTYPE macro: verify kernel expectations

gen_xqa_module sets -DDTYPE for input; gen_xqa_module_mla does not. If mla_sm120.cu infers FP8 internally, fine; otherwise add a matching -DDTYPE for FP8 to avoid mismatches.

Please confirm whether mla_sm120.cu requires a DTYPE macro. If yes, I can propose the exact flag wiring.

csrc/xqa/xqa_wrapper.cu (1)

22-48: Confirm inputSeqLen hardcoded to 1

launchMLAFlashInfer’s second arg is inputSeqLen (uniform). Setting it to 1 implies decode-only. If prefill or >1 token decode is expected, consider deriving it (e.g., from seqLen) or passing it through the API.

I can draft an update to plumb inputSeqLen if required.

flashinfer/__init__.py (1)

150-150: Public export looks good

Re-exporting xqa_mla aligns package API with the new MLA path.

flashinfer/xqa.py (1)

270-292: Verified: FP8 dtype is already supported in filename_safe_dtype_map

The filename_safe_dtype_map already includes torch.float8_e4m3fn: "e4m3", ensuring op names are stable when using FP8 as input or KV cache dtype. No action needed.

Comment on lines 78 to 89
template <typename T, uint32_t n>
__device__ inline uint32_t hashRegData(Vec<T, n> const& data)
{
static_assert(sizeof(T) == 4);
uint32_t result = 0;
#pragma unroll
for (uint32_t i = 0; i < n; i++)
{
result ^= reinterpret_cast<uint32_t const&>(data[i]);
}
return result;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid strict-aliasing UB in hashRegData; use memcpy/bit-cast.

Reinterpreting float as uint32_t via reference breaks aliasing rules. Use a byte-safe load.

-    for (uint32_t i = 0; i < n; i++)
-    {
-        result ^= reinterpret_cast<uint32_t const&>(data[i]);
-    }
+    for (uint32_t i = 0; i < n; i++)
+    {
+        uint32_t word;
+        memcpy(&word, &data[i], sizeof(word)); // bit-safe
+        result ^= word;
+    }

If needed, add the header include near the top:

 #include "utils.cuh"
+#include <cstring>
🤖 Prompt for AI Agents
In csrc/xqa/mla_sm120.cuh around lines 78 to 89, the loop uses
reinterpret_cast<uint32_t const&>(data[i]) which violates strict-aliasing and
can produce undefined behavior; replace the reinterpret_cast with a byte-safe
copy (e.g., memcpy into a uint32_t temporary for each element or use
std::bit_cast<uint32_t> if C++20 is available) and xor that temporary into
result, keep the existing static_assert(sizeof(T) == 4); also add the
appropriate include near the top (#include <cstring> for memcpy or #include
<bit> for std::bit_cast) so the code compiles cleanly.

Comment on lines +78 to 79
flag_head_group_ratio = [f"-DHEAD_GRP_SIZE={head_group_ratio}"]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate head_group_ratio early

Currently any integer is accepted. Add a guard (e.g., >0 and within the supported set) to fail fast on invalid configs.

@@
-    flag_head_group_ratio = [f"-DHEAD_GRP_SIZE={head_group_ratio}"]
+    if not isinstance(head_group_ratio, int) or head_group_ratio <= 0:
+        raise ValueError(f"Invalid head_group_ratio: {head_group_ratio}")
+    flag_head_group_ratio = [f"-DHEAD_GRP_SIZE={head_group_ratio}"]
```<!-- review_comment_end -->

<details>
<summary>🤖 Prompt for AI Agents</summary>

In flashinfer/jit/xqa.py around lines 78-79, add an early validation for
head_group_ratio: before creating flag_head_group_ratio, check that
head_group_ratio is > 0 and present in a defined SUPPORTED_HEAD_GROUP_RATIOS
collection (or a small allowed tuple/list you add), and if it fails raise a
clear ValueError (or assert) with a message describing valid values; place the
validation just above the flag assignment so invalid configs fail fast.


</details>

<!-- fingerprinting:phantom:medusa:chinchilla -->

<!-- This is an auto-generated comment by CodeRabbit -->

Comment on lines +128 to +140
assert sm_version == 120 or sm_version == 121, (
"Only SM 120 and 121 are supported for xqa MLA"
)
assert head_group_ratio == 128, "Only head group ratio 128 is supported for xqa MLA"
assert head_dim == 576, "Only head dim 576 is supported for xqa_module_mla"
assert input_dtype == torch.float8_e4m3fn, (
"Only fp8 input is supported for xqa_module_mla"
)
assert kv_cache_dtype == torch.float8_e4m3fn, (
"Only fp8 kv cache is supported for xqa_module_mla"
)
assert not use_sliding_window, "Sliding window is not supported for xqa_module_mla"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Don’t use asserts for runtime validation in library code

asserts are stripped with Python -O, losing validation. Convert to explicit checks with clear exceptions.

@@
-    assert sm_version == 120 or sm_version == 121, (
-        "Only SM 120 and 121 are supported for xqa MLA"
-    )
-    assert head_group_ratio == 128, "Only head group ratio 128 is supported for xqa MLA"
-    assert head_dim == 576, "Only head dim 576 is supported for xqa_module_mla"
-    assert input_dtype == torch.float8_e4m3fn, (
-        "Only fp8 input is supported for xqa_module_mla"
-    )
-    assert kv_cache_dtype == torch.float8_e4m3fn, (
-        "Only fp8 kv cache is supported for xqa_module_mla"
-    )
-    assert not use_sliding_window, "Sliding window is not supported for xqa_module_mla"
+    if sm_version not in (120, 121):
+        raise ValueError("Only SM 120 and 121 are supported for xqa MLA")
+    if head_group_ratio != 128:
+        raise ValueError("Only head_group_ratio=128 is supported for xqa MLA")
+    if head_dim != 576:
+        raise ValueError("Only head_dim=576 is supported for xqa_module_mla")
+    if input_dtype != torch.float8_e4m3fn:
+        raise ValueError("Only fp8 input (float8_e4m3fn) is supported for xqa_module_mla")
+    if kv_cache_dtype != torch.float8_e4m3fn:
+        raise ValueError("Only fp8 kv cache (float8_e4m3fn) is supported for xqa_module_mla")
+    if use_sliding_window:
+        raise ValueError("Sliding window is not supported for xqa_module_mla")
```<!-- review_comment_end -->

<details>
<summary>🤖 Prompt for AI Agents</summary>

In flashinfer/jit/xqa.py around lines 128 to 140, replace the runtime asserts
with explicit validation that always runs: check each condition with an
if-statement and raise a clear exception (ValueError or TypeError) when the
check fails (e.g., if sm_version not in {120,121}: raise ValueError(...)); do
the same for head_group_ratio, head_dim, input_dtype, kv_cache_dtype, and
use_sliding_window (raise ValueError if True), and include the offending value
in each error message for easier debugging.


</details>

<!-- fingerprinting:phantom:medusa:chinchilla -->

<!-- This is an auto-generated comment by CodeRabbit -->

Comment on lines +213 to +215
# Calculate head_group_ratio
head_group_ratio = num_q_heads // num_kv_heads

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Validate head grouping divisibility

Add an explicit check that num_q_heads is divisible by num_kv_heads to prevent silent floor-division.

@@
-    head_group_ratio = num_q_heads // num_kv_heads
+    if num_kv_heads <= 0 or (num_q_heads % num_kv_heads) != 0:
+        raise ValueError(
+            f"num_q_heads ({num_q_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
+        )
+    head_group_ratio = num_q_heads // num_kv_heads
```<!-- review_comment_end -->

<!-- suggestion_start -->

<details>
<summary>📝 Committable suggestion</summary>

> ‼️ **IMPORTANT**
> Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

```suggestion
    # Calculate head_group_ratio
    if num_kv_heads <= 0 or (num_q_heads % num_kv_heads) != 0:
        raise ValueError(
            f"num_q_heads ({num_q_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
        )
    head_group_ratio = num_q_heads // num_kv_heads
🤖 Prompt for AI Agents
In flashinfer/xqa.py around lines 213 to 215, the head_group_ratio is computed
by floor-dividing num_q_heads by num_kv_heads without validating divisibility;
add an explicit check (e.g., raise ValueError or assert) that num_q_heads %
num_kv_heads == 0 before computing head_group_ratio, and provide a clear error
message indicating the requirement (num_q_heads must be divisible by
num_kv_heads) so that silent floor-division is prevented.

Comment on lines +225 to +238
if (
k_cache.dtype == torch.float8_e4m3fn
and get_compute_capability(torch.device(device="cuda"))[0] == 9
):
run_sm90_fp8_mha = True
else:
run_sm90_fp8_mha = False

if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]:
raise RuntimeError("XQA is only supported on SM90, SM100, SM120 GPUs")
sm_version = int(
get_compute_capability(torch.device(device="cuda"))[0] * 10
+ get_compute_capability(torch.device(device="cuda"))[1]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use q.device for compute capability; avoid repeated queries

Current code queries the default CUDA device, which breaks on multi-GPU when tensors aren’t on device 0, and repeats the query. Cache the CC from q.device and reuse.

@@
-    if (
-        k_cache.dtype == torch.float8_e4m3fn
-        and get_compute_capability(torch.device(device="cuda"))[0] == 9
-    ):
+    cc_major, cc_minor = get_compute_capability(q.device)
+    if k_cache.dtype == torch.float8_e4m3fn and cc_major == 9:
         run_sm90_fp8_mha = True
     else:
         run_sm90_fp8_mha = False
 
-    if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]:
+    if cc_major not in (9, 10, 12):
         raise RuntimeError("XQA is only supported on SM90, SM100, SM120 GPUs")
-    sm_version = int(
-        get_compute_capability(torch.device(device="cuda"))[0] * 10
-        + get_compute_capability(torch.device(device="cuda"))[1]
-    )
+    sm_version = int(cc_major * 10 + cc_minor)
```<!-- review_comment_end -->

<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 Ruff (0.14.1)</summary>

234-234: Avoid specifying long messages outside the exception class

(TRY003)

</details>

</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

In flashinfer/xqa.py around lines 225-238, the code repeatedly calls
get_compute_capability(torch.device(device="cuda")) and always queries the
default CUDA device; instead, obtain the compute capability once from the actual
tensor device (use q.device), store it in a local variable (e.g., cc =
get_compute_capability(q.device)), then replace all subsequent calls with cc to
determine run_sm90_fp8_mha (use cc[0]) and to compute sm_version (sm_version =
int(cc[0]*10 + cc[1])); ensure q is the tensor in scope and remove the repeated
get_compute_capability calls.


</details>

<!-- fingerprinting:phantom:medusa:chinchilla -->

<!-- This is an auto-generated comment by CodeRabbit -->

Comment on lines +434 to +439
if get_compute_capability(torch.device(device="cuda"))[0] not in [12]:
raise RuntimeError("XQA is only supported on SM120 GPUs")
sm_version = int(
get_compute_capability(torch.device(device="cuda"))[0] * 10
+ get_compute_capability(torch.device(device="cuda"))[1]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use q.device for CC and clarify SM12x support message

Same device issue as above; also error message should mention SM121.

-    if get_compute_capability(torch.device(device="cuda"))[0] not in [12]:
-        raise RuntimeError("XQA is only supported on SM120 GPUs")
-    sm_version = int(
-        get_compute_capability(torch.device(device="cuda"))[0] * 10
-        + get_compute_capability(torch.device(device="cuda"))[1]
-    )
+    cc_major, cc_minor = get_compute_capability(q.device)
+    if cc_major != 12:
+        raise RuntimeError("XQA MLA is only supported on SM12x (SM120/SM121) GPUs")
+    sm_version = int(cc_major * 10 + cc_minor)
```<!-- review_comment_end -->

<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 Ruff (0.14.1)</summary>

435-435: Avoid specifying long messages outside the exception class

(TRY003)

</details>

</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

In flashinfer/xqa.py around lines 434 to 439, replace repeated
get_compute_capability(torch.device(device="cuda")) calls with a single call
using the model/query device (q.device), store the returned tuple (major,
minor), compute sm_version as major*10 + minor, and change the check to require
major == 12 and minor in (0,1) (i.e., SM120 or SM121); update the RuntimeError
message to explicitly mention SM12x (or list SM120 and SM121) so it is accurate
and helpful.


</details>

<!-- fingerprinting:phantom:medusa:chinchilla -->

<!-- This is an auto-generated comment by CodeRabbit -->

Comment on lines 166 to 169
@pytest.mark.skipif(
get_compute_capability(torch.device(device="cuda"))[0] != 9,
reason="XQA is only supported on SM90 GPUs",
get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12],
reason="XQA is only supported on SM90, SM100, SM120 GPUs",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard skipif with cuda availability to avoid import-time crashes on CPU-only runners.

Add torch.cuda.is_available() to the condition so get_compute_capability isn’t called when CUDA isn’t present.

-@pytest.mark.skipif(
-    get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12],
-    reason="XQA is only supported on SM90, SM100, SM120 GPUs",
-)
+@pytest.mark.skipif(
+    (not torch.cuda.is_available())
+    or (get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]),
+    reason="XQA is only supported on SM90, SM100, SM120 GPUs",
+)
🤖 Prompt for AI Agents
In tests/attention/test_xqa.py around lines 166 to 169, the pytest.mark.skipif
condition calls get_compute_capability(torch.device(device="cuda")) even on
CPU-only runners which can cause import-time crashes; update the condition to
first check torch.cuda.is_available() and only call get_compute_capability when
CUDA is available (e.g., skip if not torch.cuda.is_available() or the compute
capability is not in [9,10,12]) so the skip guard short-circuits safely on
machines without CUDA.

Comment on lines +406 to +409
@pytest.mark.skipif(
get_compute_capability(torch.device(device="cuda"))[0] not in [12],
reason="XQA mla is only supported on SM120 GPUs",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Same guard for MLA skipif.

Protect import-time check on environments without CUDA.

-@pytest.mark.skipif(
-    get_compute_capability(torch.device(device="cuda"))[0] not in [12],
-    reason="XQA mla is only supported on SM120 GPUs",
-)
+@pytest.mark.skipif(
+    (not torch.cuda.is_available())
+    or (get_compute_capability(torch.device(device="cuda"))[0] not in [12]),
+    reason="XQA mla is only supported on SM120 GPUs",
+)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@pytest.mark.skipif(
get_compute_capability(torch.device(device="cuda"))[0] not in [12],
reason="XQA mla is only supported on SM120 GPUs",
)
@pytest.mark.skipif(
(not torch.cuda.is_available())
or (get_compute_capability(torch.device(device="cuda"))[0] not in [12]),
reason="XQA mla is only supported on SM120 GPUs",
)
🤖 Prompt for AI Agents
In tests/attention/test_xqa.py around lines 406 to 409, the skipif currently
calls get_compute_capability at import time which fails on systems without CUDA;
change the condition to short-circuit by first checking CUDA availability so
get_compute_capability is only invoked when CUDA exists (e.g. use
pytest.mark.skipif(not torch.cuda.is_available() or
get_compute_capability(torch.device("cuda"))[0] not in [12], reason="XQA mla is
only supported on SM120 GPUs")), ensuring torch is imported before the
decorator.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9b0fd24 and e1a5c0c.

📒 Files selected for processing (2)
  • csrc/xqa/mla_sm120.cu (1 hunks)
  • csrc/xqa/mla_sm120.cuh (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/xqa/mla_sm120.cuh
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/xqa/mla_sm120.cu (3)
csrc/xqa/utils.h (3)
  • checkedVal (284-287)
  • checkCu (39-48)
  • checkCuda (32-37)
csrc/xqa/tma.h (1)
  • waitGroup (247-249)
csrc/xqa/hostUtils.h (1)
  • makeLaunchConfig (4-12)
🔇 Additional comments (2)
csrc/xqa/mla_sm120.cu (2)

171-175: Non‑paged loader dims rely on baseOffset selecting K/V lane.

Once kvKind is added, the non‑paged DimsLE<4>{..., baseOffset} remains correct. Please re‑run any non‑paged configuration tests (if supported) to confirm.


936-974: Verify FP8 scaling pipeline in mla_sm120.cu against mha_sm90.cu pattern.

Code analysis reveals an inconsistency in how kvCacheScale is applied between MHA and MLA implementations:

  • mha_sm90.cu (line 798): qkScale *= kvCacheScale after loading K/V with rcpKScale = 1.F / kvCacheScale[0] (unquantization)
  • mla_sm120.cu (line 497): qkScaleLog2e *= kvCacheScale without visible unquantization during K/V load, and again (line 1376): scaleVec *= xvScale where xvScale = kvCacheScale

The risk is real if mla_sm120.cu loads FP8 cache values without dequantization: applying kvCacheScale in both qkScaleLog2e and finalize() could apply scaling twice. Run the proposed numerical validation test to confirm whether K/V cache values are pre-dequantized or require unquantization before use.

Comment on lines +117 to +127
#if PAGED_KV_CACHE_LAYOUT == 1
,
baseOffset{idxReq * cacheList.maxNbPagesPerSeq}
#else
,
baseOffset{((idxReq * beamWidth) * 2) * cacheList.maxNbPagesPerSeq}
#endif
#else
,
baseOffset{(idxReq * beamWidth) * 2}
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix K/V selection in KV page indexing (layout 0 and non‑paged). V currently reads K pages.

In layout 0, kvCachePageList is shaped [batch][beam][2][maxPages]. baseOffset doesn’t include the K/V selector; V uses K pages. Non‑paged baseOffset also misses the +kvKind lane. Pass a kvKind (0=K, 1=V) into KVTilePartLoader and fold it into baseOffset; update call sites.

--- a/csrc/xqa/mla_sm120.cu
+++ b/csrc/xqa/mla_sm120.cu
@@ -73,6 +73,9 @@ struct KVTilePartLoader {
   static inline constexpr uint32_t const nbKHeads = 1;
   KVCacheList<usePagedKVCache> const& cacheList;
   uint32_t const idxReq;
+  // 0 = K, 1 = V (ignored for layout 1 where page list is shared)
+  uint32_t const kvKind;
+
   static inline constexpr uint32_t const idxHeadGrp = 0;
@@ -88,7 +91,8 @@ struct KVTilePartLoader {
-  __device__ KVTilePartLoader(KVCacheList<usePagedKVCache> const& cacheList, uint32_t idxReq,
-                              CUtensorMap const& tensorMap
+  __device__ KVTilePartLoader(KVCacheList<usePagedKVCache> const& cacheList, uint32_t idxReq,
+                              uint32_t kvKind, CUtensorMap const& tensorMap
 #if USE_PAGED_KV_CACHE
                               ,
                               uint32_t nbPages
 #endif
   );
@@ -104,12 +108,14 @@ __device__ inline KVTilePartLoader::KVTilePartLoader(KVCacheList<usePagedKVCache
-                                                     uint32_t idxReq, CUtensorMap const& tensorMap
+                                                     uint32_t idxReq, uint32_t kvKind,
+                                                     CUtensorMap const& tensorMap
 #if USE_PAGED_KV_CACHE
                                                      ,
                                                      uint32_t nbPages
 #endif
                                                      )
     : cacheList{cacheList},
       idxReq{idxReq},
+      kvKind{kvKind},
       tensorMap{tensorMap}
 #if USE_PAGED_KV_CACHE
       ,
       nbPages{nbPages}
 #if PAGED_KV_CACHE_LAYOUT == 1
       ,
       baseOffset{idxReq * cacheList.maxNbPagesPerSeq}
 #else
       ,
-      baseOffset{((idxReq * beamWidth) * 2) * cacheList.maxNbPagesPerSeq}
+      baseOffset{(((idxReq * beamWidth) * 2) + kvKind) * cacheList.maxNbPagesPerSeq}
 #endif
 #else
       ,
-      baseOffset{(idxReq * beamWidth) * 2}
+      baseOffset{(idxReq * beamWidth) * 2 + kvKind}
 #endif
 {
@@ -862,10 +868,11 @@ __device__ inline void Producer::loadK() {
-  KVTilePartLoader loader{args.cacheList, idxReq, args.tensorMapK
+  KVTilePartLoader loader{args.cacheList, idxReq, /*kvKind=*/0, args.tensorMapK
 #if USE_PAGED_KV_CACHE
                           ,
                           divUp(seqLen, tokensPerPage)
 #endif
   };
@@ -1342,10 +1349,11 @@ __device__ inline void Consumer::loadV() {
-  KVTilePartLoader loader(args.cacheList, idxReq, args.tensorMapV
+  KVTilePartLoader loader(args.cacheList, idxReq, /*kvKind=*/1, args.tensorMapV
 #if USE_PAGED_KV_CACHE
                           ,
                           divUp(seqLen, tokensPerPage)
 #endif
   );

Also applies to: 125-127, 862-870, 1342-1350

🤖 Prompt for AI Agents
In csrc/xqa/mla_sm120.cu around lines 117-127 (and also apply same change at
125-127, 862-870, 1342-1350): the baseOffset calculation for KV page indexing
omits the K/V selector so V reads K pages for layout 0 and non-paged layouts;
modify the KVTilePartLoader interface to accept an extra kvKind parameter (0 =
K, 1 = V) and fold kvKind into baseOffset computations (for layout==0: include
the "+ kvKind" lane inside the index that multiplies maxNbPagesPerSeq so the
shape [batch][beam][2][maxPages] is respected; for non-paged: add the kvKind
offset to (idxReq * beamWidth) * 2); update all KVTilePartLoader call sites in
the file to pass the appropriate kvKind value and ensure any preprocessor
branches use the new parameter consistently.

Comment on lines +907 to +915
auto& dst = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][ctaRank];
tma::store1DAsync(&dst, &smem.x, sizeof(CgaXBuffer));
tma::commitGroup();
tma::waitGroup<0>();
// it's turn for the other math group to produce.
uint32_t const idxBarNext = (iter + 1) % SharedMemA::nbXBars;
auto& xBarNext = smem.xBars[idxBarNext];
xBarNext.consumed.arrive();
asm volatile("fence.release.cluster;\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Use maxNbSubSeq for scratch indexing to avoid cross‑request overlap.

Scratch is sized with nbSubSeqPerSeq (host). Using nbSubSeq as a stride can alias buffers when nbSubSeq varies per request. Index with maxNbSubSeq consistently.

-      auto& dst = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][ctaRank];
+      auto& dst = args.cgaXBuf[maxNbSubSeq * idxInputTokenGlobal + idxSubSeq][ctaRank];
@@
-      auto& src = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][idxScratchXBuf];
+      auto& src = args.cgaXBuf[maxNbSubSeq * idxInputTokenGlobal + idxSubSeq][idxScratchXBuf];

Also applies to: 1329-1338

🤖 Prompt for AI Agents
In csrc/xqa/mla_sm120.cu around lines 907 to 915, the scratch indexing uses
nbSubSeq as a stride which can alias across requests when nbSubSeq varies;
replace the stride expression so the index uses maxNbSubSeq (i.e. index =
maxNbSubSeq * idxInputTokenGlobal + idxSubSeq) when addressing cgaXBuf (and any
other scratch buffers in this block), and apply the same change to the other
affected block at lines ~1329-1338 so all scratch accesses consistently use
maxNbSubSeq as the per-request stride.

Comment on lines +1248 to +1251
auto const data = ldmatrix_16x16_trans<2>(
&vBuf.template at<true>(qmmaShape.k * idxInstK + rB, idxAtomBx2 + cB));
AtomB const v[2] = {data[0], data[2], data[1], data[3]};
#pragma unroll
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix AtomB array initialization; current form provides 4 initializers to a 2‑element array.

Use nested initializers to build two AtomB from the four ldmatrix words.

-        auto const data = ldmatrix_16x16_trans<2>(
-            &vBuf.template at<true>(qmmaShape.k * idxInstK + rB, idxAtomBx2 + cB));
-        AtomB const v[2] = {data[0], data[2], data[1], data[3]};
+        auto const data = ldmatrix_16x16_trans<2>(
+            &vBuf.template at<true>(qmmaShape.k * idxInstK + rB, idxAtomBx2 + cB));
+        // Reorder (0,2) and (1,3) into two AtomB = Vec<uint32_t,2>
+        AtomB const v[2] = { AtomB{data[0], data[2]}, AtomB{data[1], data[3]} };
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
auto const data = ldmatrix_16x16_trans<2>(
&vBuf.template at<true>(qmmaShape.k * idxInstK + rB, idxAtomBx2 + cB));
AtomB const v[2] = {data[0], data[2], data[1], data[3]};
#pragma unroll
auto const data = ldmatrix_16x16_trans<2>(
&vBuf.template at<true>(qmmaShape.k * idxInstK + rB, idxAtomBx2 + cB));
// Reorder (0,2) and (1,3) into two AtomB = Vec<uint32_t,2>
AtomB const v[2] = { AtomB{data[0], data[2]}, AtomB{data[1], data[3]} };
#pragma unroll
🤖 Prompt for AI Agents
In csrc/xqa/mla_sm120.cu around lines 1248 to 1251, the AtomB array is declared
with 2 elements but initialized with 4 scalars; replace the flat 4-value
initializer with nested initializers that form two AtomB values from the four
ldmatrix words (use the four words data[0], data[1], data[2], data[3] to
construct AtomB[0] from data[0] and data[2] and AtomB[1] from data[1] and
data[3], respectively), so the array has exactly two AtomB initializers built
from the four words.

Comment on lines +1775 to +1787
if (std::is_same_v<CacheElem, half>) {
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
} else if (std::is_same_v<CacheElem, __nv_bfloat16>) {
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if (std::is_same_v<CacheElem, __nv_fp8_e4m3>) {
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
}
throw std::runtime_error("unsupported cache element type");
}();

auto const tensorMapQ = makeTensorMapForQ(q, dtype, validElemsPerHead,
headGrpSize * inputSeqLen * batchSize, partElemsK);
#if PAGED_KV_CACHE_LAYOUT == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Ensure Q tensor‑map dtype matches InputHead, not CacheElem.

dtype is derived from CacheElem then reused for Q. If InputHead != CacheElem (e.g., FP16 Q with FP8 KV), TMA encoding for Q will be wrong. Derive a separate qDtype from InputHead.

-  auto const dtype = [] {
+  auto const kvDtype = [] {
     if (std::is_same_v<CacheElem, half>) {
       return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
     } else if (std::is_same_v<CacheElem, __nv_bfloat16>) {
       return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
     } else if (std::is_same_v<CacheElem, __nv_fp8_e4m3>) {
       return CU_TENSOR_MAP_DATA_TYPE_UINT8;
     }
     throw std::runtime_error("unsupported cache element type");
   }();
-
-  auto const tensorMapQ = makeTensorMapForQ(q, dtype, validElemsPerHead,
+  auto const qDtype = [] {
+    if (std::is_same_v<InputHead, half>) {
+      return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
+    } else if (std::is_same_v<InputHead, __nv_bfloat16>) {
+      return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
+    } else if (std::is_same_v<InputHead, __nv_fp8_e4m3>) {
+      return CU_TENSOR_MAP_DATA_TYPE_UINT8;
+    }
+    throw std::runtime_error("unsupported Q element type");
+  }();
+
+  auto const tensorMapQ = makeTensorMapForQ(q, qDtype, validElemsPerHead,
                                             headGrpSize * inputSeqLen * batchSize, partElemsK);
-  auto const tensorMapK = makeTensorMapForPagedKVCache(
-      kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile);
+  auto const tensorMapK = makeTensorMapForPagedKVCache(
+      kCacheVLLM, kvDtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile);
-  auto const tensorMapV = makeTensorMapForPagedKVCache(
-      vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile);
+  auto const tensorMapV = makeTensorMapForPagedKVCache(
+      vCacheVLLM, kvDtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile);

Repeat the same change in launchMLAFlashInfer.

To verify types across the repo:

Also applies to: 1906-1913


🏁 Script executed:

#!/bin/bash
# Confirm InputHead/CacheElem definitions and all dtype usages.
rg -n -C2 'using\s+InputHead|using\s+CacheElem|typedef\s+.*InputHead|typedef\s+.*CacheElem|makeTensorMapForQ\(|makeTensorMapForPagedKVCache\(' --type=cpp

Length of output: 1316


🏁 Script executed:

#!/bin/bash
# Find InputElem definition and verify it can differ from CacheElem
rg -n 'using\s+InputElem|typedef\s+.*InputElem' --type=cpp

# Extract the specific lines mentioned in the review
echo "=== Lines 1775-1787 from mla_sm120.cu ==="
sed -n '1775,1787p' csrc/xqa/mla_sm120.cu

echo "=== Lines 1906-1913 from mla_sm120.cu ==="
sed -n '1906,1913p' csrc/xqa/mla_sm120.cu

Length of output: 1269


🏁 Script executed:

#!/bin/bash
# Find launchMLAFlashInfer function and search for similar dtype pattern
rg -n 'launchMLAFlashInfer' --type=cpp -A 50 | head -100

# Also search for all instances of "auto const dtype" to find all occurrences
echo "=== All 'auto const dtype' occurrences ==="
rg -n 'auto const dtype' --type=cpp

Length of output: 2789


🏁 Script executed:

#!/bin/bash
# Get broader context around line 1775 to see dtype definition
echo "=== Lines 1760-1800 from mla_sm120.cu ==="
sed -n '1760,1800p' csrc/xqa/mla_sm120.cu

# Search for where dtype is declared (broader search)
echo "=== Search for 'dtype' with context ==="
rg -n '\bdtype\b' csrc/xqa/mla_sm120.cu -B 2 -A 2 | head -80

Length of output: 4759


🏁 Script executed:

#!/bin/bash
# Find which functions contain these dtype patterns
echo "=== Function context for line 1774 ==="
sed -n '1700,1800p' csrc/xqa/mla_sm120.cu | grep -n 'void\|__global__\|auto const dtype'

echo "=== Function context for line 1895 ==="
sed -n '1850,1920p' csrc/xqa/mla_sm120.cu | grep -n 'void\|__global__\|auto const dtype'

# Broader check - how many dtype definitions from CacheElem exist?
echo "=== All CacheElem-based dtype patterns ==="
rg -n 'auto const dtype.*=.*\[\].*\{' csrc/xqa/mla_sm120.cu -A 8

Length of output: 1210


Separate Q and KV tensor-map dtypes to match their actual element types.

When InputElem ≠ CacheElem (e.g., FP16 Q with FP8 KV), the current dtype derived only from CacheElem causes incorrect TMA encoding for the Q tensor. Create qDtype from InputHead and kvDtype from CacheElem, then use them accordingly.

Apply the fix to both launchMLA() (lines 1774–1796) and the second function (lines 1895–1918):

  • Rename dtypekvDtype for the CacheElem-based lambda
  • Add a new qDtype lambda deriving from InputHead
  • Use qDtype for makeTensorMapForQ()
  • Use kvDtype for makeTensorMapForPagedKVCache() calls
🤖 Prompt for AI Agents
In csrc/xqa/mla_sm120.cu around lines 1775 to 1787 (and similarly apply the same
change in the second function around lines 1895 to 1918), the code derives a
single dtype from CacheElem and uses it for both Q and KV tensor maps, which is
incorrect when InputElem ≠ CacheElem; change the CacheElem-based lambda name to
kvDtype, add a new qDtype lambda that derives the data type from InputHead, then
call makeTensorMapForQ(...) with qDtype and makeTensorMapForPagedKVCache(...)
(and any other KV tensor-map calls) with kvDtype so Q and KV use their
respective element-type encodings.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
flashinfer/aot.py (2)

374-383: Simplify sm_versions construction (nit).

Current appends are verbose; a small tidy improves readability.

-    sm_versions = []
-    if has_sm90:
-        sm_versions.append(90)
-    if has_sm100:
-        sm_versions.append(100)
-    if has_sm120:
-        sm_versions.append(120)
-    if has_sm121:
-        sm_versions.append(121)
+    sm_versions = [
+        v for v, present in [(90, has_sm90), (100, has_sm100), (120, has_sm120), (121, has_sm121)]
+        if present
+    ]

423-446: Apply optional refactor to deduplicate MLA generation for SM120/121.

The two blocks are identical except sm_version. Looping over (version, flag) tuples reduces repetition while maintaining unique JitSpec names (which include sm_version as a suffix).

-    if has_sm120:
-        for token_per_page in token_per_page_:
-            yield gen_xqa_module_mla(
-                input_dtype=torch.float8_e4m3fn,
-                kv_cache_dtype=torch.float8_e4m3fn,
-                page_size=token_per_page,
-                head_dim=576,
-                head_group_ratio=128,
-                use_sliding_window=False,
-                sm_version=120,
-            )
-
-    if has_sm121:
-        for token_per_page in token_per_page_:
-            yield gen_xqa_module_mla(
-                input_dtype=torch.float8_e4m3fn,
-                kv_cache_dtype=torch.float8_e4m3fn,
-                page_size=token_per_page,
-                head_dim=576,
-                head_group_ratio=128,
-                use_sliding_window=False,
-                sm_version=121,
-            )
+    for sm_v, present in [(120, has_sm120), (121, has_sm121)]:
+        if not present:
+            continue
+        for token_per_page in token_per_page_:
+            yield gen_xqa_module_mla(
+                input_dtype=torch.float8_e4m3fn,
+                kv_cache_dtype=torch.float8_e4m3fn,
+                page_size=token_per_page,
+                head_dim=576,
+                head_group_ratio=128,
+                use_sliding_window=False,
+                sm_version=sm_v,
+            )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e1a5c0c and 94b0f7c.

📒 Files selected for processing (1)
  • flashinfer/aot.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/aot.py (3)
flashinfer/xqa.py (2)
  • xqa (56-93)
  • xqa (124-267)
flashinfer/jit/xqa.py (2)
  • gen_xqa_module (41-116)
  • gen_xqa_module_mla (119-180)
flashinfer/jit/cpp_ext.py (1)
  • get_cuda_version (59-78)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/aot.py (2)

70-70: XQA MLA import wiring — looks good.

The new import aligns with the added MLA generation paths below.


359-369: Docstring is incomplete; update to document the new parameters.

The gen_xqa function at line 358 has been extended with new parameters (fp16_input_, fp8_kv_cache_, has_sm100, has_sm120, has_sm121), but the docstring (line 370) only contains a generic description. Update it to document each parameter's purpose and type, including the new SM version flags.

The call site at line 570 is correctly updated with all new arguments.

Comment on lines +559 to +561
if (
add_xqa and get_cuda_version() > Version("12.8")
): # TODO: Earlier cuda versions have compile issues, will be fixed in future releases
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Fix CUDA version gate: ‘> 12.8’ contradicts the comment; use ‘>= 12.8’.

As written, CUDA exactly 12.8 will skip XQA despite the note implying only earlier versions have issues. Align the condition with the intent.

-    if (
-        add_xqa and get_cuda_version() > Version("12.8")
-    ):  # TODO: Earlier cuda versions have compile issues, will be fixed in future releases
+    if (
+        add_xqa and get_cuda_version() >= Version("12.8")
+    ):  # Earlier CUDA versions (< 12.8) have compile issues; enable for 12.8+

🏁 Script executed:

#!/bin/bash
# Show current CUDA version as parsed to confirm gating behavior around 12.8
rg -n 'def get_cuda_version' -n --type=py -C5
# Quick check of other gates for consistency (>= in detect_sm_capabilities)
rg -nP 'get_cuda_version\(\)\s*[<>]=?\s*Version' --type=py -C2

Length of output: 2238


Fix CUDA version gate: change > Version("12.8") to >= Version("12.8").

The current condition > Version("12.8") excludes CUDA 12.8 itself, contradicting the comment which states only earlier versions have compile issues. Align with the codebase pattern: lines 760 and 82 use >= for version gates.

    if (
-       add_xqa and get_cuda_version() > Version("12.8")
+       add_xqa and get_cuda_version() >= Version("12.8")
    ):  # TODO: Earlier cuda versions have compile issues, will be fixed in future releases
🤖 Prompt for AI Agents
In flashinfer/aot.py around lines 559 to 561, the CUDA version check currently
uses a strict greater-than ("> Version(\"12.8\")") which incorrectly excludes
CUDA 12.8; change the comparison to greater-than-or-equal (">=
Version(\"12.8\")") so CUDA 12.8 is allowed, matching the comment and the
project's version-gate pattern used elsewhere.

seq_len,
tokens_per_page,
use_fp16,
fp16_input,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this parameter more explicit? like input_data_type?

def gen_xqa(
use_fp16_: List[bool],
fp16_input_: List[bool],
fp8_kv_cache_: List[bool],
Copy link
Contributor

@nvmbreughe nvmbreughe Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, maybe we can make the data types explicit instead of a combination of bools.
It seems like the options are:

  •         kv_cache_dtype = torch.float8_e4m3fn
    
  •         kv_cache_dtype = torch.float16
    
  •         kv_cache_dtype = torch.bfloat16
    

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants