Skip to content

Conversation

@raayandhar
Copy link
Contributor

@raayandhar raayandhar commented Nov 19, 2025

📌 Description

With the new changes we should be able to support arbitrary head dim using the RopeQuantizeKernel, and I have routed the BatchQKApplyRotaryPosIdsCosSinCache to do so.

🔍 Related Issues

#2104

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.). NOTE: There were a set of tests where I got this error:
/tmp/tmp_v7jd1rh/cuda_utils.c:6:10: fatal error: Python.h: No such file or directory
    6 | #include <Python.h>
      |          ^~~~~~~~~~
compilation terminated.

which I know is related to my system. Unfortunately I do not manage this system and it does not have docker, so trying to fix this is a bit difficult. Hopefully someone else can verify my tests or run CI. All other tests were passing, and all the failing tests had that error.

Reviewer Notes

Please let me know if there's a smarter way to get around this hack or if other tests should be updated. Also I think we should remove the older kernel but let me know if we should do otherwise. I also need to test perf.

Summary by CodeRabbit

  • Bug Fixes

    • Fixed edge-case handling for rotary position embeddings when embedding dimensions don't align with processing chunk sizes, avoiding incorrect writes and ensuring correct tail-chunk behavior.
    • Added stricter validation for rotary dimensions vs. head dimensions to surface errors early.
    • Improved robustness of dynamic kernel selection and execution.
  • Tests

    • Expanded test coverage for rotary position embedding configurations across various dimension, batch, and interleave combinations.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 19, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Added a device helper to handle partial RoPE quantization chunks, refactored kernel dispatch to compute vector/block sizes dynamically, routed RoPE-quantized paths through RopeQuantize, and expanded cos/sin cache tests with four new configurations.

Changes

Cohort / File(s) Summary
RoPE quantization & kernels
include/flashinfer/pos_enc.cuh
Added scale_store_partial_chunk device helper to zero-pad and safely store tail lanes; replaced per-element vector loads/stores with guarded partial-chunk writes across RopeQuantizeKernel and Q/K RoPE paths; reworked dynamic dispatch (vec_size, bdx, bdy, block/grid) and routed RoPE-quantized flows through RopeQuantize.
RoPE tests
tests/attention/test_rope.py
Added four new test configurations to test_rope_cos_sin_cache parameter set to increase coverage for head sizes, rotary_dims, batch/sequence shapes, and interleaving modes.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Host
    participant Dispatch as KernelDispatch
    participant Device as GPU
    Note over Host,Dispatch: Host launches RoPE/quant kernels
    Host->>Dispatch: prepare params (head_dim, rotary_dim, no_rope_dim,...)
    Dispatch-->>Device: select kernel (RopeQuantize / other)
    alt RoPE-quantized path
        Device->>Device: RopeQuantizeKernel executes
        Device->>Device: call scale_store_partial_chunk for tail lanes
    else Non-RoPE or full-chunk
        Device->>Device: regular kernel vector loads/stores
    end
    Device-->>Host: return results / error (if head_dim < rotary_dim)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~35 minutes

  • Inspect scale_store_partial_chunk for correct zero-padding and bounds/write correctness.
  • Verify all kernel launches and templates use revised dynamic dispatch parameters (vec_size, bdx, bdy, grid/block).
  • Confirm RoPE-quantized routing and the head_dim < rotary_dim guard are integrated and error cases handled.
  • Run the updated test_rope_cos_sin_cache cases locally to validate behavior.

Possibly related PRs

Suggested reviewers

  • yzh119

Poem

🐇 I nibble code in evening light,

I guard the tails so bytes sleep tight,
Dispatch decides the kernel lane,
Partial chunks now dance in chain,
A hopping patch — precise and bright ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main objective of the PR - extending RoPE kernel support to handle more head dimensions.
Description check ✅ Passed The description includes all key template sections: a clear summary of changes, a related issue link, completed pre-commit and test checklist items, and reviewer notes requesting feedback on implementation details.
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1bee170 and 6805e9c.

📒 Files selected for processing (1)
  • include/flashinfer/pos_enc.cuh (7 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
include/flashinfer/pos_enc.cuh (5)

236-290: LGTM with a performance note.

The helper function correctly handles partial chunks when no_rope_dim is not a multiple of rope_dim. The fast path avoids overhead when possible, and the element-wise fallback ensures correctness at boundaries.

Note that the element-wise load/store path (lines 262-271, 281-288) will have performance overhead compared to vectorized operations. This is unavoidable for partial chunks, but consider monitoring the performance impact in cases where no_rope_dim is significantly smaller than rope_dim (e.g., head_dim=192, rope_dim=128, no_rope_dim=64).

Consider performance testing with various head_dim/rope_dim combinations to ensure the overhead is acceptable.


544-549: LGTM!

The usage of scale_store_partial_chunk correctly handles partial chunks in both K and Q non-RoPE processing paths. The chunk_valid calculation properly bounds the valid elements, preventing out-of-bounds writes.

Also applies to: 564-569


1050-1118: LGTM with a minor observation.

The refactored dispatch logic correctly computes thread block dimensions dynamically, enabling support for arbitrary head dimensions. The calculation ensures at least 128 threads per block for good GPU occupancy.

One edge case to note: if rope_dim is very small (e.g., 0), bdx would be 0 before line 1053. However, line 1053 (bdx = std::max(1u, bdx)) correctly guards against this, ensuring bdx >= 1.


1138-1195: LGTM!

Both functions correctly implement dynamic dispatch for paged cache operations. The chunking strategy (rope_chunks = 1, no_rope_chunks = (no_rope_dim + rope_dim - 1) / rope_dim) processes the full RoPE dimension in a single block, which is a reasonable optimization for these cache-append operations.

The MLA-specific handling correctly sets num_kv_heads = 1 as noted in the kernel comments.

Also applies to: 1213-1282


1294-1319: LGTM! Pointer arithmetic is correct.

The function correctly routes to the RopeQuantize kernel by:

  1. Validating that head_dim >= rotary_dim to prevent invalid configurations
  2. Slicing input/output pointers to separate RoPE and non-RoPE dimensions
  3. Using identity quantization scales (1.0f) to preserve the data type

Regarding the pointer arithmetic concern from past reviews (lines 1306-1311): The approach is mathematically correct. Offsetting the base pointer by rotary_dim and then adding the strided offsets in the kernel produces the same address as computing the strided offsets first and then adding rotary_dim, due to the commutativity of addition. The strides passed to the kernel (line 1316) are correctly the original tensor strides.

This routing strategy is a good example of code reuse, leveraging the generalized RopeQuantize kernel for the non-quantized case.

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility and robustness of the Rotary Positional Embedding (RoPE) implementation by enabling support for arbitrary head dimensions within the RopeQuantizeKernel. It introduces a mechanism to gracefully handle partial data chunks in non-RoPE dimensions and refactors the BatchQKApplyRotaryPosIdsCosSinCache function to utilize this improved kernel. These changes ensure correct processing for a broader range of model configurations and simplify future maintenance.

Highlights

  • Arbitrary Head Dimension Support: The RoPE kernel (RopeQuantizeKernel) now supports arbitrary head dimensions, specifically addressing cases where the non-RoPE portion of the head dimension is not a multiple of the rope_dim.
  • Partial Chunk Handling: Introduced a new scale_store_partial_chunk helper function to correctly handle partial vector loads and stores for non-RoPE dimensions, preventing errors when no_rope_dim is not a multiple of rope_dim.
  • Kernel Routing and Refactoring: The BatchQKApplyRotaryPosIdsCosSinCache function has been refactored to route its operations through the more general RopeQuantize kernel, simplifying the codebase and leveraging the new arbitrary head dimension support.
  • Dynamic Block Dimension (bdx): The DISPATCH_ROPE_DIM macro was removed from RopeQuantize and RopeQuantizeAppendPagedKVCache, allowing the bdx (block dimension x) to be dynamically calculated at runtime, improving flexibility for different rope_dim values.
  • Expanded Test Coverage: New test cases have been added to test_rope.py to validate the RoPE kernel's behavior with various head dimensions, including those that previously posed challenges.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for arbitrary head dimensions in the RoPE kernel by introducing a new helper function scale_store_partial_chunk to handle partial memory chunks and refactoring BatchQKApplyRotaryPosIdsCosSinCache to use the more general RopeQuantize kernel. This is a good simplification that reduces code duplication.

However, I've found a critical issue in how the non-RoPE tensor slices are handled. The pointer arithmetic used to create q_nope_in and k_nope_in is incorrect for multi-dimensional tensors, which will lead to incorrect memory accesses. I've also included a couple of suggestions to improve code clarity in the new helper function.

The added tests are good, but they seem to be passing despite the critical issue, which might indicate a problem with the test setup or reference implementation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
include/flashinfer/pos_enc.cuh (1)

1052-1120: Consider clarifying the bdx template parameter usage.

The kernel dispatch sets the template parameter bdx=1 (line 1097) while computing a runtime bdx value (line 1054). This works because the rotary_dim argument is explicitly passed to the RoPE functions, overriding the default vec_size * bdx. However, this discrepancy could be confusing for maintainability.

Consider either:

  1. Using the computed bdx value as the template parameter (would require DISPATCH_BDX macro), or
  2. Adding a comment explaining why the template bdx is set to 1 while runtime bdx varies

Example comment:

// Template bdx=1 because rotary_dim is explicitly passed to RoPE functions
auto kernel = RopeQuantizeKernel<INTERLEAVE, vec_size, 1, DType, IdType, QuantType>;
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1c4b522 and 1bee170.

📒 Files selected for processing (2)
  • include/flashinfer/pos_enc.cuh (7 hunks)
  • tests/attention/test_rope.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
tests/attention/test_rope.py (1)

303-306: LGTM! Test coverage expanded appropriately.

The new test configurations effectively validate the partial chunk handling introduced in this PR. They cover various scenarios where no_rope_dim < rope_dim, which exercises the new scale_store_partial_chunk logic for tail chunks.

include/flashinfer/pos_enc.cuh (4)

546-551: Correct usage of partial chunk handling.

The chunk_valid calculation properly handles tail chunks where no_rope_dim is not a multiple of rope_dim. The logic correctly computes the number of valid elements in the current chunk and handles the case where elem_offset >= no_rope_dim by setting chunk_valid = 0.

Also applies to: 566-571


1140-1196: Dispatch logic is consistent across variants.

The dynamic dispatch logic for RopeQuantizeAppendPagedKVCache (GQA/MHA) and RopeQuantizeAppendPagedMLACache follows the same pattern as RopeQuantize. The total_blocks_y calculation correctly accounts for the differences between GQA/MHA (includes V blocks) and MLA (no V blocks).

Note: The same bdx template parameter concern mentioned in the previous comment applies here as well.

Also applies to: 1214-1274


236-292: Verify performance impact with targeted benchmarks for partial chunk scenarios.

The code logic is correct with proper boundary checks and zero-padding. However, verification confirms the author's concern: no performance data exists for this code path. The existing benchmark uses head_size = rotary_dim (both 128), meaning no_rope_dim = 0, so it doesn't exercise the partial chunk handling that this function addresses.

Before merging, run benchmarks with configurations where no_rope_dim > 0 and no_rope_dim < rope_dim (e.g., head_dim=192, rope_dim=128, no_rope_dim=64) to quantify the performance impact of the element-by-element fallback path and zero-padding logic.


1286-1312: Performance verification requires manual benchmarking—the routing change logic is correct and well-tested.

The routing to RopeQuantize is intentional, uniform across all callers, and thoroughly validated for correctness. Existing tests in tests/attention/test_rope.py verify the output against reference implementations for all relevant configurations (head_dim: 64/128/256, partial_rotary_factor: 0.25–1.0). However, the original review specifically requests performance profiling to detect regressions, which cannot be completed automatically in this environment—you must run performance benchmarks locally to measure kernel execution time and throughput across representative workloads.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, cc @kahyunnam for another look

@yzh119
Copy link
Collaborator

yzh119 commented Nov 19, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !149 has been created, and the CI pipeline #38773414 is currently running. I'll report back once the pipeline job completes.

@raayandhar
Copy link
Contributor Author

raayandhar commented Nov 21, 2025

Is there an issue with CI? Seems like it has been running for 2 days now 😅

@yzh119
Copy link
Collaborator

yzh119 commented Nov 21, 2025

Hi @raayandhar the CI is finished (result not returned here for some reasons), the PR itself do not bring any regressions and should be ready to merge.

I'm running the benchmarks and will merge it as long as there is no performance regression.

@kahyunnam
Copy link
Contributor

kahyunnam commented Nov 21, 2025

LGTM overall, cc @kahyunnam for another look

This LGTM to me! I do wonder if we're adding some extra not-necessary overhead with still having the pointwise multiply by 1 (/*quant_scale_q=*/1.0f, /*quant_scale_kv=*/1.0f) but I also don't think that's that big of a deal.

I agree we can merge when benchmarking looks ok @yzh119 . Thanks @raayandhar for the contribution!

@kahyunnam kahyunnam self-requested a review November 21, 2025 02:19
@yzh119
Copy link
Collaborator

yzh119 commented Nov 22, 2025

There are indeed some performance regressions @raayandhar @kahyunnam :

On H100, Before this PR:

rope-latency:
    seq_len  FlashInfer    Native      vLLM
0       2.0    0.005936  0.062576  0.007968
1       4.0    0.005952  0.064256  0.008160
2       8.0    0.005888  0.069376  0.008128
3      16.0    0.006112  0.066160  0.008352
4      32.0    0.006240  0.066784  0.008576
5      64.0    0.006752  0.068608  0.009056
6     128.0    0.007808  0.075328  0.010464
7     256.0    0.009664  0.088256  0.012832
8     512.0    0.013472  0.115648  0.019904
9    1024.0    0.020896  0.170496  0.033728
10   2048.0    0.035712  0.290272  0.060896
11   4096.0    0.066240  0.523520  0.114400
12   8192.0    0.129952  0.985888  0.221632
13  16384.0    0.255168  1.897296  0.436032
14  32768.0    0.486576  3.715232  0.864640
15  65536.0    0.953376  7.342368  1.722112

After:

    seq_len  FlashInfer    Native      vLLM
0       2.0    0.005952  0.063488  0.007968
1       4.0    0.005952  0.064112  0.008128
2       8.0    0.005920  0.069440  0.008128
3      16.0    0.006272  0.067104  0.008384
4      32.0    0.006400  0.067552  0.008576
5      64.0    0.006688  0.068512  0.009056
6     128.0    0.007744  0.075424  0.010464
7     256.0    0.009760  0.088224  0.012832
8     512.0    0.013632  0.115712  0.019872
9    1024.0    0.021120  0.170720  0.033696
10   2048.0    0.036064  0.289760  0.060864
11   4096.0    0.066976  0.524288  0.114528
12   8192.0    0.128800  0.985664  0.221760
13  16384.0    0.259968  1.899248  0.435840
14  32768.0    0.621312  3.711968  0.864608
15  65536.0    1.758672  7.343424  1.722016

@raayandhar
Copy link
Contributor Author

There are indeed some performance regressions @raayandhar @kahyunnam :

On H100, Before this PR:

rope-latency:
    seq_len  FlashInfer    Native      vLLM
0       2.0    0.005936  0.062576  0.007968
1       4.0    0.005952  0.064256  0.008160
2       8.0    0.005888  0.069376  0.008128
3      16.0    0.006112  0.066160  0.008352
4      32.0    0.006240  0.066784  0.008576
5      64.0    0.006752  0.068608  0.009056
6     128.0    0.007808  0.075328  0.010464
7     256.0    0.009664  0.088256  0.012832
8     512.0    0.013472  0.115648  0.019904
9    1024.0    0.020896  0.170496  0.033728
10   2048.0    0.035712  0.290272  0.060896
11   4096.0    0.066240  0.523520  0.114400
12   8192.0    0.129952  0.985888  0.221632
13  16384.0    0.255168  1.897296  0.436032
14  32768.0    0.486576  3.715232  0.864640
15  65536.0    0.953376  7.342368  1.722112

After:

    seq_len  FlashInfer    Native      vLLM
0       2.0    0.005952  0.063488  0.007968
1       4.0    0.005952  0.064112  0.008128
2       8.0    0.005920  0.069440  0.008128
3      16.0    0.006272  0.067104  0.008384
4      32.0    0.006400  0.067552  0.008576
5      64.0    0.006688  0.068512  0.009056
6     128.0    0.007744  0.075424  0.010464
7     256.0    0.009760  0.088224  0.012832
8     512.0    0.013632  0.115712  0.019872
9    1024.0    0.021120  0.170720  0.033696
10   2048.0    0.036064  0.289760  0.060864
11   4096.0    0.066976  0.524288  0.114528
12   8192.0    0.128800  0.985664  0.221760
13  16384.0    0.259968  1.899248  0.435840
14  32768.0    0.621312  3.711968  0.864608
15  65536.0    1.758672  7.343424  1.722016

Oof ok, I will go and investigate. Could you share your benchmarking scripts?

@yzh119
Copy link
Collaborator

yzh119 commented Nov 22, 2025

https://github.com/flashinfer-ai/flashinfer/blob/main/benchmarks/bench_rope.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants