Skip to content

Conversation

@qsang-nv
Copy link
Collaborator

@qsang-nv qsang-nv commented Nov 18, 2025

📌 Description

Enable xqa with speculative decoding and add mask tensor in trtllm_batch_decode_with_kv_cache.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Speculative decoding: multi-token query support (q_seq_len) with optional attention mask threaded end-to-end.
  • API

    • Public APIs updated to accept q_seq_len and an optional mask; automatic reshaping and runtime checks for multi-token decoding.
  • JIT / Build

    • JIT now emits SPEC_DEC-enabled variants and includes spec-dec flags in generated specs.
  • Backend / Runtime

    • Mask propagation and architecture-aware backend selection improved for compatible kernels.
  • Tests

    • Added helpers and tests to generate causal masks and validate multi-token speculative decoding.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Qidi Sang <[email protected]>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 18, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

The PR stabilizes speculative-decoding parameters across the XQA stack: public/FFI xqa_wrapper signatures now always include qSeqLen and an Optional mask, a SPEC_DEC-specific attention-sinks device routine was added, JIT and Python layers compute/propagate SPEC_DEC and mask, and tests were extended to generate and pass causal masks for multi-token queries.

Changes

Cohort / File(s) Summary
CUDA Binding Signature Updates
csrc/flashinfer_xqa_binding.cu
Non-MLA xqa_wrapper signature now always includes int64_t qSeqLen and tvm::ffi::Optional<TensorView> mask, reorders parameters (kvScaleTensor, qSeqLen, mask, semaphores, scratch, enable_pdl); previous #if SPEC_DEC parameter injection removed; MLA variant unchanged.
MHA / Attention Sink SPEC_DEC Variant
csrc/xqa/mha.cu
Added __device__ inline void addAttentionSinksSpecDec(...) under #if SPEC_DEC; per-block and per-tile reduction paths call this SPEC_DEC variant when enabled, non-SPEC_DEC paths unchanged.
XQA Wrapper Implementation
csrc/xqa/xqa_wrapper.cu
Updated public xqa_wrapper signature to match binding; under #if SPEC_DEC extract nullable MaskType const* maskPtr from Optional<TensorView> and pass qSeqLen, nullptr, maskPtr into downstream kernels; architecture-aware MHA selection retained.
Python Decode API
flashinfer/decode.py
Added optional mask: Optional[torch.Tensor] = None to trtllm_batch_decode_with_kv_cache and xqa_batch_decode_with_kv_cache; allow q_len_per_req > 1 for XQA by reshaping queries and forward q_seq_len and mask to backend calls.
JIT Module Generation
flashinfer/jit/xqa.py
Added q_seq_len: int = 1 to gen_xqa_module; compute use_spec_dec from q_seq_len and head_group_ratio, inject SPEC_DEC flags and include use_spec_dec in JIT spec name; conditionally include SM90-specific sources when applicable.
XQA Public API
flashinfer/xqa.py
Added q_seq_len and mask to public/internal xqa and _fake_xqa signatures and propagated them to the compiled module call; require mask when q_seq_len > 1; changed head inference to q.shape[-2].
Tests & Test Utilities
tests/attention/test_xqa_batch_decode.py, tests/attention/test_trtllm_gen_attention.py
Added generate_causal_mask(batch_size, q_seq_len, device) helper; removed skips for speculative decoding; when q_len_per_req > 1 tests generate and pass causal masks through decode paths and wrappers.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant DecodePy as decode.py
    participant XQAPY as xqa.py
    participant JitXQA as jit/xqa.py
    participant Binding as flashinfer_xqa_binding.cu
    participant Kernel as xqa_wrapper / mha

    User->>DecodePy: xqa_batch_decode_with_kv_cache(q_seq_len, mask)
    DecodePy->>XQAPY: xqa(..., q_seq_len, mask)
    XQAPY->>JitXQA: get_xqa_module(..., q_seq_len)
    JitXQA->>Binding: compiled module (flags include SPEC_DEC when triggered)
    XQAPY->>Binding: xqa_wrapper(..., qSeqLen, mask, semaphores, scratch)
    Binding->>Kernel: invoke xqa_wrapper with maskPtr (nullable)
    alt SPEC_DEC active
        Kernel->>Kernel: addAttentionSinksSpecDec(...) (SPEC_DEC per-head logic)
    else SPEC_DEC inactive
        Kernel->>Kernel: addAttentionSinks(...) (existing logic)
    end
    Kernel-->>Binding: output
    Binding-->>XQAPY: output
    XQAPY-->>DecodePy: decoded output
    DecodePy-->>User: result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Pay special attention to FFI/ABI ordering and nullable maskPtr lifetime in csrc/flashinfer_xqa_binding.cu.
  • Review math, indexing, and normalization in csrc/xqa/mha.cu::addAttentionSinksSpecDec.
  • Verify JIT flag computation, conditional source inclusion, and deterministic spec naming in flashinfer/jit/xqa.py.
  • Check Python mask packing/format and runtime checks (require mask when q_seq_len > 1) in flashinfer/xqa.py and tests.

Possibly related PRs

Suggested reviewers

  • cyx-6
  • djmmoss
  • wenscarl
  • aleozlx
  • nvmbreughe
  • jiahanc

Poem

🐇 I nudged the kernels, mask in paw,

q‑seqs stretched out for speculative awe.
Heads counted neat in token rows,
Sinks pooled where soft attention flows,
Outputs hop forth — I clap my paws.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'enable xqa speculative decoding' clearly and specifically describes the main change of the pull request.
Description check ✅ Passed The description identifies the main changes (enabling XQA with speculative decoding and adding mask tensor) but incomplete on checklist items and lacks related issues reference.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the FlashInfer library by integrating speculative decoding capabilities into its XQA kernel. This feature allows the system to process multiple predicted tokens in a single step, potentially leading to substantial speedups in large language model inference. The changes span from low-level CUDA kernel modifications to higher-level Python API adjustments, ensuring a seamless and efficient user experience for speculative decoding.

Highlights

  • Speculative Decoding Support: Enabled speculative decoding for XQA (eXtended Query Attention) kernels, allowing for faster inference by processing multiple tokens simultaneously.
  • Causal Mask Integration: Introduced a "mask" tensor to handle causal attention requirements during speculative decoding, ensuring correct token dependencies.
  • API Extension: Updated the Python APIs (trtllm_batch_decode_with_kv_cache and xqa_batch_decode_with_kv_cache) to accept the new "mask" parameter and "q_seq_len" for speculative decoding.
  • Dynamic Kernel Compilation: Modified the JIT compilation system to dynamically build XQA kernels with SPEC_DEC flags based on whether speculative decoding is active, optimizing performance.
  • Test Coverage: Added new test cases and a "generate_causal_mask" utility to validate the correctness of speculative decoding functionality across different configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables speculative decoding for XQA kernels, which is a great enhancement. The changes across the C++ bindings, Python wrappers, and JIT compilation logic are consistent and well-implemented. The API is updated cleanly to support speculative decoding as a runtime option.

I've included a few review comments with suggestions for minor improvements, mainly focusing on code simplification in a CUDA kernel and improving maintainability in the test helpers by addressing code duplication and style. Overall, this is a solid contribution.

// In SPEC_DEC, layout is [token0_head0, token0_head1, ..., token1_head0, ...]
// Extract head index from head-token index
uint32_t headIdx = idxHeadToken % headGrpSize;
if (headIdx < headGrpSize && idxHeadToken < rowsPerBlock) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check headIdx < headGrpSize is redundant. headIdx is calculated as idxHeadToken % headGrpSize. The result of the modulo operator % with a positive headGrpSize will always be in the range [0, headGrpSize - 1]. Removing this redundant check simplifies the code.

    if (idxHeadToken < rowsPerBlock) {

Comment on lines 342 to 386
def generate_causal_mask(
batch_size: int,
q_seq_len: int,
device: torch.device,
) -> torch.Tensor:
"""
Generate causal attention mask for speculative decoding.
Parameters
----------
batch_size : int
Batch size
q_seq_len : int
Query sequence length (number of speculative decoding tokens)
device : torch.device
Target device for the mask tensor
Returns
-------
torch.Tensor
Causal mask with shape [batch_size, q_seq_len, mask_size_per_row]
where mask_size_per_row = divUp(q_seq_len, 32) * 2 (in uint16_t units).
Data type: torch.uint16
"""
import numpy as np

num_packed_masks_per_token = (q_seq_len + 31) // 32

mask_np = np.zeros(
(batch_size, q_seq_len, num_packed_masks_per_token), dtype=np.uint32
)

for q_pos in range(q_seq_len):
for kv_pos in range(q_pos + 1): # Causal: only see previous queries and self
word_idx = kv_pos // 32
bit_in_word = kv_pos % 32
if word_idx < num_packed_masks_per_token:
mask_np[:, q_pos, word_idx] |= np.uint32(1 << bit_in_word)

mask_uint32 = torch.from_numpy(mask_np).to(device)
mask_uint16 = mask_uint32.view(torch.uint16)

return mask_uint16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper function generate_causal_mask has a few areas for improvement:

  • Import location: The import numpy as np statement is inside the function. According to PEP 8, imports should be at the top of the file to make dependencies clear and avoid re-importing.
  • Redundant check: The check if word_idx < num_packed_masks_per_token: on line 379 is redundant. Given kv_pos <= q_pos < q_seq_len, word_idx will always be less than num_packed_masks_per_token. Removing this simplifies the logic.

Consider refactoring this function to address these points.

Comment on lines 293 to 337
def generate_causal_mask(
batch_size: int,
q_seq_len: int,
device: torch.device,
) -> torch.Tensor:
"""
Generate causal attention mask for speculative decoding.
Parameters
----------
batch_size : int
Batch size
q_seq_len : int
Query sequence length (number of speculative decoding tokens)
device : torch.device
Target device for the mask tensor
Returns
-------
torch.Tensor
Causal mask with shape [batch_size, q_seq_len, mask_size_per_row]
where mask_size_per_row = divUp(q_seq_len, 32) * 2 (in uint16_t units).
Data type: torch.uint16
"""
import numpy as np

num_packed_masks_per_token = (q_seq_len + 31) // 32

mask_np = np.zeros(
(batch_size, q_seq_len, num_packed_masks_per_token), dtype=np.uint32
)

for q_pos in range(q_seq_len):
for kv_pos in range(q_pos + 1): # Causal: only see previous queries and self
word_idx = kv_pos // 32
bit_in_word = kv_pos % 32
if word_idx < num_packed_masks_per_token:
mask_np[:, q_pos, word_idx] |= np.uint32(1 << bit_in_word)

mask_uint32 = torch.from_numpy(mask_np).to(device)
mask_uint16 = mask_uint32.view(torch.uint16)

return mask_uint16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This generate_causal_mask function has a few areas for improvement:

  • Code Duplication: This function is identical to the one in tests/attention/test_trtllm_gen_attention.py. To improve maintainability, consider moving this helper to a shared test utility file (e.g., in tests/test_helpers/) and importing it where needed.
  • Import location: The import numpy as np statement is inside the function. Per PEP 8, it should be at the top of the file.
  • Redundant check: The check if word_idx < num_packed_masks_per_token: on line 330 is redundant for the same reasons as in the other file.

Refactoring this would improve the test suite's quality.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/xqa.py (1)

301-329: Replace assert with explicit mask validation

assert is skipped under python -O, so speculative decoding could proceed with mask=None, feeding a null (CPU) pointer into the CUDA kernel and causing undefined behaviour or a hard crash. Please switch to an explicit runtime check and validate the tensor before launching the kernel.

@@
-    if q_seq_len > 1:
-        assert mask is not None, "Mask is required for speculative decoding"
-        run_sm90_fp8_mha = (
-            False  # TODO: mha_sm90.cu has precision issue with speculative decoding
-        )
+    if q_seq_len > 1:
+        if mask is None:
+            raise ValueError(
+                "mask must be provided when q_seq_len > 1 (speculative decoding)"
+            )
+        expected_mask_cols = ((q_seq_len + 31) // 32) * 2
+        check_shape_dtype_device(
+            mask,
+            (batch_size, q_seq_len, expected_mask_cols),
+            torch.uint16,
+            q.device,
+            "mask",
+        )
+        run_sm90_fp8_mha = False  # TODO: mha_sm90.cu has precision issue with speculative decoding
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a9f71bd and 95ce39b.

📒 Files selected for processing (8)
  • csrc/flashinfer_xqa_binding.cu (1 hunks)
  • csrc/xqa/mha.cu (3 hunks)
  • csrc/xqa/xqa_wrapper.cu (3 hunks)
  • flashinfer/decode.py (7 hunks)
  • flashinfer/jit/xqa.py (4 hunks)
  • flashinfer/xqa.py (11 hunks)
  • tests/attention/test_trtllm_gen_attention.py (3 hunks)
  • tests/attention/test_xqa_batch_decode.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
csrc/xqa/mha.cu (1)
csrc/xqa/mha_sm90.cu (20)
  • void (538-567)
  • void (569-574)
  • void (578-588)
  • void (1650-1684)
  • void (1703-1720)
  • void (1722-1737)
  • void (1748-1794)
  • void (1886-1907)
  • void (1909-1927)
  • void (1959-2041)
  • void (2090-2164)
  • void (2166-2185)
  • void (2188-2206)
  • void (2226-2242)
  • void (2246-2269)
  • void (2306-2330)
  • i (176-176)
  • i (178-178)
  • i (180-180)
  • i (182-182)
csrc/xqa/xqa_wrapper.cu (2)
csrc/xqa/mha_sm90.cu (2)
  • scratch (496-503)
  • scratch (496-496)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
tests/attention/test_xqa_batch_decode.py (2)
tests/attention/test_trtllm_gen_attention.py (1)
  • generate_causal_mask (342-385)
flashinfer/utils.py (1)
  • get_compute_capability (253-256)
flashinfer/xqa.py (1)
flashinfer/utils.py (4)
  • register_custom_op (314-323)
  • register_custom_op (333-352)
  • register_fake_op (325-329)
  • register_fake_op (354-359)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment on lines 342 to 385
def generate_causal_mask(
batch_size: int,
q_seq_len: int,
device: torch.device,
) -> torch.Tensor:
"""
Generate causal attention mask for speculative decoding.
Parameters
----------
batch_size : int
Batch size
q_seq_len : int
Query sequence length (number of speculative decoding tokens)
device : torch.device
Target device for the mask tensor
Returns
-------
torch.Tensor
Causal mask with shape [batch_size, q_seq_len, mask_size_per_row]
where mask_size_per_row = divUp(q_seq_len, 32) * 2 (in uint16_t units).
Data type: torch.uint16
"""
import numpy as np

num_packed_masks_per_token = (q_seq_len + 31) // 32

mask_np = np.zeros(
(batch_size, q_seq_len, num_packed_masks_per_token), dtype=np.uint32
)

for q_pos in range(q_seq_len):
for kv_pos in range(q_pos + 1): # Causal: only see previous queries and self
word_idx = kv_pos // 32
bit_in_word = kv_pos % 32
if word_idx < num_packed_masks_per_token:
mask_np[:, q_pos, word_idx] |= np.uint32(1 << bit_in_word)

mask_uint32 = torch.from_numpy(mask_np).to(device)
mask_uint16 = mask_uint32.view(torch.uint16)

return mask_uint16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix mask dtype reinterpretation crash

Tensor.view only accepts integer shape arguments; passing torch.uint16 raises a TypeError (“torch.dtype object cannot be interpreted as an integer”), so the new helper will fail before any test even runs. Rework the conversion so we reshape in NumPy (where .view(np.uint16) is valid) and only then materialize the torch tensor on the target device.

-    mask_uint32 = torch.from_numpy(mask_np).to(device)
-    mask_uint16 = mask_uint32.view(torch.uint16)
-
-    return mask_uint16
+    mask_uint16 = torch.from_numpy(
+        mask_np.view(np.uint16).reshape(
+            batch_size, q_seq_len, num_packed_masks_per_token * 2
+        )
+    ).to(device)
+    return mask_uint16
🤖 Prompt for AI Agents
In tests/attention/test_trtllm_gen_attention.py around lines 342-385, the code
currently calls mask_uint32.view(torch.uint16) which raises a TypeError because
torch.view expects integer shape arguments and torch.dtype cannot be used that
way; instead reinterpret the uint32 data as uint16 while still in NumPy (e.g.,
cast/view the numpy array to np.uint16 so the last dimension becomes
num_packed_masks_per_token * 2 to match mask_size_per_row), then call
torch.from_numpy(...) and move the tensor to the target device; ensure the numpy
array is contiguous before converting.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 19, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !148 has been created, and the CI pipeline #38770160 is currently running. I'll report back once the pipeline job completes.


num_packed_masks_per_token = (q_seq_len + 31) // 32

mask_np = np.zeros(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you vectorize the tensor construction loop? (and make it fully on GPU), we found that the process of creating tensor on cpu and then converting it to GPU tensor is always the bottleneck of unittests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

(batch_size, q_seq_len, num_packed_masks_per_token), dtype=np.uint32
)

for q_pos in range(q_seq_len):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9d9e4fc and fd2527b.

📒 Files selected for processing (4)
  • csrc/xqa/xqa_wrapper.cu (2 hunks)
  • flashinfer/jit/xqa.py (4 hunks)
  • tests/attention/test_trtllm_gen_attention.py (3 hunks)
  • tests/attention/test_xqa_batch_decode.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • csrc/xqa/xqa_wrapper.cu
  • tests/attention/test_xqa_batch_decode.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_trtllm_gen_attention.py (1)
tests/attention/test_xqa_batch_decode.py (1)
  • generate_causal_mask (293-350)
flashinfer/jit/xqa.py (1)
flashinfer/jit/core.py (1)
  • gen_jit_spec (315-381)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
tests/attention/test_trtllm_gen_attention.py (2)

883-887: LGTM: Mask generation logic is correct.

The conditional mask generation correctly supplies a causal mask for speculative decoding when q_len_per_req > 1 and passes None otherwise.


919-919: LGTM: Mask parameter correctly threaded through.

The mask is properly passed to the decode function, enabling speculative decoding support.

flashinfer/jit/xqa.py (4)

42-42: Good addition: q_seq_len parameter enables speculative decoding.

The new parameter with a sensible default of 1 (non-speculative) allows the XQA module to adapt to multi-token query scenarios.


102-114: Good refactor: Dynamic source selection based on target architecture.

Conditionally including SM90-specific sources (mha_sm90.cu and tensorMap.cpp) based on target architectures is cleaner and more maintainable than static lists. This prevents unnecessary compilation and linking of unused code for non-SM90 targets.


116-116: LGTM: JIT spec name and flags correctly updated.

The generated spec name now includes use_spec_dec to differentiate speculative decoding variants, and the corresponding flags are properly propagated to the CUDA compiler. This ensures correct JIT caching and prevents collisions between different kernel configurations.

Also applies to: 127-127


84-92: No action required—threshold heuristic is correctly aligned with kernel constraints.

Verification confirms the Python code properly implements the kernel's threshold requirement. The kernel enforces the constraint via static_assert(specDecQLen * headGrpSize <= 32, ...) in csrc/xqa/mha_sm90.cu (line 42), and the Python JIT code at lines 87–88 applies the identical condition (q_seq_len * head_group_ratio <= 32) before defining the SPEC_Q_SEQ_LEN compile flag. The heuristic is correct and well-aligned.

Comment on lines +342 to +400
def generate_causal_mask(
batch_size: int,
q_seq_len: int,
device: torch.device,
) -> torch.Tensor:
"""
Generate causal attention mask for speculative decoding.
Parameters
----------
batch_size : int
Batch size
q_seq_len : int
Query sequence length (number of speculative decoding tokens)
device : torch.device
Target device for the mask tensor
Returns
-------
torch.Tensor
Causal mask with shape [batch_size, q_seq_len, mask_size_per_row]
where mask_size_per_row = divUp(q_seq_len, 32) * 2 (in uint16_t units).
Data type: torch.uint16
"""
num_packed_masks_per_token = (q_seq_len + 31) // 32

q_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(1)
kv_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(0)

causal_bool_mask = kv_indices <= q_indices

padded_seq_len = num_packed_masks_per_token * 32
if padded_seq_len > q_seq_len:
padding = torch.zeros(
q_seq_len, padded_seq_len - q_seq_len, device=device, dtype=torch.bool
)
causal_bool_mask = torch.cat([causal_bool_mask, padding], dim=1)

causal_bool_mask = causal_bool_mask.view(q_seq_len, num_packed_masks_per_token, 32)

bit_positions = torch.tensor(
[1 << i for i in range(32)], device=device, dtype=torch.int64
)

mask_uint32 = (
(causal_bool_mask.to(torch.int64) * bit_positions).sum(dim=-1).to(torch.uint32)
)

mask_uint32 = (
mask_uint32.unsqueeze(0)
.expand(batch_size, q_seq_len, num_packed_masks_per_token)
.contiguous()
)

mask_uint16 = mask_uint32.view(torch.uint16)

return mask_uint16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Eliminate code duplication by extracting to shared utility.

The generate_causal_mask function is duplicated in tests/attention/test_xqa_batch_decode.py (lines 292-349). Extract this function to a shared test utility module (e.g., tests/test_helpers/mask_utils.py) to avoid maintaining identical code in multiple places. This will also ensure consistency when the critical .view(torch.uint16) bug flagged in previous reviews is fixed.

Additionally, the maintainer (yzh119) has requested vectorization and full GPU implementation to avoid CPU-to-GPU conversion bottlenecks in tests. Consider addressing this when refactoring.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 19, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !148 has been updated with latest changes, and the CI pipeline #38782417 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/xqa/xqa_wrapper.cu (1)

53-104: Hopper FP8 MHA dispatch is disabled in host code; adopt runtime flag instead of CUDA_ARCH guard

The dispatch logic at lines 71–75 will not work as intended:

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900
  auto const mha_func = run_sm90_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer;
#else
  auto const mha_func = &launchMHAFlashInfer;
#endif

Since xqa_wrapper is a host function, __CUDA_ARCH__ is not defined during host compilation, so this preprocessor branch always takes the #else path. The run_sm90_fp8_mha flag is ignored, and launchHopperF8MHAFlashInfer is never selected, even on SM90 GPUs.

Replace the preprocessor guard with a runtime condition:

auto const mha_func = run_sm90_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer;

This makes the dispatch purely runtime-based and allows the Python caller's run_sm90_fp8_mha flag to actually select the Hopper FP8 path when appropriate.

♻️ Duplicate comments (2)
csrc/xqa/mha.cu (1)

1270-1285: SPEC_DEC attention sinks wiring looks correct; minor redundant condition remains

The SPEC_DEC-specific addAttentionSinksSpecDec implementation and its use in both single-CTA and multi-CTA reductions are consistent with the non-SPEC_DEC path; sinks are applied once per head, with per-row normalization via globalRowMax.

The condition if (headIdx < headGrpSize && idxHeadToken < rowsPerBlock) can be simplified, since headIdx is computed as idxHeadToken % headGrpSize and is therefore always < headGrpSize when headGrpSize > 0. Keeping only the idxHeadToken < rowsPerBlock bound check would slightly simplify the code.

This is purely a micro-cleanup; no functional change required.

Also applies to: 1890-1894, 2375-2379

tests/attention/test_xqa_batch_decode.py (1)

293-351: Causal mask helper is correct but duplicated; prefer a shared test utility

The new generate_causal_mask matches the expected packed uint16 causal-mask format and is used correctly to feed mask into xqa_batch_decode_with_kv_cache when q_len_per_req > 1.

However, this helper is effectively identical to the one in tests/attention/test_trtllm_gen_attention.py. To avoid divergence and make future changes easier, consider moving it into a shared test helper (e.g., tests/test_helpers/masks.py) and importing it from both test files.

Also applies to: 508-512, 530-531

🧹 Nitpick comments (3)
flashinfer/decode.py (2)

2062-2083: Mask parameter is correctly plumbed to the XQA backend; clarify / guard usage for others

Adding mask: Optional[torch.Tensor] = None here and forwarding it only when backend == "xqa" matches the docstring (“causal attention mask for xqa speculative decoding”).

Two minor improvements you might consider:

  • Explicitly document that mask is ignored for the trtllm-gen backend (or raise if it is passed together with backend="trtllm-gen" to avoid silent surprises).
  • Optionally assert that any provided mask is on the same device as query before forwarding, to fail fast on obvious misuse.

Also applies to: 2153-2155, 2195-2212


2341-2357: q_len_per_req and mask handling in XQA batch decode are consistent; add a simple shape invariant

The speculative-decoding wiring here is coherent:

  • For q_len_per_req > 1, query is reshaped from [num_tokens, num_heads, head_dim] into [batch_size, q_len_per_req, num_heads, head_dim] and then into [batch_size, 1, q_len_per_req, num_heads, head_dim] via the unsqueeze(1) in the XQA call site.
  • q_len_per_req is forwarded as q_seq_len, and mask is passed through to flashinfer.xqa.xqa, which enforces the “mask required when q_seq_len > 1” invariant.

To make this more robust to API misuse, you might:

  • Add a quick check like assert query.shape[0] % q_len_per_req == 0 before the view, so incorrect q_len_per_req values fail loudly.
  • Optionally sanity-check mask when provided (dtype torch.uint16, last-dimension length matching the packed format) before forwarding, to surface shape mistakes at the Python boundary instead of failing deep in the CUDA kernel.

Also applies to: 2451-2484

flashinfer/xqa.py (1)

34-43: SPEC_DEC integration in XQA wrapper is coherent; consider adding lightweight mask validation

The new q_seq_len plumbing and speculative-decoding gating are wired cleanly:

  • q_seq_len participates in the get_xqa_module cache key, is forwarded to gen_xqa_module, and also influences the op name via use_spec_dec, so SPEC_DEC vs non-SPEC_DEC variants don’t collide.
  • The registered custom op signature matches the C++ xqa_wrapper (extra q_seq_len and mask args) and the Python xqa wrapper now forwards these fields consistently.
  • num_q_heads is inferred from q.shape[-2], which correctly handles both 4D ([B, beam, H, D]) and 5D ([B, beam, q_seq_len, H, D]) query layouts.
  • For q_seq_len > 1 you assert that mask is provided and force run_sm90_fp8_mha = False, which is a sensible guard given the noted precision issue in the SM90 MHA path.

Two small, non-blocking improvements:

  • At the Python level, add a quick check on mask when q_seq_len > 1 (e.g., dtype torch.uint16, 3D shape, and last dimension equal to the packed size implied by q_seq_len) so shape/dtype mistakes are caught before reaching the kernel.
  • In the docstring, it may be worth explicitly stating that q_seq_len > 1 implies SPEC_DEC mode and requires the packed causal mask, to make this contract clearer to external callers.

Also applies to: 52-59, 60-86, 146-167, 245-254, 292-301, 303-308, 329-331

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fd2527b and 07c8742.

📒 Files selected for processing (6)
  • csrc/flashinfer_xqa_binding.cu (1 hunks)
  • csrc/xqa/mha.cu (3 hunks)
  • csrc/xqa/xqa_wrapper.cu (3 hunks)
  • flashinfer/decode.py (7 hunks)
  • flashinfer/xqa.py (11 hunks)
  • tests/attention/test_xqa_batch_decode.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
flashinfer/xqa.py (2)
flashinfer/jit/core.py (1)
  • build_and_load (300-312)
flashinfer/utils.py (4)
  • register_custom_op (314-323)
  • register_custom_op (333-352)
  • register_fake_op (325-329)
  • register_fake_op (354-359)
tests/attention/test_xqa_batch_decode.py (2)
tests/attention/test_trtllm_gen_attention.py (1)
  • generate_causal_mask (342-399)
flashinfer/utils.py (1)
  • get_compute_capability (253-256)
csrc/xqa/xqa_wrapper.cu (2)
csrc/xqa/mha_sm90.cu (4)
  • scratch (501-508)
  • scratch (501-501)
  • launchHopperF8MHAFlashInfer (3044-3111)
  • launchHopperF8MHAFlashInfer (3044-3057)
csrc/xqa/mha.cu (2)
  • launchMHAFlashInfer (2598-2657)
  • launchMHAFlashInfer (2598-2612)
csrc/flashinfer_xqa_binding.cu (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
csrc/xqa/mha.cu (1)
csrc/xqa/mha_sm90.cu (20)
  • void (543-572)
  • void (574-579)
  • void (583-593)
  • void (1657-1691)
  • void (1710-1727)
  • void (1729-1744)
  • void (1755-1801)
  • void (1893-1914)
  • void (1916-1934)
  • void (1966-2048)
  • void (2097-2171)
  • void (2173-2192)
  • void (2195-2213)
  • void (2233-2249)
  • void (2253-2276)
  • void (2313-2337)
  • i (181-181)
  • i (183-183)
  • i (185-185)
  • i (187-187)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
csrc/flashinfer_xqa_binding.cu (2)

19-28: Confirm that MLA variant intentionally excludes speculative decoding parameters.

The MLA wrapper (xqa_wrapper_mla) does not include the new qSeqLen and mask parameters added to the non-MLA path. Please confirm this divergence is intentional—specifically, that speculative decoding is only supported in the non-MLA code path.


31-39: All signature changes verified as consistent across declaration, implementation, and call sites.

The function signature has been properly updated across all three layers:

  1. FFI Declaration (csrc/flashinfer_xqa_binding.cu:31-39): Includes int64_t qSeqLen and tvm::ffi::Optional<TensorView> mask
  2. Implementation (csrc/xqa/xqa_wrapper.cu:53-60): Matches declaration exactly with identical parameter positions
  3. Python Caller (flashinfer/xqa.py:87+): Passes all 23 parameters including q_seq_len and mask in the correct positions

The implementation correctly extracts and forwards the new parameters (qSeqLen, maskPtr) to the kernel under the SPEC_DEC conditional. The xqa_wrapper_mla variant intentionally excludes these parameters, reflecting its separate code path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants