Skip to content

Conversation

@zcin
Copy link

@zcin zcin commented Oct 12, 2025

📌 Description

Functions in sampling like top_k_mask_logits require the input tensor to be contiguous. But in reality we just need the last dimension to be contiguous. This PR adds support for tensors that are not contiguous in batch dimension, but contiguous in the last dimension.

🔍 Related Issues

closes #1866

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • API Changes

    • Sampling host/kernel interfaces now accept an explicit stride parameter to control memory indexing.
  • Improvements

    • Better handling of non-contiguous tensor layouts across sampling operations.
    • Stronger CUDA-side input validation to catch layout/contiguity issues earlier.
  • Tests

    • Added parameterized tests covering both contiguous and non-contiguous inputs for sampling logic.

@zcin zcin marked this pull request as ready for review October 12, 2025 19:04
@zcin
Copy link
Author

zcin commented Oct 12, 2025

Hi @yzh119, I added support for top_k_mask_logits and added test cases for it. If the approach looks good, I can also make the same changes to other relevant sampling kernels?

@yzh119
Copy link
Collaborator

yzh119 commented Oct 13, 2025

Hi @yzh119, I added support for top_k_mask_logits and added test cases for it. If the approach looks good, I can also make the same changes to other relevant sampling kernels?

Yes it should be applicable to all these kernels.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 17, 2025

Walkthrough

The diff threads a new uint32_t stride parameter through CUDA kernels and host wrappers to use row_idx * stride for indexing, replaces generic checks with CUDA-specific CHECK_CUDA / CHECK_LAST_DIM_CONTIGUOUS, and adds a test variant to exercise contiguous and non-contiguous logits inputs.

Changes

Cohort / File(s) Summary
CUDA kernel call-sites & checks
csrc/renorm.cu, csrc/sampling.cu
Replace generic input checks with CHECK_CUDA and CHECK_LAST_DIM_CONTIGUOUS. Update internal calls to pass the leading-dimension stride (probs->strides[0], logits->strides[0]) as an extra argument to renorm/sampling routines.
Kernel/header signatures & indexing
include/flashinfer/sampling.cuh
Add uint32_t stride parameter to many kernels and host wrappers (SamplingFromLogits*, SamplingFromProb*, TopK*, TopP*, MinP*, TopPRenormProb*, TopKMaskLogits*, TopKRenormProb*, GetMinMaxValue/GetMaxValue, OnlineSoftmax variants). Replace row_idx * d with row_idx * stride and update kernel launch argument arrays.
Tests
tests/utils/test_sampling.py
Parametrize test_top_k_mask_logits with contiguous (True/False); construct logits as contiguous or created via extra-dim slicing to test non-contiguous input handling.

Sequence Diagram

sequenceDiagram
    participant Py as Python Host
    participant Lib as FlashInfer Host API
    participant CUDA as CUDA Kernel
    participant GPU as GPU Memory

    Note over Py,Lib: Call site builds tensors (contiguous or sliced)
    Py->>Lib: top_k_mask_logits(logits, mask, ..., stride=logits.strides[0])
    Lib->>CUDA: launch Kernel(data_ptr, batch, d, stride, ...)
    CUDA->>GPU: Load at base = row_idx * stride + col
    alt contiguous
        Note right of CUDA: stride == d (same as before)
    else non-contiguous
        Note right of CUDA: stride != d (handles padded layouts)
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 I hopped through kernels, found a stride,

row by row I learned to glide.
Contiguous paths and sliced parade,
now every layout's softly made.
A carrot for tests—no more collide!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ❓ Inconclusive The title "Sampling non contiguous" is vague and imprecise. While it references a real aspect of the changeset (supporting non-contiguous tensors), it uses incomplete phrasing that lacks clarity. A reader scanning the history would not immediately understand that this PR adds support for non-contiguous batch dimensions in sampling functions, or what the specific improvement is compared to before. The title would benefit from more descriptive language such as "Support non-contiguous batch dimensions in sampling kernels" or "Allow non-contiguous tensors in sampling functions".
✅ Passed checks (3 passed)
Check name Status Explanation
Linked Issues Check ✅ Passed The code changes directly address the requirement in issue #1866, which requested support for non-contiguous tensors in sampling functions. The implementation replaces full contiguity checks with CHECK_LAST_DIM_CONTIGUOUS checks across all sampling kernels, introduces a stride parameter to handle non-contiguous batch dimensions, and updates all kernel signatures and call sites accordingly. Test coverage has been added via parametrized tests in test_sampling.py that specifically validate non-contiguous tensor handling. The approach has been applied uniformly across all sampling kernels as recommended in the review discussion.
Out of Scope Changes Check ✅ Passed All code changes in this pull request are directly related to the stated objective of supporting non-contiguous tensors in sampling functions. The modifications include: replacing input validation checks with last-dimension-only contiguity checks, adding stride parameters to kernel and host function signatures to handle arbitrary batch dimension strides, updating all call sites and kernel implementations to use the stride parameter, and adding comprehensive test cases for non-contiguous tensor handling. No extraneous changes or unrelated modifications are present in the changeset.
Description Check ✅ Passed The pull request description is well-structured and follows the provided template. It includes a clear explanation section stating that sampling functions now support tensors that are non-contiguous in batch dimension but contiguous in the last dimension, references the related issue (#1866), and confirms completion of the pre-commit checks and tests. All required sections from the template are present and filled out with relevant information.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 807813a and 7c4ff34.

📒 Files selected for processing (1)
  • csrc/renorm.cu (4 hunks)
🔇 Additional comments (3)
csrc/renorm.cu (3)

26-27: Previous critical issue resolved; current changes LGTM.

The past review correctly identified that line 38 had an undefined identifier probs_strides[0]. The current code correctly uses probs->strides[0], resolving the compile error.

The added CUDA-specific validations (CHECK_CUDA, CHECK_LAST_DIM_CONTIGUOUS) and stride parameter properly support non-contiguous batch dimensions while ensuring last-dimension contiguity required by the kernel.

Also applies to: 38-38


45-46: LGTM!

The changes follow the same pattern as top_p_renorm_probs: CUDA-specific validations ensure device compatibility and last-dimension contiguity, while the stride parameter enables correct indexing for non-contiguous batch dimensions.

Also applies to: 57-57


65-66: LGTM! Directly resolves issue #1866.

The changes follow the established pattern and directly address the RuntimeError reported in issue #1866 by:

  1. Validating CUDA device placement and last-dimension contiguity
  2. Passing stride information to enable correct indexing for non-contiguous batch dimensions

Also applies to: 77-77


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
csrc/sampling.cu (1)

121-126: Use int for top-k array; current float cast is inconsistent.**

Kernels expect IdType* (int) for top_k_arr. Align the cast.

-      has_top_k_arr ? static_cast<float*>(maybe_top_k_arr.value()->data) : nullptr, batch_size,
+      has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value()->data) : nullptr, batch_size,
tests/utils/test_sampling.py (1)

499-501: Fix API name: use top_k_renorm_probs (plural).

Current call will fail.

-    renormed_probs_ref = flashinfer.sampling.top_k_renorm_prob(probs, k)
+    renormed_probs_ref = flashinfer.sampling.top_k_renorm_probs(probs, k)
🧹 Nitpick comments (1)
csrc/renorm.cu (1)

24-41: Add device compatibility checks for outputs.

To prevent silent device mismatches, add:

  • CHECK_DEVICE(renorm_probs, probs) in top_p_renorm_probs/top_k_renorm_probs
  • CHECK_DEVICE(mask_logits, logits) in top_k_mask_logits

Also applies to: 43-61, 63-81

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4b55b26 and 807813a.

📒 Files selected for processing (4)
  • csrc/renorm.cu (4 hunks)
  • csrc/sampling.cu (11 hunks)
  • include/flashinfer/sampling.cuh (49 hunks)
  • tests/utils/test_sampling.py (1 hunks)
🔇 Additional comments (3)
csrc/sampling.cu (1)

48-50: Contiguity checks + stride propagation look good.

CHECK_CUDA/CHECK_LAST_DIM_CONTIGUOUS plus passing strides[0] are correct and align with the new kernels.

Also applies to: 67-69, 87-89, 109-111, 134-136, 163-165

tests/utils/test_sampling.py (1)

481-492: Good coverage for non‑contiguous batch dimension.

The slicing pattern (x[::2, :]) ensures last-dim contiguity while breaking batch contiguity. Nice.

include/flashinfer/sampling.cuh (1)

247-259: Stride-based indexing is correctly applied.

Using row_idx * stride for all loads while storing outputs as row_idx * d is consistent with non-contiguous leading dims.

Also applies to: 288-297, 1572-1590, 1759-1766, 1861-1866, 1988-1992

Comment on lines 1439 to +1451
template <typename T, typename IdType>
cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_k_arr,
uint32_t batch_size, uint32_t top_k_val, uint32_t d,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset,
cudaStream_t stream = 0) {
uint32_t stride, bool deterministic, uint64_t philox_seed,
uint64_t philox_offset, cudaStream_t stream = 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Align TopKSamplingFromProb wrapper to IdType top_k_arr.*

Kernel expects IdType*; wrapper currently takes T*, causing type drift. Align with other wrappers (e.g., TopKTopPSamplingFromProb).

-template <typename T, typename IdType>
-cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_k_arr,
-                                 uint32_t batch_size, uint32_t top_k_val, uint32_t d,
-                                 uint32_t stride, bool deterministic, uint64_t philox_seed,
-                                 uint64_t philox_offset, cudaStream_t stream = 0) {
+template <typename T, typename IdType>
+cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, IdType* top_k_arr,
+                                 uint32_t batch_size, uint32_t top_k_val, uint32_t d,
+                                 uint32_t stride, bool deterministic, uint64_t philox_seed,
+                                 uint64_t philox_offset, cudaStream_t stream = 0) {
@@
-    void* args[] = {&probs, &output, &indices,     &top_k_arr,    &top_k_val,
-                    &d,     &stride, &philox_seed, &philox_offset};
+    void* args[] = {&probs, &output, &indices,     &top_k_arr,    &top_k_val,
+                    &d,     &stride, &philox_seed, &philox_offset};

Also applies to: 1459-1461

🤖 Prompt for AI Agents
In include/flashinfer/sampling.cuh around lines 1447-1451 (and similarly at
1459-1461), the TopKSamplingFromProb wrapper declares top_k_arr as T* while the
kernel expects IdType*; update the function signature(s) to use IdType*
top_k_arr to match the kernel and other wrappers (e.g.,
TopKTopPSamplingFromProb), and update any internal uses/casts of top_k_arr in
the wrapper to treat it as IdType* (adjust template instantiation or
reinterpret_casts if necessary) and ensure all call sites and overloads are
updated accordingly to avoid type drift.

Signed-off-by: Cindy Zhang <[email protected]>
@zcin
Copy link
Author

zcin commented Oct 17, 2025

@yzh119 Should I also make changes to the softmax and speculative sampling kernels?

@yzh119
Copy link
Collaborator

yzh119 commented Oct 20, 2025

Hi @zcin sure you are welcomed to work on that if you have bandwidth :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RuntimeError: logits must be contiguous in flashinfer/sampling.py:375

2 participants