Skip to content

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Nov 19, 2025

📌 Description

This PR refactors the out-dated fa3 codebase, more specifically, for page_size>1, the page offset calculation is performed inside the kernel, without the need of a standalone function call to block_sparse_indices_to_vector_sparse_offsets, and optimize the kv_offset calculation with prefetching and shuffling.

This PR also fixes the failed unittest on hopper.

However, the FA3 structure in our codebase is still terrible outdated without important features such as IntraWGOverlap and RescaleOBeforeGemm, will follow up soon in a later PR.

🔍 Related Issues

This PR should fixes #1647

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Added runtime validation enforcing K/V page-stride and stride_n consistency for paged sparse attention.
  • Refactor

    • Unified sparse attention to page-based addressing; removed legacy vector-sparse conversion and buffers.
    • Simplified workspace/wrapper APIs to use paged KV buffers end-to-end.
  • New Features

    • Expanded FP8 support (scales propagated, output dtype handling) and added ragged-KV dispatch/pathways.
  • Tests

    • Added FP8 ragged/paged prefill tests; removed obsolete vector-sparse unit test.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 19, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds page-level K/V stride and page_size fields across host and device params, enforces K/V stride and stride_n consistency at runtime in paged paths, removes block-sparse→vector-sparse conversion and its helpers/tests, refactors mainloops to use page-table/manual page-based K/V loads, and adds FP8 ragged/paged variants and tests.

Changes

Cohort / File(s) Summary
Paged params struct additions
csrc/batch_prefill_sm90_customize_config.jinja, include/flashinfer/attention/hopper/default_params.cuh
Added k_page_stride and v_page_stride (int64_t) fields to paged params structs to record page-level strides for K and V.
Kernel/host wiring: paged args & runtime checks
csrc/batch_prefill_sm90.cu, include/.../prefill_sm90.cuh
Populate and propagate k_page_stride, v_page_stride (and page_size where applicable) into kernel arguments; add runtime sanity checks requiring K/V page-stride and stride_n equality for sparse paged paths.
Sparse mainloop & quantized mainloop refactor
include/flashinfer/attention/hopper/sparse_mainloop.cuh, include/.../quantization/mainloop_sparse_load.cuh
Replace block-sparse gather tensors with page-table/manual page-based K/V addressing: add k_stride_n, k_page_stride, v_stride_n, v_page_stride, page_size to Arguments/Params and implement page-aware load helpers and cooperative prefetch/loading.
Removal of block→vector conversion (C/CUDA)
csrc/flashinfer_page_binding.cu, csrc/page.cu, include/flashinfer/page.cuh
Delete declaration, implementation, host launcher, and FFI export for block_sparse_indices_to_vector_sparse_offsets (kernel + host helper removed).
Python helper & tests removed
flashinfer/page.py, tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py
Removed Python wrapper block_sparse_indices_to_vector_sparse_offsets and its unit test.
Prefill & sparse Python unification & FP8 wiring
flashinfer/prefill.py, flashinfer/sparse.py
Drop vector-sparse buffers/branches (fa3-specific) and use paged_kv_indptr_buf/paged_kv_indices_buf exclusively; propagate FP8 scale tensors and add/track output dtype (o_data_type) in planning/run paths; update reset_workspace_buffer signatures.
FP8 / Ragged KV additions
csrc/batch_prefill_fp8_sm90.cu, csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja, include/.../quantization/prefill_sm90.cuh
Add Ragged KV dispatch declarations/paths and FP8 ragged/paged kernel variants and instantiation wiring; declare new dispatched templates.
Tests & benchmarks
tests/attention/test_batch_prefill_kernels.py, tests/attention/test_hopper.py, tests/attention/test_hopper_fp8_attention.py, benchmarks/bench_hopper_fp8_attention.py
Adjust LSE buffering, add page_size=16 parameterization, remove kv_indices padding, add FP8 ragged/paged tests and FP8 benchmarks/quantization helpers.
Misc small fix
flashinfer/triton/kernels/cascade.py
Index loop variable switched to 64-bit for correct addressing in variable_length_merge_states_kernel.
Epilogue sync simplification
include/flashinfer/attention/hopper/epilogue.cuh
Simplify barrier usage: unify arrival counts and always sync NUM_MMA_THREADS at epilogue barrier.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python tests/bench
    participant Wrap as Wrapper (flashinfer/prefill.py / sparse.py)
    participant Host as Host planner (csrc/*, headers)
    participant Kernel as CUDA device mainloop

    Note over Py,Kernel: OLD — vector-sparse flow (removed)
    Py->>Wrap: reset_workspace_buffer(..., vector_sparse_buffers)
    Wrap->>Host: plan/run(..., vector_sparse_buffers)
    Host->>Kernel: invoke with BlockSparseIndexedGather tensors
    Kernel->>Kernel: block-sparse gather loads

    Note over Py,Kernel: NEW — paged/ragged + FP8
    Py->>Wrap: reset_workspace_buffer(..., paged_kv_indptr, paged_kv_indices, o_data_type?, fp8_scales?)
    Wrap->>Host: plan/run(..., paged_kv_indptr, paged_kv_indices, o_data_type, fp8_scales)
    Host->>Kernel: invoke with (K_ptr, V_ptr, kv_indices, k_page_stride, v_page_stride, k_stride_n, v_stride_n, page_size, ...)
    Kernel->>Kernel: compute page_idx via divmod(page_size)
    Kernel->>Kernel: load_kv_tile / cp_async_zfill using page strides and stride_n
    Kernel->>Kernel: proceed to MMA/epilogue (sync)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60–90 minutes

  • Files needing extra attention:
    • include/flashinfer/attention/hopper/sparse_mainloop.cuh and include/.../quantization/mainloop_sparse_load.cuh — correctness of page arithmetic (fastdiv/divmod), load_kv_tile, cp_async_zfill semantics, prefetch and barrier ordering.
    • csrc/batch_prefill_sm90.cu — propagation of k_page_stride/v_page_stride and runtime consistency checks.
    • flashinfer/prefill.py and flashinfer/sparse.py — FP8 scale propagation, o_data_type handling, removal of vector-sparse code paths and signature changes affecting call sites and tests.
    • Tests modifying LSE handling and FP8 validation — ensure test inputs/quantization match kernel expectations.

Possibly related PRs

Suggested reviewers

  • joker-eph
  • aleozlx
  • djmmoss
  • cyx-6
  • wenscarl
  • IwakuraRein

Poem

🐇 I hop through pages, counting stride by stride,
K and V now march steady, always side by side.
Old vectors gone, page-tables hum and sing,
Scales in tow, the FP8 kernels spring.
A rabbit cheers — the kernels ping!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 39.13% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main refactoring work and specific fix (fa3 codebase update and hopper unittest fix). It is concise and specific to the primary changes.
Description check ✅ Passed The PR description provides context about the refactoring (moving page-offset calculation into kernel, removing standalone function), mentions the unittest fix, and references the linked issue. Pre-commit and test checklist items are marked complete.
Linked Issues check ✅ Passed Issue #1647 reported '_vector_sparse_indices_buffer' undersizing error in FA3. The PR removes the block_sparse_indices_to_vector_sparse_offsets function and vector sparse buffer dependencies entirely, moving offset calculation into kernel, directly addressing the root cause.
Out of Scope Changes check ✅ Passed All changes scope to FA3 refactoring (removing vector sparse functions/buffers, adding page-stride parameters, updating kernel logic) and hopper unittest fixes. Minor changes to epilogue synchronization and triton cascade indexing are well-contained and support the overall kernel optimization goals.
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the efficiency and correctness of the Flash Attention v3 (FA3) implementation, particularly for paged Key-Value (KV) caches with page sizes greater than one. By integrating page offset calculations directly into the kernel and optimizing KV offset handling with prefetching and shuffling, the codebase becomes more streamlined and performant. A critical bug affecting Hopper unittests has also been resolved, ensuring robust operation on the target architecture. These changes collectively contribute to a more optimized and reliable sparse attention mechanism.

Highlights

  • FA3 Codebase Refactoring: The pull request refactors the Flash Attention v3 (FA3) codebase, specifically for page_size > 1, by moving the page offset calculation directly inside the kernel, eliminating the need for a separate block_sparse_indices_to_vector_sparse_offsets function.
  • KV Offset Optimization: Optimizations for KV offset calculation have been implemented, incorporating prefetching and shuffling techniques to enhance performance.
  • Hopper Unittest Fix: A previously failing unittest on Hopper architecture has been identified and fixed, ensuring correctness on this specific hardware.
  • Parameter Updates for Paged KV Cache: New parameters, k_page_stride and v_page_stride, have been introduced to store the stride between pages for sparse paged KV cache, along with assertions to ensure K and V have consistent page and stride_n values for efficiency.
  • Removal of Redundant Buffer Allocations: The _vector_sparse_indices_buffer and _vector_sparse_indptr_buffer have been removed from Python-side classes (BatchPrefillWrapper, BatchPrefillWrapperSparse, BatchDecodeWrapperSparse) as their functionality is now handled directly within the kernel.
  • Test Updates and File Management: The test_batch_prefill_kernels.py file has been updated to correctly handle lse (log-sum-exp) buffer allocation when return_lse is enabled. Additionally, a test file related to the removed block_sparse_indices_to_vector_sparse_offsets function has been removed, and another test file has been renamed for better organization.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yzh119
Copy link
Collaborator Author

yzh119 commented Nov 19, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !151 has been created, and the CI pipeline #38780168 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant refactoring of the FA3 codebase. It removes the standalone block_sparse_indices_to_vector_sparse_offsets function and moves the page offset calculation directly into the CUDA kernel, which is a great simplification. The changes also include an optimization for kv_offset calculation using prefetching and shuffling, which should improve performance. The code removal across C++, Python, and header files is consistent and clean. I've found a couple of minor areas for code improvement to reduce redundancy, but overall the changes look solid and well-implemented.

int d_idx = get<1>(coord);
int kv_idx = kv_base_idx + kv_offset;

bool guard = kv_idx < kv_len && kv_offset < valid_tile_size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The guard condition can be simplified. The check kv_idx < kv_len is redundant when use_predicate is true, as it's already implied by kv_offset < valid_tile_size. When use_predicate is false, valid_tile_size is CTA_KV, and kv_offset is always less than CTA_KV, so the guard is not needed for non-last tiles anyway. You can simplify this to just kv_offset < valid_tile_size.

        bool guard = kv_offset < valid_tile_size;

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
flashinfer/prefill.py (1)

2109-2156: Paged KV run argument rewiring is reasonable; verify trtllm cum_seq_lens_kv semantics

Using _paged_kv_indptr_buf / _paged_kv_indices_buf directly in run_args keeps the Python wrapper aligned with the new paged-KV FFI signature, and _qo_indptr_buf is a natural fit for cum_seq_lens_q. The only subtle point is that _paged_kv_indptr_buf is in units of pages, while trtllm paged attention APIs traditionally expect cum_seq_lens_kv in tokens; if the trtllm-gen backend actually consumes those trailing args as cum-token lengths, it may need cumsum(seq_lens) instead of raw page indptr. Worth double-checking against the current trtllm kernel contract.

tests/attention/test_batch_prefill_kernels.py (1)

147-157: Good coverage of preallocated LSE path; consider also checking LSE values

Using lse_buffer = torch.empty_like(lse) and rerunning with out=o_buffer, lse=lse_buffer now exercises the buffered LSE write path, which should catch the Hopper regression. To fully validate it, you may also want to assert torch.testing.assert_close(lse, lse_buffer, ...) alongside the existing o vs o_buffer check.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b9964cc and 0bbd445.

📒 Files selected for processing (15)
  • csrc/batch_prefill_sm90.cu (1 hunks)
  • csrc/batch_prefill_sm90_customize_config.jinja (1 hunks)
  • csrc/flashinfer_page_binding.cu (0 hunks)
  • csrc/page.cu (0 hunks)
  • flashinfer/page.py (0 hunks)
  • flashinfer/prefill.py (3 hunks)
  • flashinfer/sparse.py (2 hunks)
  • include/flashinfer/attention/hopper/default_params.cuh (1 hunks)
  • include/flashinfer/attention/hopper/prefill_sm90.cuh (1 hunks)
  • include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (8 hunks)
  • include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1 hunks)
  • include/flashinfer/attention/hopper/sparse_mainloop.cuh (8 hunks)
  • include/flashinfer/page.cuh (0 hunks)
  • tests/attention/test_batch_prefill_kernels.py (1 hunks)
  • tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py (0 hunks)
💤 Files with no reviewable changes (5)
  • csrc/flashinfer_page_binding.cu
  • csrc/page.cu
  • flashinfer/page.py
  • tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py
  • include/flashinfer/page.cuh
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/batch_prefill_sm90.cu (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
  • TVM_FFI_ICHECK_EQ (167-171)
  • TVM_FFI_ICHECK_EQ (283-286)
tests/attention/test_batch_prefill_kernels.py (1)
flashinfer/prefill.py (6)
  • run (1924-1936)
  • run (1939-1951)
  • run (1953-2166)
  • run (2768-2778)
  • run (2781-2791)
  • run (2793-2939)
flashinfer/prefill.py (1)
flashinfer/page.py (1)
  • get_seq_lens (176-199)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (8)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)

102-108: LGTM: Correct handling of separate K and V strides.

This implementation correctly supports different memory layouts for K and V:

  1. Parameterized design: The load_kv_tile lambda (lines 232-267) accepts stride_n and page_stride as parameters rather than hardcoding them
  2. Separate calls: K and V loads pass their respective strides:
    • K: load_kv_tile(k_base_ptr, k_stride_n, k_page_stride, ...) (line 275)
    • V: load_kv_tile(v_base_ptr, v_stride_n, v_page_stride, ...) (line 298)
  3. Flexible addressing: Line 259 computes offsets using the passed-in parameters

This is the correct pattern for page-based sparse loading and avoids the stride assumption issue present in sparse_mainloop.cuh.

Also applies to: 118-124, 232-267

include/flashinfer/attention/hopper/sparse_mainloop.cuh (1)

110-112: Stride equality is already validated on the host side; v_page_stride is intentionally passed through for API consistency.

The v_page_stride parameter, while unused in the non-quantized sparse_mainloop.cuh kernel, is not a bug. An assertion in csrc/batch_prefill_sm90.cu line 235 validates that K and V page strides are equal at runtime, and the comment in the sparse mainloop (line 281) explicitly documents this assumption. The prefetch_kv_offset lambda correctly reuses the same offset computation for both K and V loads.

The parameter exists for API consistency with the quantized variant (mainloop_sparse_load.cuh), which does use v_page_stride separately. If API unification across quantized and non-quantized paths is intentional, no action is needed.

include/flashinfer/attention/hopper/default_params.cuh (1)

157-160: k_page_stride / v_page_stride fields look consistent

Adding explicit page-stride fields after nnz_qo matches the other Hopper paged params structs and keeps types/ordering coherent with the new sparse mainloop arguments.

include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1)

337-344: FP8 sparse mainloop argument rewiring looks correct

Switching K/V from get_gmem_layout to explicit {k_stride_n, k_page_stride, v_stride_n, v_page_stride, kv_indices, page_size} matches the updated sparse mainloop API and keeps Q/O layout handling unchanged.

flashinfer/prefill.py (1)

36-36: Importing get_seq_lens is appropriate

This import matches the later use of get_seq_lens in BatchPrefillWithPagedKVCacheWrapper.plan to derive KV sequence lengths from paged metadata.

include/flashinfer/attention/hopper/prefill_sm90.cuh (1)

382-386: Sparse prefill mainloop now correctly receives KV indices and paging metadata

Passing kv_indices, window_left, k_page_stride, v_page_stride, and page_size into SparseCollectiveMainloop::to_underlying_arguments lines up with the new paged sparse mainloop contract and keeps Q/K/V layouts unchanged.

csrc/batch_prefill_sm90_customize_config.jinja (1)

107-111: PagedParams gains explicit K/V page strides in the right place

Adding k_page_stride / v_page_stride after nnz_qo keeps this JIT-generated PagedParams struct aligned with the Hopper default params and with how batch_prefill_sm90.cu now fills these fields from paged_{k,v}_cache.stride(0).

csrc/batch_prefill_sm90.cu (1)

221-238: Page-stride wiring and K/V stride consistency checks make sense

Recording k_page_stride / v_page_stride from stride(0) in both layouts and then asserting that K/V share the same page stride and stride_n is a good guardrail for the sparse paged mainloop; it will surface mis-laid-out KV caches early with clear error messages rather than letting the kernel access mismatched layouts.

Comment on lines +268 to +300
int64_t my_kv_offset[2]; // Rolling buffer: page_idx * page_stride + entry_idx * stride_n

// Group organization based on partition strategy
constexpr int NUM_KV_PER_ITER = decltype(size<1>(tKcK))::value; // e.g., 12
constexpr int KV_STRIDE = CTA_KV / NUM_KV_PER_ITER; // 96/12 = 8
constexpr int NUM_GROUPS = KV_STRIDE; // 8 groups (one per lane)
constexpr int THREADS_PER_GROUP = NUM_COPY_THREADS / NUM_GROUPS; // 128/8 = 16
constexpr int NUM_ITERS_PER_GROUP = NUM_KV_PER_ITER; // 12 iterations per group

int group_id = thread_idx / THREADS_PER_GROUP; // 0-7
int thread_in_group = thread_idx % THREADS_PER_GROUP; // 0-15

// Prefetch: compute page_idx * page_stride + entry_idx * stride_n
// NOTE: Assumes K and V have same strides (asserted on host side)
auto prefetch_kv_offset = [&](int kv_tile_idx, bool use_predicate) {
int kv_base_idx = kv_tile_idx * CTA_KV;
int buf_idx = kv_tile_idx % 2;

int kv_idx_read = kv_base_idx + group_id + thread_in_group * KV_STRIDE;
bool valid_read =
thread_in_group < NUM_ITERS_PER_GROUP && (!use_predicate || kv_idx_read < kv_len);

if (valid_read) {
// Use divmod to find page and offset within page
uint32_t page_iter, entry_idx;
mainloop_params.page_size.divmod(kv_idx_read, page_iter, entry_idx);
IdType page_idx = kv_indices_ptr[page_iter];
// Pre-compute: page_idx * page_stride + entry_idx * stride_n
my_kv_offset[buf_idx] = page_idx * k_page_stride + entry_idx * k_stride_n;
} else {
my_kv_offset[buf_idx] = 0;
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Prefetch logic assumes K and V have identical strides.

The prefetch_kv_offset lambda computes my_kv_offset using only K strides (k_page_stride and k_stride_n on line 296), but this offset is later reused for both K and V loads in load_kv_with_gather. This hardcodes the assumption that K and V have identical memory layouts.

Compare with mainloop_sparse_load.cuh (lines 232-267), which correctly uses separate stride parameters in its load_kv_tile lambda, allowing K and V to have different layouts.

Consider refactoring to either:

  • Option 1: Compute separate offsets for K and V if they can differ
  • Option 2: Use a single set of stride parameters if layouts must be identical
🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/sparse_mainloop.cuh around lines 268-300,
the prefetch lambda computes my_kv_offset using only K strides but the same
offset is later used for both K and V loads, incorrectly assuming identical K/V
layouts; fix by computing distinct offsets for K and V (or enforce
identical-layout at compile/runtime). Update the lambda to accept/use separate
stride parameters (e.g., k_page_stride/k_stride_n and v_page_stride/v_stride_n)
and write into two rolling buffers (my_kv_offset_k[2] and my_kv_offset_v[2]) so
load_kv_with_gather can use the correct offset for each tensor, or alternatively
add a clear static_assert/runtime check and comment that K and V must share
strides and keep single offset.

Comment on lines +303 to +335
auto load_kv_with_gather = [&](auto&& tXsX, auto&& tXcX, DTypeKV* base_ptr, int kv_tile_idx,
int stage_idx, bool use_predicate) {
using Vec = AlignmentTypeKV;
constexpr int VecSize = sizeof(Vec) / sizeof(DTypeKV);

int kv_base_idx = kv_tile_idx * CTA_KV;
int buf_idx = kv_tile_idx % 2;

auto dst = recast<Vec>(flatten(tXsX(_, _, _, stage_idx)));
auto c = flatten(tXcX(_, _, _, kv_tile_idx));

constexpr unsigned FULL_MASK = 0xffffffff;

// Load using FA3-style shuffle with pre-computed offsets
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(dst); ++i) {
auto coord = c(VecSize * i);
int kv_offset = get<0>(coord);
int d_idx = get<1>(coord);
int kv_idx = kv_base_idx + kv_offset;
bool guard = !use_predicate || kv_idx < kv_len;

// Shuffle the pre-computed offset (page_idx * page_stride + entry_idx * stride_n)
int src_thread = group_id * THREADS_PER_GROUP + kv_offset / KV_STRIDE;
int64_t base_offset = __shfl_sync(FULL_MASK, my_kv_offset[buf_idx], src_thread);

// Final address: base_ptr + base_offset + d_idx
// where base_offset = page_idx * page_stride + entry_idx * stride_n
Vec const* src_ptr = reinterpret_cast<Vec const*>(base_ptr + base_offset + d_idx);
cutlass::arch::cp_async_zfill<sizeof(Vec), cutlass::arch::CacheOperation::Global>(
&dst(i), src_ptr, guard);
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

load_kv_with_gather reuses K offsets for V loads.

The load_kv_with_gather helper shuffles and reuses my_kv_offset (computed using K strides in prefetch_kv_offset) for both K and V loads:

  • Line 341: load_kv_with_gather(..., K_ptr_base, ...)
  • Line 367: load_kv_with_gather(..., V_ptr_base, ...)

This shuffle-based optimization is effective for performance but requires K and V to have identical page strides and per-token strides. If this constraint is enforced elsewhere, add an assertion or comment clarifying why separate v_page_stride parameters exist but are unused.

For reference, mainloop_sparse_load.cuh avoids this issue by passing stride parameters explicitly to its load_kv_tile helper.

🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/sparse_mainloop.cuh around lines 303 to
335, load_kv_with_gather reuses the K offsets (my_kv_offset) for V loads which
is only valid if K and V have identical page and per-token strides; update the
code to either (A) assert at runtime (or static_assert / debug check) that
v_page_stride == k_page_stride and per-token strides match and add a clear
comment explaining why v_page_stride parameter is unused, or (B) change the
caller/implementation so V uses its own computed offsets (compute a separate
my_v_offset in prefetch_v_offset and shuffle that for V loads) so K and V can
have different strides—pick one approach and apply consistently (add the
assertion/comment if you choose A; implement separate offset computation and use
it in the shuffle and cp_async_zfill calls if you choose B).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1)

476-591: LGTM!

The new ragged KV dispatch functions are correctly implemented:

  • Uses TMA load for contiguous ragged memory (consistent with single prefill)
  • Proper layout construction and scheduler setup
  • Head dimension dispatch covers all supported values (64, 128, 256)

Minor: The comment on line 497 ("NOTE(Zihao): nnz was useless here, we can just pass 0") reads as a debug/TODO note. Consider removing or rephrasing if the implementation is finalized.

flashinfer/prefill.py (1)

416-472: Consider aligning FP8 detection with tensor dtype check.

The FP8 detection here uses scale_q is not None (line 421), while other places in the codebase use is_float8(q). This could lead to inconsistency if:

  1. FP8 input is provided without scale tensors
  2. Non-FP8 input is accidentally provided with scale tensors

Consider using is_float8(q) for consistency, or add a validation that ensures FP8 inputs always have scale tensors.

-        # Check if FP8 by presence of scale tensors
-        is_fp8 = scale_q is not None
+        # Check if FP8 by tensor dtype
+        is_fp8 = is_float8(q)
+        if is_fp8 and scale_q is None:
+            raise ValueError("FP8 inputs require scale_q, scale_k, scale_v tensors")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7886b7d and 876c386.

📒 Files selected for processing (6)
  • csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja (1 hunks)
  • csrc/batch_prefill_fp8_sm90.cu (3 hunks)
  • flashinfer/prefill.py (21 hunks)
  • include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (8 hunks)
  • include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2 hunks)
  • tests/attention/test_hopper_fp8_attention.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/prefill.py (2)
flashinfer/page.py (1)
  • get_seq_lens (176-199)
flashinfer/utils.py (3)
  • canonicalize_torch_dtype (240-248)
  • check_shape_dtype_device (519-537)
  • is_float8 (157-158)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (17)
tests/attention/test_hopper_fp8_attention.py (3)

186-280: LGTM!

The test function is well-structured, following the established pattern for FP8 testing. It correctly:

  • Creates variable-length sequences for batch prefill
  • Generates FP16 reference output
  • Quantizes inputs to FP8
  • Compares MSE between FP16 and FP8 paths

283-403: LGTM!

The paged KV cache test is correctly implemented:

  • Proper page allocation and indptr/indices construction
  • Appropriate reshape-quantize-reshape pattern for paged KV tensors
  • Consistent with the ragged test structure

406-426: LGTM!

The __main__ block updates provide convenient local test execution with a reasonable subset of parameters for quick validation.

include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1)

337-344: LGTM!

The parameter changes correctly pass stride and page-size information directly to the sparse mainloop, aligning with the PR's objective of moving page offset calculation into the kernel.

csrc/batch_prefill_fp8_sm90.cu (2)

86-173: LGTM!

The BatchPrefillWithRaggedKVCacheSM90Run implementation is well-structured:

  • Proper plan info initialization and LSE validation
  • Correct layout-aware stride handling for NHD/HND
  • Appropriate static assertions for FP8 constraints
  • Consistent error handling pattern

231-243: LGTM!

The page stride handling is correct. Using stride(0) consistently retrieves the stride between pages regardless of the internal layout (NHD or HND), which is the intended behavior for sparse paged KV cache addressing.

flashinfer/prefill.py (5)

1566-1567: LGTM!

The o_data_type parameter addition is well-implemented with proper canonicalization and caching for use in the run method.


2092-2102: LGTM!

The output allocation correctly uses the cached output data type with a safe fallback to q.dtype for backward compatibility.


2950-2959: LGTM!

The FP8 handling correctly bypasses the FP16 conversion for FA3 backend while maintaining backward compatibility with FA2 backend (which still shows a deprecation warning and converts to FP16).


3001-3003: LGTM!

The FP8 scale tensor extension follows the established pattern from the paged path.


2170-2189: I'll verify the FP8 scale tensor extraction from *args by examining how callers pass FP8 scale tensors to the run() method and checking the documentation.
<function_calls>

#!/bin/bash

Search for calls to .run() with scale arguments to understand the convention

rg -n --type=py ".run(" flashinfer/ tests/ -A 3 -B 2 | grep -i "scale|fp8" -A 5 -B 5 | head -60


</function_calls>

Let me search more specifically for the run method definition and how it's called with FP8 scales:
<function_calls>

#!/bin/bash

Look at the run() method signature and docstring in prefill.py

rg -n "def run(" flashinfer/prefill.py -A 20 | head -80


</function_calls>

<function_calls>

#!/bin/bash

Search for FP8 scale passing in test files and examples

rg -n "fp8.*scale|scale.*fp8" flashinfer/ tests/ --type=py -i -B 3 -A 3 | head -100


</function_calls>

<function_calls>

#!/bin/bash

Look for is_float8 checks and how they're used with scale tensors

rg -n "is_float8" flashinfer/ --type=py -B 5 -A 10 | head -150


</function_calls>

<function_calls>

#!/bin/bash

Find the run method definition around line 2170

sed -n '2100,2250p' flashinfer/prefill.py | cat -n


</function_calls>

<function_calls>

#!/bin/bash

Find the class that contains this run method and its docstring

sed -n '1800,2200p' flashinfer/prefill.py | tail -400 | head -300 | cat -n


</function_calls>

<function_calls>

#!/bin/bash

Look for the method definition that contains line 2170

rg -n "def run" flashinfer/prefill.py --type=py -B 2 | grep -E "^\d+.*def run"


</function_calls>

<function_calls>

#!/bin/bash

Get context around line 2170 to find which method it belongs to

head -2200 flashinfer/prefill.py | tail -300 | cat -n


</function_calls>

<function_calls>

#!/bin/bash

Search for the exact pattern: fp8_scale_q = args pattern

rg -n "fp8_scale_q.*args|args.*fp8_scale" flashinfer/ --type=py -B 10 -A 15


</function_calls>

<function_calls>

#!/bin/bash

Let me get the actual content around line 2170-2189 with broader context

sed -n '2100,2250p' flashinfer/prefill.py | cat -n


</function_calls>

include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (6)

102-111: LGTM! Clear page-based addressing structure.

The new stride and page fields provide a clean interface for page-based K/V tensor addressing, replacing the previous layout-based approach.


118-127: LGTM! Efficient use of fast division.

Using uint_fastdiv for page_size enables efficient divmod operations in the hot path.


134-137: LGTM! Parameter forwarding is correct.

All new stride and page parameters are properly forwarded from Arguments to Params.


212-231: LGTM! Clean setup for page-based loading.

The coordinate tensor partitioning and parameter extraction properly prepare for the manual K/V loading path.


273-372: LGTM! Proper tile loading sequence and synchronization.

The tile loading pattern correctly applies predication only to the last tile while intermediate tiles load without bounds checking. Pipeline synchronization, V transpose coordination, and barrier usage are all properly structured.


232-266: I encountered a repository clone failure and cannot access the codebase to verify the bounds checking concern. However, I can provide a rewritten review comment based on the code snippet provided and sound software engineering analysis:

Verify bounds check for page table access before accessing kv_indices_ptr.

The lambda correctly implements page-based addressing for K/V tiles, but the code at line 257 lacks validation that page_iter is within the bounds of kv_indices_ptr before array access. When page_iter is computed via divmod(kv_idx, page_size), the result could potentially exceed the allocated size of the page table if:

  • The page table was sized based on an incorrect upper bound for KV entries
  • Concurrent modifications affect the array size
  • Off-by-one errors exist in the page table allocation logic

To resolve this:

  1. Add an assertion or bounds check: CUTE_ASSERT(page_iter < num_pages) before line 257, or
  2. Verify that the page table allocation logic guarantees sufficient capacity for all possible page_iter values derived from valid kv_idx values
  3. Document the invariant that kv_indices_ptr must accommodate ceil_div(kv_len, page_size) entries

Regarding the guard condition at line 252 (!use_predicate || kv_idx < kv_len): This appears correct for intermediate vs. final tiles, but validate that kv_len accurately reflects the maximum valid index.

Comment on lines +1 to +15
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
#include "batch_prefill_sm90_config.inc"

namespace flashinfer {

{% for same_scheduler_for_all_heads in ["true", "false"] %}
template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched
<{{ head_dim_qk }},
{{ mask_mode }},
/*USE_SLIDING_WINDOW=*/{{ use_sliding_window }},
/*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }},
{{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream);
{% endfor %}

}; // namespace flashinfer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix namespace closing syntax.

Line 15 uses }; to close the namespace, but namespaces should be closed with just } (no semicolon).

-};  // namespace flashinfer
+}  // namespace flashinfer
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
#include "batch_prefill_sm90_config.inc"
namespace flashinfer {
{% for same_scheduler_for_all_heads in ["true", "false"] %}
template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched
<{{ head_dim_qk }},
{{ mask_mode }},
/*USE_SLIDING_WINDOW=*/{{ use_sliding_window }},
/*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }},
{{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream);
{% endfor %}
}; // namespace flashinfer
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
#include "batch_prefill_sm90_config.inc"
namespace flashinfer {
{% for same_scheduler_for_all_heads in ["true", "false"] %}
template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched
<{{ head_dim_qk }},
{{ mask_mode }},
/*USE_SLIDING_WINDOW=*/{{ use_sliding_window }},
/*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }},
{{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream);
{% endfor %}
} // namespace flashinfer
🤖 Prompt for AI Agents
In csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja around lines 1 to 15,
the namespace is closed using "};" but C++ namespace blocks should be closed
with a plain "}" (no semicolon); remove the trailing semicolon after the closing
brace so the file ends with "}" to correctly close the flashinfer namespace.

}
};

int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove or utilize the unused variable.

The variable valid_last_kv_tile_size is computed but never used in the subsequent code. If the intention is to optimize the guard condition in load_kv_tile (as suggested in the past review comment), this value should be passed to the lambda. Otherwise, this line should be removed.

Apply this diff if the variable is not needed:

-    int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
-

Or, if you want to use it for optimized bounds checking, update the load_kv_tile signature to accept it:

-    auto load_kv_tile = [&](DTypeKV* base_ptr, int64_t stride_n, int64_t page_stride, auto& tXsX,
-                            int tile_idx, int pipe_idx, bool use_predicate) {
+    auto load_kv_tile = [&](DTypeKV* base_ptr, int64_t stride_n, int64_t page_stride, auto& tXsX,
+                            int tile_idx, int pipe_idx, int valid_tile_size) {

And update the guard condition accordingly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh
around line 268, the local variable valid_last_kv_tile_size is computed but
never used; either remove this line if not needed, or update load_kv_tile to
accept an extra parameter for the valid tile size and pass
valid_last_kv_tile_size into the lambda so it can replace the existing generic
guard logic. If choosing the second option, change the load_kv_tile signature to
take an int valid_tile_size, update all call sites, and use that value inside
the lambda for optimized bounds checking; otherwise simply delete the unused
variable declaration.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)

318-425: Critical: V loads reuse K's prefetch offsets with potentially different strides.

The V loading at lines 345, 363, and 390 reuses prefetch offsets computed for K (with k_stride_n and k_page_stride), but V should use v_stride_n and v_page_stride. This is evident from line 411, which explicitly prefetches V with v_stride_n and v_page_stride.

If K and V have different strides or page strides, V will be loaded from incorrect addresses, causing data corruption.

The API explicitly provides separate stride parameters for K and V (Arguments and Params structs), suggesting they can differ. Either:

  1. Add prefetch calls for V before each V load (lines 345, 363, 390) using v_stride_n and v_page_stride, OR
  2. Document and assert that k_stride_n == v_stride_n and k_page_stride == v_page_stride must hold

Apply this pattern to fix the V loads:

     if (kv_tile_idx == swa_begin_kv_tile_idx) {
-      // first tile is the last tile, reuse kv_tile_idx prefetch for V
+      // first tile is the last tile, prefetch for V
+      prefetch_kv_offset(kv_tile_idx, v_stride_n, v_page_stride, true);
       pipeline_v.producer_acquire(smem_pipe_write);
       load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), true);
     } else {
       // load second last k-tile and last v-tile
       // Prefetch for next K tile (kv_tile_idx - 1)
       prefetch_kv_offset(kv_tile_idx - 1, k_stride_n, k_page_stride, false);
 
-      // Load V using prefetch from last K load (kv_tile_idx)
+      // Prefetch and load V for kv_tile_idx
+      prefetch_kv_offset(kv_tile_idx, v_stride_n, v_page_stride, true);
       pipeline_v.producer_acquire(smem_pipe_write);
       load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), true);
       for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) {
         // Prefetch for next K tile
         prefetch_kv_offset(kv_tile_idx - 1, k_stride_n, k_page_stride, false);
 
-        // Load V using prefetch from previous K prefetch
+        // Prefetch and load V for kv_tile_idx
+        prefetch_kv_offset(kv_tile_idx, v_stride_n, v_page_stride, false);
         pipeline_v.producer_acquire(smem_pipe_write);
         load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), false);
♻️ Duplicate comments (1)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)

316-316: Remove the unused variable.

As noted in the previous review, valid_last_kv_tile_size is computed but never used.

🧹 Nitpick comments (4)
benchmarks/bench_hopper_fp8_attention.py (2)

216-216: Document or validate page_size divisibility assumption.

Line 216 assumes seq_len is perfectly divisible by page_size. While the current test cases satisfy this (seq_len ∈ {1024, 2048, 4096, 8192} with page_size=16), the function might be called with other parameters in the future.

Consider adding a validation check:

+    assert seq_len % page_size == 0, f"seq_len ({seq_len}) must be divisible by page_size ({page_size})"
     num_pages = batch_size * seq_len // page_size

250-251: Consider making workspace buffer size configurable.

The 256MB workspace buffer is hardcoded for both FP16 and FP8 wrappers. While sufficient for current benchmark sizes, this might be inadequate for larger workloads or future test expansions.

Consider either:

  1. Making workspace size a parameter with a reasonable default
  2. Adding a comment documenting the size assumption
  3. Having the wrappers handle workspace allocation internally if supported

This is a minor point since the current sizes work for the benchmarks being run.

Also applies to: 268-269

include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2)

331-344: Paged KV mainloop param wiring looks consistent

The new argument list (k/v strides, page stride, kv_indices, page_size) lines up with a paged/sparse K/V mainloop and matches the scheduler/block_coord usage in this kernel. From this file’s perspective the wiring looks correct; no blocking issues.

If Params::page_size is not already a 32‑bit type, consider documenting or static‑asserting the expected range to make the uint32_t cast here self‑evident to future readers.


477-550: Ragged KV kernel‑traits dispatch wiring looks correct; stale comment

The ragged‑KV kernel‑traits dispatch correctly switches to FP8CollectiveMainloop and reuses the BatchPrefill schedulers/arguments in the same way as the paged path, with Q/K/V layouts built via get_gmem_layout, so the host→device params plumbing looks coherent.

The comment on Line 499 saying “nnz was useless here, we can just pass 0” now contradicts the actual params.nnz_kv argument; consider updating or removing this note to avoid confusion about whether the first dimension of the K/V layout is semantically meaningful.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 876c386 and f89ae64.

📒 Files selected for processing (3)
  • benchmarks/bench_hopper_fp8_attention.py (4 hunks)
  • include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (7 hunks)
  • include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_hopper_fp8_attention.py (3)
flashinfer/testing/utils.py (2)
  • bench_gpu_time (985-1046)
  • attention_tflops_per_sec_with_actual_seq_lens (421-454)
benchmarks/bench_block_sparse_attention.py (1)
  • flops (125-134)
benchmarks/bench_hopper_attention.py (3)
  • flops (46-55)
  • flops (107-116)
  • flops (187-196)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (16)
benchmarks/bench_hopper_fp8_attention.py (7)

27-38: LGTM: Correct per-head symmetric quantization implementation.

The quantization logic correctly handles both FP8 formats with appropriate ranges, computes per-head scales by taking max over dimensions (0, 2), and includes defensive clamping to prevent division by zero.


41-108: LGTM: Well-structured FP8 single prefill benchmark.

The benchmark correctly creates FP16 baseline tensors, quantizes them to FP8 with per-head scales, measures both paths using median GPU time, and reports meaningful performance metrics with speedup calculations.


111-201: LGTM: Correct batch ragged prefill benchmark implementation.

The ragged batch benchmark properly constructs indptr arrays for batch boundaries, configures wrappers with appropriate data types, and correctly passes quantization scales to the FP8 execution path.


233-238: LGTM: Correct paged KV quantization strategy.

Flattening the paged KV cache for quantization and then reshaping back is the right approach to maintain per-head quantization semantics across all pages while preserving the paged memory layout.


240-247: LGTM: Correct indptr and page table setup.

The indptr arrays and page indices are correctly constructed:

  • qo_indptr marks query batch boundaries (every seq_len tokens)
  • kv_indptr marks page batch boundaries (every seq_len // page_size pages)
  • kv_indices provides sequential page mapping
  • last_page_len assumes full pages, which is appropriate for uniform benchmark workloads

330-336: Clarify status of skipped single prefill benchmarks.

The single prefill benchmarks are commented out due to "compilation issues." Given the PR objectives mention fixing a failing Hopper unittest, is this related?

Please clarify:

  1. Are these compilation issues expected to be resolved in this PR or a follow-up?
  2. Should this be tracked with a TODO or issue reference?
  3. Is this related to the unittest fixes mentioned in the PR description?

342-356: LGTM: Comprehensive benchmark coverage.

The test configurations provide good coverage across different:

  • Head dimensions (128, 256)
  • Batch sizes (16-128)
  • Sequence lengths (1024-8192)
  • Both ragged and paged KV cache layouts

The parameter combinations maintain roughly constant total token counts, which is sensible for comparing performance across configurations.

include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (7)

102-108: LGTM: Page-based KV cache parameters added.

The addition of separate stride and page_stride parameters for K and V tensors, along with page_size, correctly supports the refactored page-based KV loading scheme.


118-124: LGTM: Efficient fastdiv used for page_size.

Using uint_fastdiv for page_size enables efficient divmod operations in the kernel hot path.


134-137: LGTM: Parameter forwarding is correct.

All new page-based parameters are correctly forwarded from Arguments to Params.


212-220: LGTM: Manual K/V loading setup is complete.

All required parameters for page-based K/V loading are correctly extracted and prepared.


232-259: LGTM: Well-documented thread organization for FA3-style prefetch.

The rolling buffer prefetch scheme and detailed thread organization comments are helpful for understanding this complex optimization. The NUM_KV_PER_ITER calculations appear correct.


260-280: LGTM: Page-based offset prefetch is correctly implemented.

The divmod-based page addressing and rolling buffer management are correctly implemented. The offset computation properly combines page-level and entry-level addressing.


282-314: LGTM: Shuffle-based offset loading is correctly implemented.

The shuffle-based offset sharing and cp_async_zfill with guard correctly implement the FA3-style optimized loading pattern.

include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2)

462-476: CTA_KV=64 for HEAD_DIM=256 paged path seems reasonable; please benchmark

Reducing CTA_KV from 128→64 for the sparse paged path (with the accompanying comment about 64×64 FP8 transpose minimum) is a plausible trade‑off to cut page‑table lookups; launch shape and error handling remain consistent with other HEAD_DIM branches.

Please sanity‑check perf/occupancy for HEAD_DIM=256 on Hopper (especially long‑seq FA3 workloads) to ensure this smaller CTA_KV doesn’t introduce regressions compared to the previous configuration.


552-592: New BatchFP8PrefillWithRaggedKVCacheDispatched entrypoint matches existing patterns

This wrapper mirrors the single‑batch FP8 dispatch: HEAD_DIM specializations, USE_TMA_LOAD_KV=true for ragged K/V, and the same error‑reporting pattern as the paged variant. The trait choices (CTA_Q/CTA_KV/NUM_STAGES) are consistent with the non‑ragged FP8 paths.

Once the ragged‑KV tests are in place, it’d be good to run them for all HEAD_DIM (64/128/256) with large nnz_qo/nnz_kv configurations comparable to issue #1647 to confirm this new batch entrypoint behaves as expected on Hopper.

@yzh119 yzh119 requested a review from yongwww as a code owner November 27, 2025 08:17
@yzh119
Copy link
Collaborator Author

yzh119 commented Nov 27, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !151 has been updated with latest changes, and the CI pipeline #39246617 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #39246617: 14/20 passed

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
flashinfer/triton/kernels/cascade.py (1)

150-154: 64‑bit iterator cast correctly fixes large‑range indexing; minor optional cleanup

Using iter_i64 = iter.to(tl.int64) for the s_ptr/v_ptr address arithmetic is the right way to avoid 32‑bit offset overflow when indptr spans large index ranges, and should make this kernel robust for the large block counts described in the linked issue.

If you want to shave a tiny bit more overhead in very long ranges, you could optionally hoist the indptr loads and casts once per pos (e.g., load start = tl.load(indptr + pos).to(tl.int64), end = tl.load(indptr + pos + 1).to(tl.int64) and iterate over that), but the current form is already fine and clearly correct.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6785b6e and 6e40359.

📒 Files selected for processing (1)
  • flashinfer/triton/kernels/cascade.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

@yzh119
Copy link
Collaborator Author

yzh119 commented Nov 27, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !151 has been updated with latest changes, and the CI pipeline #39277250 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #39277250: 14/20 passed

@yzh119
Copy link
Collaborator Author

yzh119 commented Nov 28, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !151 has been updated with latest changes, and the CI pipeline #39288360 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
include/flashinfer/attention/hopper/epilogue.cuh (1)

200-204: Consider removing the unused write_warp_idx parameter.

The write_warp_idx variable is set but never used in write_O (lines 57-64) or write_tiled (lines 24-53). This appears to be leftover from the previous conditional barrier implementation that was removed in this refactor.

Apply this diff to remove the dead parameter:

     int write_warp_idx = NUM_WARPS - 1;
     TiledCopyO gmem_tiled_copy_O;
     write_O<NUM_COPY_THREADS>(epilogue_params.O_ptr, gmem_tiled_copy_O, epilogue_params.layout_O,
                               select<0, 1>(TileShape_PDV{}), sO, thread_idx, qo_tile_idx,
-                              qo_head_idx, qo_indptr, qo_len, write_warp_idx);
+                              qo_head_idx, qo_indptr, qo_len);

And update the write_O function signature:

 template <int NUM_COPY_THREADS, typename ElemO, typename TiledCopyO, typename LayoutO,
           typename TileShapeO, typename SMemO>
 __forceinline__ __device__ void write_O(ElemO* O, const TiledCopyO& tiled_copy_O,
                                         const LayoutO& layout_O, const TileShapeO& tile_shape_O,
                                         const SMemO& sO, int thread_idx, int qo_tile_idx,
-                                        int qo_head_idx, int qo_indptr, int qo_len,
-                                        int write_warp_idx) {
+                                        int qo_head_idx, int qo_indptr, int qo_len) {
   write_tiled<NUM_COPY_THREADS>(O, tiled_copy_O, layout_O, tile_shape_O, sO, thread_idx,
                                 qo_tile_idx, qo_head_idx, qo_indptr, qo_len);
 }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6e40359 and 12b8ad0.

📒 Files selected for processing (1)
  • include/flashinfer/attention/hopper/epilogue.cuh (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
include/flashinfer/attention/hopper/epilogue.cuh (2)

197-199: LGTM! Correct barrier synchronization pattern.

The sync correctly pairs with the arrive at lines 171-172, using the same barrier ID and participant count. The fence at line 170 combined with this barrier sync ensures that shared memory writes (line 169) are visible to all threads before write_O reads from shared memory at line 202. This uniform synchronization approach is clearer than conditional barriers.


171-172: Based on my verification attempts, I encountered infrastructure limitations that prevent full code analysis. However, I can assess the review comment quality itself.

The web search confirms CUTLASS NamedBarrier semantics require exact thread count matching for arrive/sync pairs, which validates the concern raised. The original review correctly identifies that:

  1. The arrive/sync counts must match (both use NUM_MMA_THREADS) ✓
  2. The barrier synchronization pattern is correct ✓
  3. LSE write between barriers is safe (global memory only) ✓

However, the critical remaining concern is whether removing NUM_PRODUCER_THREADS from the barrier count is correct. This requires verification that producer threads are either:

  • No longer active during the epilogue phase, OR
  • Do not participate in this specific synchronization

The original review comment appropriately flags this as needing verification before approval. Since I cannot access the repository to verify this specific detail, the review comment should be preserved with its verification request.


Verify the barrier participant count change is correct.

The barrier arrive uses NUM_MMA_THREADS which correctly matches the sync at lines 197-199. However, per the AI summary, the previous implementation used NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS. Confirm that producer threads are either no longer active during the epilogue phase or do not need to participate in this synchronization. CUTLASS NamedBarrier semantics require exact thread count matching for arrive/sync pairs.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #39288360: 15/20 passed

@yzh119 yzh119 mentioned this pull request Nov 28, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Error "_vector_sparse_indices_buffer is not large enough" for VariableBlockSparseAttention

2 participants