-
Notifications
You must be signed in to change notification settings - Fork 576
feat: support more head dim in RoPE kernel #2109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: support more head dim in RoPE kernel #2109
Conversation
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdded a device helper to handle partial RoPE quantization chunks, refactored kernel dispatch to compute vector/block sizes dynamically, routed RoPE-quantized paths through RopeQuantize, and expanded cos/sin cache tests with four new configurations. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Host
participant Dispatch as KernelDispatch
participant Device as GPU
Note over Host,Dispatch: Host launches RoPE/quant kernels
Host->>Dispatch: prepare params (head_dim, rotary_dim, no_rope_dim,...)
Dispatch-->>Device: select kernel (RopeQuantize / other)
alt RoPE-quantized path
Device->>Device: RopeQuantizeKernel executes
Device->>Device: call scale_store_partial_chunk for tail lanes
else Non-RoPE or full-chunk
Device->>Device: regular kernel vector loads/stores
end
Device-->>Host: return results / error (if head_dim < rotary_dim)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~35 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (5)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the flexibility and robustness of the Rotary Positional Embedding (RoPE) implementation by enabling support for arbitrary head dimensions within the RopeQuantizeKernel. It introduces a mechanism to gracefully handle partial data chunks in non-RoPE dimensions and refactors the BatchQKApplyRotaryPosIdsCosSinCache function to utilize this improved kernel. These changes ensure correct processing for a broader range of model configurations and simplify future maintenance. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for arbitrary head dimensions in the RoPE kernel by introducing a new helper function scale_store_partial_chunk to handle partial memory chunks and refactoring BatchQKApplyRotaryPosIdsCosSinCache to use the more general RopeQuantize kernel. This is a good simplification that reduces code duplication.
However, I've found a critical issue in how the non-RoPE tensor slices are handled. The pointer arithmetic used to create q_nope_in and k_nope_in is incorrect for multi-dimensional tensors, which will lead to incorrect memory accesses. I've also included a couple of suggestions to improve code clarity in the new helper function.
The added tests are good, but they seem to be passing despite the critical issue, which might indicate a problem with the test setup or reference implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
include/flashinfer/pos_enc.cuh (1)
1052-1120: Consider clarifying the bdx template parameter usage.The kernel dispatch sets the template parameter
bdx=1(line 1097) while computing a runtimebdxvalue (line 1054). This works because therotary_dimargument is explicitly passed to the RoPE functions, overriding the defaultvec_size * bdx. However, this discrepancy could be confusing for maintainability.Consider either:
- Using the computed
bdxvalue as the template parameter (would require DISPATCH_BDX macro), or- Adding a comment explaining why the template bdx is set to 1 while runtime bdx varies
Example comment:
// Template bdx=1 because rotary_dim is explicitly passed to RoPE functions auto kernel = RopeQuantizeKernel<INTERLEAVE, vec_size, 1, DType, IdType, QuantType>;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
include/flashinfer/pos_enc.cuh(7 hunks)tests/attention/test_rope.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
tests/attention/test_rope.py (1)
303-306: LGTM! Test coverage expanded appropriately.The new test configurations effectively validate the partial chunk handling introduced in this PR. They cover various scenarios where
no_rope_dim < rope_dim, which exercises the newscale_store_partial_chunklogic for tail chunks.include/flashinfer/pos_enc.cuh (4)
546-551: Correct usage of partial chunk handling.The
chunk_validcalculation properly handles tail chunks whereno_rope_dimis not a multiple ofrope_dim. The logic correctly computes the number of valid elements in the current chunk and handles the case whereelem_offset >= no_rope_dimby settingchunk_valid = 0.Also applies to: 566-571
1140-1196: Dispatch logic is consistent across variants.The dynamic dispatch logic for
RopeQuantizeAppendPagedKVCache(GQA/MHA) andRopeQuantizeAppendPagedMLACachefollows the same pattern asRopeQuantize. Thetotal_blocks_ycalculation correctly accounts for the differences between GQA/MHA (includes V blocks) and MLA (no V blocks).Note: The same
bdxtemplate parameter concern mentioned in the previous comment applies here as well.Also applies to: 1214-1274
236-292: Verify performance impact with targeted benchmarks for partial chunk scenarios.The code logic is correct with proper boundary checks and zero-padding. However, verification confirms the author's concern: no performance data exists for this code path. The existing benchmark uses
head_size = rotary_dim(both 128), meaningno_rope_dim = 0, so it doesn't exercise the partial chunk handling that this function addresses.Before merging, run benchmarks with configurations where
no_rope_dim > 0andno_rope_dim < rope_dim(e.g.,head_dim=192, rope_dim=128, no_rope_dim=64) to quantify the performance impact of the element-by-element fallback path and zero-padding logic.
1286-1312: Performance verification requires manual benchmarking—the routing change logic is correct and well-tested.The routing to
RopeQuantizeis intentional, uniform across all callers, and thoroughly validated for correctness. Existing tests intests/attention/test_rope.pyverify the output against reference implementations for all relevant configurations (head_dim: 64/128/256, partial_rotary_factor: 0.25–1.0). However, the original review specifically requests performance profiling to detect regressions, which cannot be completed automatically in this environment—you must run performance benchmarks locally to measure kernel execution time and throughput across representative workloads.
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, cc @kahyunnam for another look
|
/bot run |
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
|
Is there an issue with CI? Seems like it has been running for 2 days now 😅 |
|
Hi @raayandhar the CI is finished (result not returned here for some reasons), the PR itself do not bring any regressions and should be ready to merge. I'm running the benchmarks and will merge it as long as there is no performance regression. |
This LGTM to me! I do wonder if we're adding some extra not-necessary overhead with still having the pointwise multiply by 1 ( I agree we can merge when benchmarking looks ok @yzh119 . Thanks @raayandhar for the contribution! |
|
There are indeed some performance regressions @raayandhar @kahyunnam : On H100, Before this PR: rope-latency:
seq_len FlashInfer Native vLLM
0 2.0 0.005936 0.062576 0.007968
1 4.0 0.005952 0.064256 0.008160
2 8.0 0.005888 0.069376 0.008128
3 16.0 0.006112 0.066160 0.008352
4 32.0 0.006240 0.066784 0.008576
5 64.0 0.006752 0.068608 0.009056
6 128.0 0.007808 0.075328 0.010464
7 256.0 0.009664 0.088256 0.012832
8 512.0 0.013472 0.115648 0.019904
9 1024.0 0.020896 0.170496 0.033728
10 2048.0 0.035712 0.290272 0.060896
11 4096.0 0.066240 0.523520 0.114400
12 8192.0 0.129952 0.985888 0.221632
13 16384.0 0.255168 1.897296 0.436032
14 32768.0 0.486576 3.715232 0.864640
15 65536.0 0.953376 7.342368 1.722112After: seq_len FlashInfer Native vLLM
0 2.0 0.005952 0.063488 0.007968
1 4.0 0.005952 0.064112 0.008128
2 8.0 0.005920 0.069440 0.008128
3 16.0 0.006272 0.067104 0.008384
4 32.0 0.006400 0.067552 0.008576
5 64.0 0.006688 0.068512 0.009056
6 128.0 0.007744 0.075424 0.010464
7 256.0 0.009760 0.088224 0.012832
8 512.0 0.013632 0.115712 0.019872
9 1024.0 0.021120 0.170720 0.033696
10 2048.0 0.036064 0.289760 0.060864
11 4096.0 0.066976 0.524288 0.114528
12 8192.0 0.128800 0.985664 0.221760
13 16384.0 0.259968 1.899248 0.435840
14 32768.0 0.621312 3.711968 0.864608
15 65536.0 1.758672 7.343424 1.722016 |
Oof ok, I will go and investigate. Could you share your benchmarking scripts? |
📌 Description
With the new changes we should be able to support arbitrary head dim using the
RopeQuantizeKernel, and I have routed theBatchQKApplyRotaryPosIdsCosSinCacheto do so.🔍 Related Issues
#2104
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.). NOTE: There were a set of tests where I got this error:which I know is related to my system. Unfortunately I do not manage this system and it does not have docker, so trying to fix this is a bit difficult. Hopefully someone else can verify my tests or run CI. All other tests were passing, and all the failing tests had that error.
Reviewer Notes
Please let me know if there's a smarter way to get around this hack or if other tests should be updated. Also I think we should remove the older kernel but let me know if we should do otherwise. I also need to test perf.
Summary by CodeRabbit
Bug Fixes
Tests