-
Notifications
You must be signed in to change notification settings - Fork 573
enable xqa speculative decoding #2105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Qidi Sang <[email protected]>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThe PR stabilizes speculative-decoding parameters across the XQA stack: public/FFI xqa_wrapper signatures now always include qSeqLen and an Optional mask, a SPEC_DEC-specific attention-sinks device routine was added, JIT and Python layers compute/propagate SPEC_DEC and mask, and tests were extended to generate and pass causal masks for multi-token queries. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant DecodePy as decode.py
participant XQAPY as xqa.py
participant JitXQA as jit/xqa.py
participant Binding as flashinfer_xqa_binding.cu
participant Kernel as xqa_wrapper / mha
User->>DecodePy: xqa_batch_decode_with_kv_cache(q_seq_len, mask)
DecodePy->>XQAPY: xqa(..., q_seq_len, mask)
XQAPY->>JitXQA: get_xqa_module(..., q_seq_len)
JitXQA->>Binding: compiled module (flags include SPEC_DEC when triggered)
XQAPY->>Binding: xqa_wrapper(..., qSeqLen, mask, semaphores, scratch)
Binding->>Kernel: invoke xqa_wrapper with maskPtr (nullable)
alt SPEC_DEC active
Kernel->>Kernel: addAttentionSinksSpecDec(...) (SPEC_DEC per-head logic)
else SPEC_DEC inactive
Kernel->>Kernel: addAttentionSinks(...) (existing logic)
end
Kernel-->>Binding: output
Binding-->>XQAPY: output
XQAPY-->>DecodePy: decoded output
DecodePy-->>User: result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the FlashInfer library by integrating speculative decoding capabilities into its XQA kernel. This feature allows the system to process multiple predicted tokens in a single step, potentially leading to substantial speedups in large language model inference. The changes span from low-level CUDA kernel modifications to higher-level Python API adjustments, ensuring a seamless and efficient user experience for speculative decoding. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enables speculative decoding for XQA kernels, which is a great enhancement. The changes across the C++ bindings, Python wrappers, and JIT compilation logic are consistent and well-implemented. The API is updated cleanly to support speculative decoding as a runtime option.
I've included a few review comments with suggestions for minor improvements, mainly focusing on code simplification in a CUDA kernel and improving maintainability in the test helpers by addressing code duplication and style. Overall, this is a solid contribution.
| // In SPEC_DEC, layout is [token0_head0, token0_head1, ..., token1_head0, ...] | ||
| // Extract head index from head-token index | ||
| uint32_t headIdx = idxHeadToken % headGrpSize; | ||
| if (headIdx < headGrpSize && idxHeadToken < rowsPerBlock) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def generate_causal_mask( | ||
| batch_size: int, | ||
| q_seq_len: int, | ||
| device: torch.device, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Generate causal attention mask for speculative decoding. | ||
| Parameters | ||
| ---------- | ||
| batch_size : int | ||
| Batch size | ||
| q_seq_len : int | ||
| Query sequence length (number of speculative decoding tokens) | ||
| device : torch.device | ||
| Target device for the mask tensor | ||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| Causal mask with shape [batch_size, q_seq_len, mask_size_per_row] | ||
| where mask_size_per_row = divUp(q_seq_len, 32) * 2 (in uint16_t units). | ||
| Data type: torch.uint16 | ||
| """ | ||
| import numpy as np | ||
|
|
||
| num_packed_masks_per_token = (q_seq_len + 31) // 32 | ||
|
|
||
| mask_np = np.zeros( | ||
| (batch_size, q_seq_len, num_packed_masks_per_token), dtype=np.uint32 | ||
| ) | ||
|
|
||
| for q_pos in range(q_seq_len): | ||
| for kv_pos in range(q_pos + 1): # Causal: only see previous queries and self | ||
| word_idx = kv_pos // 32 | ||
| bit_in_word = kv_pos % 32 | ||
| if word_idx < num_packed_masks_per_token: | ||
| mask_np[:, q_pos, word_idx] |= np.uint32(1 << bit_in_word) | ||
|
|
||
| mask_uint32 = torch.from_numpy(mask_np).to(device) | ||
| mask_uint16 = mask_uint32.view(torch.uint16) | ||
|
|
||
| return mask_uint16 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The helper function generate_causal_mask has a few areas for improvement:
- Import location: The
import numpy as npstatement is inside the function. According to PEP 8, imports should be at the top of the file to make dependencies clear and avoid re-importing. - Redundant check: The check
if word_idx < num_packed_masks_per_token:on line 379 is redundant. Givenkv_pos <= q_pos < q_seq_len,word_idxwill always be less thannum_packed_masks_per_token. Removing this simplifies the logic.
Consider refactoring this function to address these points.
| def generate_causal_mask( | ||
| batch_size: int, | ||
| q_seq_len: int, | ||
| device: torch.device, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Generate causal attention mask for speculative decoding. | ||
| Parameters | ||
| ---------- | ||
| batch_size : int | ||
| Batch size | ||
| q_seq_len : int | ||
| Query sequence length (number of speculative decoding tokens) | ||
| device : torch.device | ||
| Target device for the mask tensor | ||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| Causal mask with shape [batch_size, q_seq_len, mask_size_per_row] | ||
| where mask_size_per_row = divUp(q_seq_len, 32) * 2 (in uint16_t units). | ||
| Data type: torch.uint16 | ||
| """ | ||
| import numpy as np | ||
|
|
||
| num_packed_masks_per_token = (q_seq_len + 31) // 32 | ||
|
|
||
| mask_np = np.zeros( | ||
| (batch_size, q_seq_len, num_packed_masks_per_token), dtype=np.uint32 | ||
| ) | ||
|
|
||
| for q_pos in range(q_seq_len): | ||
| for kv_pos in range(q_pos + 1): # Causal: only see previous queries and self | ||
| word_idx = kv_pos // 32 | ||
| bit_in_word = kv_pos % 32 | ||
| if word_idx < num_packed_masks_per_token: | ||
| mask_np[:, q_pos, word_idx] |= np.uint32(1 << bit_in_word) | ||
|
|
||
| mask_uint32 = torch.from_numpy(mask_np).to(device) | ||
| mask_uint16 = mask_uint32.view(torch.uint16) | ||
|
|
||
| return mask_uint16 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This generate_causal_mask function has a few areas for improvement:
- Code Duplication: This function is identical to the one in
tests/attention/test_trtllm_gen_attention.py. To improve maintainability, consider moving this helper to a shared test utility file (e.g., intests/test_helpers/) and importing it where needed. - Import location: The
import numpy as npstatement is inside the function. Per PEP 8, it should be at the top of the file. - Redundant check: The check
if word_idx < num_packed_masks_per_token:on line 330 is redundant for the same reasons as in the other file.
Refactoring this would improve the test suite's quality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/xqa.py (1)
301-329: Replaceassertwith explicit mask validation
assertis skipped underpython -O, so speculative decoding could proceed withmask=None, feeding a null (CPU) pointer into the CUDA kernel and causing undefined behaviour or a hard crash. Please switch to an explicit runtime check and validate the tensor before launching the kernel.@@ - if q_seq_len > 1: - assert mask is not None, "Mask is required for speculative decoding" - run_sm90_fp8_mha = ( - False # TODO: mha_sm90.cu has precision issue with speculative decoding - ) + if q_seq_len > 1: + if mask is None: + raise ValueError( + "mask must be provided when q_seq_len > 1 (speculative decoding)" + ) + expected_mask_cols = ((q_seq_len + 31) // 32) * 2 + check_shape_dtype_device( + mask, + (batch_size, q_seq_len, expected_mask_cols), + torch.uint16, + q.device, + "mask", + ) + run_sm90_fp8_mha = False # TODO: mha_sm90.cu has precision issue with speculative decoding
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
csrc/flashinfer_xqa_binding.cu(1 hunks)csrc/xqa/mha.cu(3 hunks)csrc/xqa/xqa_wrapper.cu(3 hunks)flashinfer/decode.py(7 hunks)flashinfer/jit/xqa.py(4 hunks)flashinfer/xqa.py(11 hunks)tests/attention/test_trtllm_gen_attention.py(3 hunks)tests/attention/test_xqa_batch_decode.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
csrc/xqa/mha.cu (1)
csrc/xqa/mha_sm90.cu (20)
void(538-567)void(569-574)void(578-588)void(1650-1684)void(1703-1720)void(1722-1737)void(1748-1794)void(1886-1907)void(1909-1927)void(1959-2041)void(2090-2164)void(2166-2185)void(2188-2206)void(2226-2242)void(2246-2269)void(2306-2330)i(176-176)i(178-178)i(180-180)i(182-182)
csrc/xqa/xqa_wrapper.cu (2)
csrc/xqa/mha_sm90.cu (2)
scratch(496-503)scratch(496-496)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
enable_pdl(220-220)
tests/attention/test_xqa_batch_decode.py (2)
tests/attention/test_trtllm_gen_attention.py (1)
generate_causal_mask(342-385)flashinfer/utils.py (1)
get_compute_capability(253-256)
flashinfer/xqa.py (1)
flashinfer/utils.py (4)
register_custom_op(314-323)register_custom_op(333-352)register_fake_op(325-329)register_fake_op(354-359)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
| def generate_causal_mask( | ||
| batch_size: int, | ||
| q_seq_len: int, | ||
| device: torch.device, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Generate causal attention mask for speculative decoding. | ||
| Parameters | ||
| ---------- | ||
| batch_size : int | ||
| Batch size | ||
| q_seq_len : int | ||
| Query sequence length (number of speculative decoding tokens) | ||
| device : torch.device | ||
| Target device for the mask tensor | ||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| Causal mask with shape [batch_size, q_seq_len, mask_size_per_row] | ||
| where mask_size_per_row = divUp(q_seq_len, 32) * 2 (in uint16_t units). | ||
| Data type: torch.uint16 | ||
| """ | ||
| import numpy as np | ||
|
|
||
| num_packed_masks_per_token = (q_seq_len + 31) // 32 | ||
|
|
||
| mask_np = np.zeros( | ||
| (batch_size, q_seq_len, num_packed_masks_per_token), dtype=np.uint32 | ||
| ) | ||
|
|
||
| for q_pos in range(q_seq_len): | ||
| for kv_pos in range(q_pos + 1): # Causal: only see previous queries and self | ||
| word_idx = kv_pos // 32 | ||
| bit_in_word = kv_pos % 32 | ||
| if word_idx < num_packed_masks_per_token: | ||
| mask_np[:, q_pos, word_idx] |= np.uint32(1 << bit_in_word) | ||
|
|
||
| mask_uint32 = torch.from_numpy(mask_np).to(device) | ||
| mask_uint16 = mask_uint32.view(torch.uint16) | ||
|
|
||
| return mask_uint16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix mask dtype reinterpretation crash
Tensor.view only accepts integer shape arguments; passing torch.uint16 raises a TypeError (“torch.dtype object cannot be interpreted as an integer”), so the new helper will fail before any test even runs. Rework the conversion so we reshape in NumPy (where .view(np.uint16) is valid) and only then materialize the torch tensor on the target device.
- mask_uint32 = torch.from_numpy(mask_np).to(device)
- mask_uint16 = mask_uint32.view(torch.uint16)
-
- return mask_uint16
+ mask_uint16 = torch.from_numpy(
+ mask_np.view(np.uint16).reshape(
+ batch_size, q_seq_len, num_packed_masks_per_token * 2
+ )
+ ).to(device)
+ return mask_uint16🤖 Prompt for AI Agents
In tests/attention/test_trtllm_gen_attention.py around lines 342-385, the code
currently calls mask_uint32.view(torch.uint16) which raises a TypeError because
torch.view expects integer shape arguments and torch.dtype cannot be used that
way; instead reinterpret the uint32 data as uint16 while still in NumPy (e.g.,
cast/view the numpy array to np.uint16 so the last dimension becomes
num_packed_masks_per_token * 2 to match mask_size_per_row), then call
torch.from_numpy(...) and move the tensor to the target device; ensure the numpy
array is contiguous before converting.
|
/bot run |
|
|
||
| num_packed_masks_per_token = (q_seq_len + 31) // 32 | ||
|
|
||
| mask_np = np.zeros( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you vectorize the tensor construction loop? (and make it fully on GPU), we found that the process of creating tensor on cpu and then converting it to GPU tensor is always the bottleneck of unittests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| (batch_size, q_seq_len, num_packed_masks_per_token), dtype=np.uint32 | ||
| ) | ||
|
|
||
| for q_pos in range(q_seq_len): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: Qidi Sang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/xqa/xqa_wrapper.cu(2 hunks)flashinfer/jit/xqa.py(4 hunks)tests/attention/test_trtllm_gen_attention.py(3 hunks)tests/attention/test_xqa_batch_decode.py(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- csrc/xqa/xqa_wrapper.cu
- tests/attention/test_xqa_batch_decode.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_trtllm_gen_attention.py (1)
tests/attention/test_xqa_batch_decode.py (1)
generate_causal_mask(293-350)
flashinfer/jit/xqa.py (1)
flashinfer/jit/core.py (1)
gen_jit_spec(315-381)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
tests/attention/test_trtllm_gen_attention.py (2)
883-887: LGTM: Mask generation logic is correct.The conditional mask generation correctly supplies a causal mask for speculative decoding when
q_len_per_req > 1and passesNoneotherwise.
919-919: LGTM: Mask parameter correctly threaded through.The mask is properly passed to the decode function, enabling speculative decoding support.
flashinfer/jit/xqa.py (4)
42-42: Good addition: q_seq_len parameter enables speculative decoding.The new parameter with a sensible default of 1 (non-speculative) allows the XQA module to adapt to multi-token query scenarios.
102-114: Good refactor: Dynamic source selection based on target architecture.Conditionally including SM90-specific sources (
mha_sm90.cuandtensorMap.cpp) based on target architectures is cleaner and more maintainable than static lists. This prevents unnecessary compilation and linking of unused code for non-SM90 targets.
116-116: LGTM: JIT spec name and flags correctly updated.The generated spec name now includes
use_spec_decto differentiate speculative decoding variants, and the corresponding flags are properly propagated to the CUDA compiler. This ensures correct JIT caching and prevents collisions between different kernel configurations.Also applies to: 127-127
84-92: No action required—threshold heuristic is correctly aligned with kernel constraints.Verification confirms the Python code properly implements the kernel's threshold requirement. The kernel enforces the constraint via
static_assert(specDecQLen * headGrpSize <= 32, ...)incsrc/xqa/mha_sm90.cu(line 42), and the Python JIT code at lines 87–88 applies the identical condition (q_seq_len * head_group_ratio <= 32) before defining theSPEC_Q_SEQ_LENcompile flag. The heuristic is correct and well-aligned.
| def generate_causal_mask( | ||
| batch_size: int, | ||
| q_seq_len: int, | ||
| device: torch.device, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Generate causal attention mask for speculative decoding. | ||
| Parameters | ||
| ---------- | ||
| batch_size : int | ||
| Batch size | ||
| q_seq_len : int | ||
| Query sequence length (number of speculative decoding tokens) | ||
| device : torch.device | ||
| Target device for the mask tensor | ||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| Causal mask with shape [batch_size, q_seq_len, mask_size_per_row] | ||
| where mask_size_per_row = divUp(q_seq_len, 32) * 2 (in uint16_t units). | ||
| Data type: torch.uint16 | ||
| """ | ||
| num_packed_masks_per_token = (q_seq_len + 31) // 32 | ||
|
|
||
| q_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(1) | ||
| kv_indices = torch.arange(q_seq_len, device=device, dtype=torch.int32).unsqueeze(0) | ||
|
|
||
| causal_bool_mask = kv_indices <= q_indices | ||
|
|
||
| padded_seq_len = num_packed_masks_per_token * 32 | ||
| if padded_seq_len > q_seq_len: | ||
| padding = torch.zeros( | ||
| q_seq_len, padded_seq_len - q_seq_len, device=device, dtype=torch.bool | ||
| ) | ||
| causal_bool_mask = torch.cat([causal_bool_mask, padding], dim=1) | ||
|
|
||
| causal_bool_mask = causal_bool_mask.view(q_seq_len, num_packed_masks_per_token, 32) | ||
|
|
||
| bit_positions = torch.tensor( | ||
| [1 << i for i in range(32)], device=device, dtype=torch.int64 | ||
| ) | ||
|
|
||
| mask_uint32 = ( | ||
| (causal_bool_mask.to(torch.int64) * bit_positions).sum(dim=-1).to(torch.uint32) | ||
| ) | ||
|
|
||
| mask_uint32 = ( | ||
| mask_uint32.unsqueeze(0) | ||
| .expand(batch_size, q_seq_len, num_packed_masks_per_token) | ||
| .contiguous() | ||
| ) | ||
|
|
||
| mask_uint16 = mask_uint32.view(torch.uint16) | ||
|
|
||
| return mask_uint16 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Eliminate code duplication by extracting to shared utility.
The generate_causal_mask function is duplicated in tests/attention/test_xqa_batch_decode.py (lines 292-349). Extract this function to a shared test utility module (e.g., tests/test_helpers/mask_utils.py) to avoid maintaining identical code in multiple places. This will also ensure consistency when the critical .view(torch.uint16) bug flagged in previous reviews is fixed.
Additionally, the maintainer (yzh119) has requested vectorization and full GPU implementation to avoid CPU-to-GPU conversion bottlenecks in tests. Consider addressing this when refactoring.
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/xqa/xqa_wrapper.cu (1)
53-104: Hopper FP8 MHA dispatch is disabled in host code; adopt runtime flag instead of CUDA_ARCH guardThe dispatch logic at lines 71–75 will not work as intended:
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 auto const mha_func = run_sm90_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer; #else auto const mha_func = &launchMHAFlashInfer; #endifSince
xqa_wrapperis a host function,__CUDA_ARCH__is not defined during host compilation, so this preprocessor branch always takes the#elsepath. Therun_sm90_fp8_mhaflag is ignored, andlaunchHopperF8MHAFlashInferis never selected, even on SM90 GPUs.Replace the preprocessor guard with a runtime condition:
auto const mha_func = run_sm90_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer;This makes the dispatch purely runtime-based and allows the Python caller's
run_sm90_fp8_mhaflag to actually select the Hopper FP8 path when appropriate.
♻️ Duplicate comments (2)
csrc/xqa/mha.cu (1)
1270-1285: SPEC_DEC attention sinks wiring looks correct; minor redundant condition remainsThe SPEC_DEC-specific
addAttentionSinksSpecDecimplementation and its use in both single-CTA and multi-CTA reductions are consistent with the non-SPEC_DEC path; sinks are applied once per head, with per-row normalization viaglobalRowMax.The condition
if (headIdx < headGrpSize && idxHeadToken < rowsPerBlock)can be simplified, sinceheadIdxis computed asidxHeadToken % headGrpSizeand is therefore always< headGrpSizewhenheadGrpSize > 0. Keeping only theidxHeadToken < rowsPerBlockbound check would slightly simplify the code.This is purely a micro-cleanup; no functional change required.
Also applies to: 1890-1894, 2375-2379
tests/attention/test_xqa_batch_decode.py (1)
293-351: Causal mask helper is correct but duplicated; prefer a shared test utilityThe new
generate_causal_maskmatches the expected packed uint16 causal-mask format and is used correctly to feedmaskintoxqa_batch_decode_with_kv_cachewhenq_len_per_req > 1.However, this helper is effectively identical to the one in
tests/attention/test_trtllm_gen_attention.py. To avoid divergence and make future changes easier, consider moving it into a shared test helper (e.g.,tests/test_helpers/masks.py) and importing it from both test files.Also applies to: 508-512, 530-531
🧹 Nitpick comments (3)
flashinfer/decode.py (2)
2062-2083: Mask parameter is correctly plumbed to the XQA backend; clarify / guard usage for othersAdding
mask: Optional[torch.Tensor] = Nonehere and forwarding it only whenbackend == "xqa"matches the docstring (“causal attention mask for xqa speculative decoding”).Two minor improvements you might consider:
- Explicitly document that
maskis ignored for thetrtllm-genbackend (or raise if it is passed together withbackend="trtllm-gen"to avoid silent surprises).- Optionally assert that any provided
maskis on the same device asquerybefore forwarding, to fail fast on obvious misuse.Also applies to: 2153-2155, 2195-2212
2341-2357: q_len_per_req and mask handling in XQA batch decode are consistent; add a simple shape invariantThe speculative-decoding wiring here is coherent:
- For
q_len_per_req > 1,queryis reshaped from[num_tokens, num_heads, head_dim]into[batch_size, q_len_per_req, num_heads, head_dim]and then into[batch_size, 1, q_len_per_req, num_heads, head_dim]via theunsqueeze(1)in the XQA call site.q_len_per_reqis forwarded asq_seq_len, andmaskis passed through toflashinfer.xqa.xqa, which enforces the “mask required whenq_seq_len > 1” invariant.To make this more robust to API misuse, you might:
- Add a quick check like
assert query.shape[0] % q_len_per_req == 0before theview, so incorrectq_len_per_reqvalues fail loudly.- Optionally sanity-check
maskwhen provided (dtypetorch.uint16, last-dimension length matching the packed format) before forwarding, to surface shape mistakes at the Python boundary instead of failing deep in the CUDA kernel.Also applies to: 2451-2484
flashinfer/xqa.py (1)
34-43: SPEC_DEC integration in XQA wrapper is coherent; consider adding lightweight mask validationThe new
q_seq_lenplumbing and speculative-decoding gating are wired cleanly:
q_seq_lenparticipates in theget_xqa_modulecache key, is forwarded togen_xqa_module, and also influences the op name viause_spec_dec, so SPEC_DEC vs non-SPEC_DEC variants don’t collide.- The registered custom op signature matches the C++
xqa_wrapper(extraq_seq_lenandmaskargs) and the Pythonxqawrapper now forwards these fields consistently.num_q_headsis inferred fromq.shape[-2], which correctly handles both 4D ([B, beam, H, D]) and 5D ([B, beam, q_seq_len, H, D]) query layouts.- For
q_seq_len > 1you assert thatmaskis provided and forcerun_sm90_fp8_mha = False, which is a sensible guard given the noted precision issue in the SM90 MHA path.Two small, non-blocking improvements:
- At the Python level, add a quick check on
maskwhenq_seq_len > 1(e.g., dtypetorch.uint16, 3D shape, and last dimension equal to the packed size implied byq_seq_len) so shape/dtype mistakes are caught before reaching the kernel.- In the docstring, it may be worth explicitly stating that
q_seq_len > 1implies SPEC_DEC mode and requires the packed causal mask, to make this contract clearer to external callers.Also applies to: 52-59, 60-86, 146-167, 245-254, 292-301, 303-308, 329-331
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/flashinfer_xqa_binding.cu(1 hunks)csrc/xqa/mha.cu(3 hunks)csrc/xqa/xqa_wrapper.cu(3 hunks)flashinfer/decode.py(7 hunks)flashinfer/xqa.py(11 hunks)tests/attention/test_xqa_batch_decode.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
flashinfer/xqa.py (2)
flashinfer/jit/core.py (1)
build_and_load(300-312)flashinfer/utils.py (4)
register_custom_op(314-323)register_custom_op(333-352)register_fake_op(325-329)register_fake_op(354-359)
tests/attention/test_xqa_batch_decode.py (2)
tests/attention/test_trtllm_gen_attention.py (1)
generate_causal_mask(342-399)flashinfer/utils.py (1)
get_compute_capability(253-256)
csrc/xqa/xqa_wrapper.cu (2)
csrc/xqa/mha_sm90.cu (4)
scratch(501-508)scratch(501-501)launchHopperF8MHAFlashInfer(3044-3111)launchHopperF8MHAFlashInfer(3044-3057)csrc/xqa/mha.cu (2)
launchMHAFlashInfer(2598-2657)launchMHAFlashInfer(2598-2612)
csrc/flashinfer_xqa_binding.cu (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
enable_pdl(220-220)
csrc/xqa/mha.cu (1)
csrc/xqa/mha_sm90.cu (20)
void(543-572)void(574-579)void(583-593)void(1657-1691)void(1710-1727)void(1729-1744)void(1755-1801)void(1893-1914)void(1916-1934)void(1966-2048)void(2097-2171)void(2173-2192)void(2195-2213)void(2233-2249)void(2253-2276)void(2313-2337)i(181-181)i(183-183)i(185-185)i(187-187)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
csrc/flashinfer_xqa_binding.cu (2)
19-28: Confirm that MLA variant intentionally excludes speculative decoding parameters.The MLA wrapper (
xqa_wrapper_mla) does not include the newqSeqLenandmaskparameters added to the non-MLA path. Please confirm this divergence is intentional—specifically, that speculative decoding is only supported in the non-MLA code path.
31-39: All signature changes verified as consistent across declaration, implementation, and call sites.The function signature has been properly updated across all three layers:
- FFI Declaration (csrc/flashinfer_xqa_binding.cu:31-39): Includes
int64_t qSeqLenandtvm::ffi::Optional<TensorView> mask- Implementation (csrc/xqa/xqa_wrapper.cu:53-60): Matches declaration exactly with identical parameter positions
- Python Caller (flashinfer/xqa.py:87+): Passes all 23 parameters including
q_seq_lenandmaskin the correct positionsThe implementation correctly extracts and forwards the new parameters (
qSeqLen,maskPtr) to the kernel under theSPEC_DECconditional. Thexqa_wrapper_mlavariant intentionally excludes these parameters, reflecting its separate code path.
📌 Description
Enable xqa with speculative decoding and add mask tensor in trtllm_batch_decode_with_kv_cache.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
API
JIT / Build
Backend / Runtime
Tests
✏️ Tip: You can customize this high-level summary in your review settings.