-
Notifications
You must be signed in to change notification settings - Fork 541
Sampling non contiguous #1916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Sampling non contiguous #1916
Conversation
Signed-off-by: Cindy Zhang <[email protected]>
Signed-off-by: Cindy Zhang <[email protected]>
Signed-off-by: Cindy Zhang <[email protected]>
Signed-off-by: Cindy Zhang <[email protected]>
…tride Signed-off-by: Cindy Zhang <[email protected]>
|
Hi @yzh119, I added support for |
Yes it should be applicable to all these kernels. |
Signed-off-by: Cindy Zhang <[email protected]>
WalkthroughThe diff threads a new uint32_t stride parameter through CUDA kernels and host wrappers to use row_idx * stride for indexing, replaces generic checks with CUDA-specific CHECK_CUDA / CHECK_LAST_DIM_CONTIGUOUS, and adds a test variant to exercise contiguous and non-contiguous logits inputs. Changes
Sequence DiagramsequenceDiagram
participant Py as Python Host
participant Lib as FlashInfer Host API
participant CUDA as CUDA Kernel
participant GPU as GPU Memory
Note over Py,Lib: Call site builds tensors (contiguous or sliced)
Py->>Lib: top_k_mask_logits(logits, mask, ..., stride=logits.strides[0])
Lib->>CUDA: launch Kernel(data_ptr, batch, d, stride, ...)
CUDA->>GPU: Load at base = row_idx * stride + col
alt contiguous
Note right of CUDA: stride == d (same as before)
else non-contiguous
Note right of CUDA: stride != d (handles padded layouts)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🔇 Additional comments (3)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
csrc/sampling.cu (1)
121-126: Use int for top-k array; current float cast is inconsistent.**Kernels expect IdType* (int) for top_k_arr. Align the cast.
- has_top_k_arr ? static_cast<float*>(maybe_top_k_arr.value()->data) : nullptr, batch_size, + has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value()->data) : nullptr, batch_size,tests/utils/test_sampling.py (1)
499-501: Fix API name: use top_k_renorm_probs (plural).Current call will fail.
- renormed_probs_ref = flashinfer.sampling.top_k_renorm_prob(probs, k) + renormed_probs_ref = flashinfer.sampling.top_k_renorm_probs(probs, k)
🧹 Nitpick comments (1)
csrc/renorm.cu (1)
24-41: Add device compatibility checks for outputs.To prevent silent device mismatches, add:
- CHECK_DEVICE(renorm_probs, probs) in top_p_renorm_probs/top_k_renorm_probs
- CHECK_DEVICE(mask_logits, logits) in top_k_mask_logits
Also applies to: 43-61, 63-81
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/renorm.cu(4 hunks)csrc/sampling.cu(11 hunks)include/flashinfer/sampling.cuh(49 hunks)tests/utils/test_sampling.py(1 hunks)
🔇 Additional comments (3)
csrc/sampling.cu (1)
48-50: Contiguity checks + stride propagation look good.CHECK_CUDA/CHECK_LAST_DIM_CONTIGUOUS plus passing strides[0] are correct and align with the new kernels.
Also applies to: 67-69, 87-89, 109-111, 134-136, 163-165
tests/utils/test_sampling.py (1)
481-492: Good coverage for non‑contiguous batch dimension.The slicing pattern (x[::2, :]) ensures last-dim contiguity while breaking batch contiguity. Nice.
include/flashinfer/sampling.cuh (1)
247-259: Stride-based indexing is correctly applied.Using row_idx * stride for all loads while storing outputs as row_idx * d is consistent with non-contiguous leading dims.
Also applies to: 288-297, 1572-1590, 1759-1766, 1861-1866, 1988-1992
| template <typename T, typename IdType> | ||
| cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_k_arr, | ||
| uint32_t batch_size, uint32_t top_k_val, uint32_t d, | ||
| bool deterministic, uint64_t philox_seed, uint64_t philox_offset, | ||
| cudaStream_t stream = 0) { | ||
| uint32_t stride, bool deterministic, uint64_t philox_seed, | ||
| uint64_t philox_offset, cudaStream_t stream = 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Align TopKSamplingFromProb wrapper to IdType top_k_arr.*
Kernel expects IdType*; wrapper currently takes T*, causing type drift. Align with other wrappers (e.g., TopKTopPSamplingFromProb).
-template <typename T, typename IdType>
-cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_k_arr,
- uint32_t batch_size, uint32_t top_k_val, uint32_t d,
- uint32_t stride, bool deterministic, uint64_t philox_seed,
- uint64_t philox_offset, cudaStream_t stream = 0) {
+template <typename T, typename IdType>
+cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, IdType* top_k_arr,
+ uint32_t batch_size, uint32_t top_k_val, uint32_t d,
+ uint32_t stride, bool deterministic, uint64_t philox_seed,
+ uint64_t philox_offset, cudaStream_t stream = 0) {
@@
- void* args[] = {&probs, &output, &indices, &top_k_arr, &top_k_val,
- &d, &stride, &philox_seed, &philox_offset};
+ void* args[] = {&probs, &output, &indices, &top_k_arr, &top_k_val,
+ &d, &stride, &philox_seed, &philox_offset};Also applies to: 1459-1461
🤖 Prompt for AI Agents
In include/flashinfer/sampling.cuh around lines 1447-1451 (and similarly at
1459-1461), the TopKSamplingFromProb wrapper declares top_k_arr as T* while the
kernel expects IdType*; update the function signature(s) to use IdType*
top_k_arr to match the kernel and other wrappers (e.g.,
TopKTopPSamplingFromProb), and update any internal uses/casts of top_k_arr in
the wrapper to treat it as IdType* (adjust template instantiation or
reinterpret_casts if necessary) and ensure all call sites and overloads are
updated accordingly to avoid type drift.
Signed-off-by: Cindy Zhang <[email protected]>
|
@yzh119 Should I also make changes to the softmax and speculative sampling kernels? |
|
Hi @zcin sure you are welcomed to work on that if you have bandwidth :) |
📌 Description
Functions in sampling like
top_k_mask_logitsrequire the input tensor to be contiguous. But in reality we just need the last dimension to be contiguous. This PR adds support for tensors that are not contiguous in batch dimension, but contiguous in the last dimension.🔍 Related Issues
closes #1866
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
API Changes
Improvements
Tests