-
Notifications
You must be signed in to change notification settings - Fork 541
Bump tvm ffi to stable version 0.1.0 #1960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
<!-- .github/pull_request_template.md --> ## 📌 Description This PR fixes the dev container after #1880. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
WalkthroughThis PR migrates internal tensor access from pointer-style fields to TensorView accessor methods ( Changes
Sequence Diagram(s)(omitted — changes are API-accessor migration and dependency updates; no control-flow changes warranting sequence diagrams) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧬 Code graph analysis (1)csrc/trtllm_fused_moe_kernel_launcher.cu (2)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
🔇 Additional comments (1)
Comment |
Summary of ChangesHello @cyx-6, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request undertakes a comprehensive refactoring of the FlashInfer codebase to integrate the stable 0.1.0 release of the Highlights
Ignored Files
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request bumps the tvm-ffi
dependency to the stable version 0.1.0 and updates the codebase to be compatible with the new API. The changes are mostly mechanical, replacing pointer-style access (->
) with member access (.
) and updating method calls like shape
to size
, data
to data_ptr
, etc. This is a necessary and well-executed update to align with the new tvm-ffi
version. I've also identified a potential logic issue in csrc/tgv_gemm.cu
related to tensor shapes that should be addressed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 21
♻️ Duplicate comments (1)
csrc/cudnn_sdpa_kernel_launcher.cu (1)
386-387
: Same pointer arithmetic pattern as lines 344-345.Same concern as lines 344-345 regarding pointer arithmetic on
data_ptr()
. Consider explicit casting tostd::byte*
before arithmetic.
🧹 Nitpick comments (19)
csrc/cutlass_mla.cu (1)
26-45
: LGTM! Clean migration to the new TensorView accessor API.The migration from pointer-based field access to method-based accessors is implemented correctly and consistently throughout the function:
- Device access via
.device()
- Shape dimensions via
.size(i)
- Data type via
.dtype()
- Data pointers via
.data_ptr()
All CUDA API calls and kernel invocations are properly updated to use the new accessors.
Optional: Consider caching the device object.
q_nope_pe.device()
is called three times (lines 26, 27, 29). While not a critical issue, you could cache it for slightly cleaner code:+ const DLDevice device = q_nope_pe.device(); - cudaSetDevice(q_nope_pe.device().device_id); - const cudaStream_t stream = get_stream(q_nope_pe.device()); + cudaSetDevice(device.device_id); + const cudaStream_t stream = get_stream(device); - int device_index = q_nope_pe.device().device_id; + int device_index = device.device_id;flashinfer/jit/activation.py (1)
37-38
: API migration looks correct; consider fixing spacing.The migration from
input->shape[input->ndim - 1]
toinput.size(input.ndim() -1)
is correct.Minor spacing inconsistency:
ndim() -1
has a space only before the minus sign. For consistency, usendim() - 1
.Apply this diff to fix spacing:
- int d = input.size(input.ndim() -1) / 2; - int64_t num_tokens = input.numel() / input.size(input.ndim() -1); + int d = input.size(input.ndim() - 1) / 2; + int64_t num_tokens = input.numel() / input.size(input.ndim() - 1);csrc/group_gemm_fp8_groupwise_sm100.cu (2)
97-98
: Make max_m extraction robust to 1D/2D SFA (match SM120 logic)Avoid assuming SFA is 2D; mirror the SM120 approach.
- int max_m = SFA.size(1); + int max_m = SFA.size(SFA.ndim() > 1 ? 1 : 0);
109-116
: Workspace byte-size assumes 1D; add shape guardsize(0) is only correct for 1D buffers. Guard or compute nbytes.
+ TVM_FFI_ICHECK_EQ(int_workspace_buffer.ndim(), 1) + << "int_workspace_buffer must be 1D"; + TVM_FFI_ICHECK_EQ(float_workspace_buffer.ndim(), 1) + << "float_workspace_buffer must be 1D"; auto status = flashinfer::group_gemm::CutlassFP8GroupwiseScaledGroupGEMMSM100< SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, SCALE_MAJOR_K, MMA_SM>(static_cast<int*>(int_workspace_buffer.data_ptr()), get_element_size(int_workspace_buffer) * int_workspace_buffer.size(0), static_cast<float*>(float_workspace_buffer.data_ptr()), get_element_size(float_workspace_buffer) * float_workspace_buffer.size(0),csrc/batch_mla_sm90_run.cu (2)
48-58
: Avoid silent narrowing of strides to 32-bitstride() is typically 64-bit. Casting to unsigned int can overflow for very large tensors. Either keep 64-bit in Params or add explicit range checks before casts.
- unsigned int q_nope_stride_n = q_nope.stride(0); - unsigned int q_nope_stride_h = q_nope.stride(1); + TVM_FFI_ICHECK_LE(q_nope.stride(0), std::numeric_limits<unsigned int>::max()); + TVM_FFI_ICHECK_LE(q_nope.stride(1), std::numeric_limits<unsigned int>::max()); + unsigned int q_nope_stride_n = static_cast<unsigned int>(q_nope.stride(0)); + unsigned int q_nope_stride_h = static_cast<unsigned int>(q_nope.stride(1)); ... - unsigned int o_stride_n = o.stride(0); - unsigned int o_stride_h = o.stride(1); + TVM_FFI_ICHECK_LE(o.stride(0), std::numeric_limits<unsigned int>::max()); + TVM_FFI_ICHECK_LE(o.stride(1), std::numeric_limits<unsigned int>::max()); + unsigned int o_stride_n = static_cast<unsigned int>(o.stride(0)); + unsigned int o_stride_h = static_cast<unsigned int>(o.stride(1));
101-103
: Make uint_fastdiv inputs explicit 32-bitExplicit casts avoid ambiguity and make intent clear.
- params.num_heads = uint_fastdiv(num_heads); - params.block_size = uint_fastdiv(page_size); + params.num_heads = uint_fastdiv(static_cast<unsigned int>(num_heads)); + params.block_size = uint_fastdiv(static_cast<unsigned int>(page_size));csrc/group_gemm_fp8_groupwise_sm120.cu (1)
97-99
: Workspace byte-size assumes 1D; add guardsAssert 1D for both workspace buffers before using size(0).
- int max_m = SFA.size(SFA.ndim() > 1 ? 1 : 0); + int max_m = SFA.size(SFA.ndim() > 1 ? 1 : 0); + TVM_FFI_ICHECK_EQ(int_workspace_buffer.ndim(), 1) + << "int_workspace_buffer must be 1D"; + TVM_FFI_ICHECK_EQ(float_workspace_buffer.ndim(), 1) + << "float_workspace_buffer must be 1D";Also applies to: 110-117
csrc/sampling.cu (2)
41-43
: Workspace nbytes assumes 1D; add guardUse size(0) only if buffer is 1D.
- temperature_val, workspace_buffer.data_ptr(), + temperature_val, workspace_buffer.data_ptr(), get_element_size(workspace_buffer) * workspace_buffer.size(0), enable_pdl, stream); + TVM_FFI_ICHECK_EQ(workspace_buffer.ndim(), 1) + << "workspace_buffer must be 1D";
111-117
: Minor: be consistent with explicit casts and 1D workspace expectationsNo functional issue; optionally add explicit casts where kernels expect 32-bit and assert 1D work buffers where used.
Also applies to: 135-141, 167-169
csrc/norm.cu (1)
180-186
: Dtype checks are good; consider early-returning before dispatch for clarityNo bug; optional readability tweak to place dtype asserts before stream creation/dispatch.
csrc/batch_mla_config.jinja (1)
16-16
: Accessor migration to data_ptr() is correct; consider dtype guard.Optional: assert profiler_buffer.dtype is uint64 to avoid misaligned casts in debug builds.
.github/workflows/nightly-release.yml (1)
101-101
: tvm-ffi range OK; consider centralizing the spec.Define TVM_FFI_SPEC="apache-tvm-ffi>=0.1,<0.2" at job/workflow level and reuse to avoid drift between workflows.
csrc/page.cu (2)
73-77
: Add explicit include for std::equal.Relying on transitive includes is brittle; explicitly include to ensure std::equal is available.
You can add near the top of this file:
#include <algorithm>
163-168
: Potential narrowing from int64 to unsigned int.append_ckv.size(0) and kv_last_page_len.size(0) return int64; casting to unsigned int may overflow on very large tensors. If kernel ABI requires 32-bit, add a guard:
- unsigned int nnz = append_ckv.size(0); - unsigned int batch_size = kv_last_page_len.size(0); + int64_t nnz64 = append_ckv.size(0); + int64_t batch_size64 = kv_last_page_len.size(0); + TVM_FFI_ICHECK_LE(nnz64, std::numeric_limits<unsigned int>::max()); + TVM_FFI_ICHECK_LE(batch_size64, std::numeric_limits<unsigned int>::max()); + unsigned int nnz = static_cast<unsigned int>(nnz64); + unsigned int batch_size = static_cast<unsigned int>(batch_size64);csrc/single_prefill.cu (1)
41-41
: Remove or silence unused variable head_dim_qk.It’s computed but never used; this may trigger warnings.
- unsigned int head_dim_qk = q.size(2); + // head_dim_qk not needed; template-dispatched HEAD_DIM_QK handles static dim. + // (remove to avoid unused variable warning)Alternatively:
[[maybe_unused]] const auto head_dim_qk = q.size(2);pyproject.toml (1)
30-30
: Move apache-tvm-ffi to runtime dependencies; no build-time usage detected.Verification confirms apache-tvm-ffi is not imported or used in build_backend.py. Since this dependency is not consumed during the build process, moving it from build requirements to runtime dependencies keeps the PEP 517 build environment lean without impact.
csrc/tgv_gemm.cu (1)
192-201
: Macro implementation verified—const-correctness is optional improvement.The
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16
macro at csrc/tvm_ffi_utils.h:83 correctly supports both fp16 and bf16 (via_DISPATCH_CASE_F16
and_DISPATCH_CASE_BF16
) and fails loudly on unsupported dtypes withTVM_FFI_ICHECK(false)
. The macro name is misleading but the implementation is sound.Regarding const-correctness:
mat1_ptr
,mat2_ptr
, andbias_ptr
in the kernel call (line 202) are read-only inputs. The function signature at line 89 could declare these asconst
pointers to improve const-correctness, though this is not a functional issue.csrc/nvshmem_binding.cu (2)
34-36
: Avoid potential unaligned write for nvshmem unique IDWrite/read the unique ID via a local, properly aligned object and memcpy into the TensorView buffer (keeps the CPU check intact).
Apply this diff:
- TVM_FFI_ICHECK_EQ(uid.device().device_type, kDLCPU); - nvshmemx_uniqueid_t* uid_ptr = reinterpret_cast<nvshmemx_uniqueid_t*>(uid.data_ptr()); - *uid_ptr = NVSHMEMX_UNIQUEID_INITIALIZER; - nvshmemx_get_uniqueid(uid_ptr); + TVM_FFI_ICHECK_EQ(uid.device().device_type, kDLCPU); + nvshmemx_uniqueid_t tmp = NVSHMEMX_UNIQUEID_INITIALIZER; + nvshmemx_get_uniqueid(&tmp); + std::memcpy(uid.data_ptr(), &tmp, sizeof(tmp));Add once near the includes:
+#include <cstring> // for std::memcpy
45-47
: Use aligned local unique ID and validate rank/world_size before initPrevent alignment issues and guard inputs; also cast to the expected nvshmem types.
Apply this diff:
- TVM_FFI_ICHECK_EQ(uid.device().device_type, kDLCPU); - nvshmemx_uniqueid_t* uid_ptr = reinterpret_cast<nvshmemx_uniqueid_t*>(uid.data_ptr()); - nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER; - nvshmemx_set_attr_uniqueid_args(rank, world_size, uid_ptr, &attr); + TVM_FFI_ICHECK_EQ(uid.device().device_type, kDLCPU); + TVM_FFI_ICHECK_GT(world_size, 0) << "world_size must be > 0"; + TVM_FFI_ICHECK_GE(rank, 0) << "rank must be >= 0"; + TVM_FFI_ICHECK_LT(rank, world_size) << "rank must be < world_size"; + nvshmemx_uniqueid_t tmp{}; + std::memcpy(&tmp, uid.data_ptr(), sizeof(tmp)); + nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER; + int r = static_cast<int>(rank); + int ws = static_cast<int>(world_size); + nvshmemx_set_attr_uniqueid_args(r, ws, &tmp, &attr);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (67)
.github/workflows/nightly-release.yml
(1 hunks).github/workflows/release.yml
(1 hunks)csrc/batch_attention.cu
(4 hunks)csrc/batch_decode.cu
(3 hunks)csrc/batch_decode_mla_cute_sm80.cu
(3 hunks)csrc/batch_decode_mla_plan.cu
(2 hunks)csrc/batch_decode_mla_run.cu
(1 hunks)csrc/batch_mla_config.jinja
(1 hunks)csrc/batch_mla_plan.cu
(1 hunks)csrc/batch_mla_run.cu
(2 hunks)csrc/batch_mla_sm90_plan.cu
(1 hunks)csrc/batch_mla_sm90_run.cu
(2 hunks)csrc/batch_prefill.cu
(3 hunks)csrc/batch_prefill_fp8_sm90.cu
(4 hunks)csrc/batch_prefill_sm90.cu
(6 hunks)csrc/blackwell_fmha_plan.cu
(1 hunks)csrc/bmm_fp8.cu
(1 hunks)csrc/cascade.cu
(3 hunks)csrc/cudnn_sdpa_kernel_launcher.cu
(19 hunks)csrc/cutlass_mla.cu
(1 hunks)csrc/fmha_cutlass_sm100.cu
(2 hunks)csrc/fp4_gemm_cutlass.cu
(3 hunks)csrc/fp4_gemm_cutlass_sm120.cu
(4 hunks)csrc/fp8_gemm_cutlass.cu
(3 hunks)csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
(14 hunks)csrc/gemm_groupwise_sm100.cu
(1 hunks)csrc/gemm_groupwise_sm120.cu
(2 hunks)csrc/group_gemm.cu
(1 hunks)csrc/group_gemm_fp8_groupwise_sm100.cu
(2 hunks)csrc/group_gemm_fp8_groupwise_sm120.cu
(2 hunks)csrc/group_gemm_mxfp4_groupwise_sm100.cu
(2 hunks)csrc/group_gemm_sm90.cu
(1 hunks)csrc/norm.cu
(6 hunks)csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp
(9 hunks)csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp
(8 hunks)csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp
(6 hunks)csrc/nvshmem_binding.cu
(4 hunks)csrc/page.cu
(5 hunks)csrc/pod.cu
(8 hunks)csrc/quantization.cu
(2 hunks)csrc/renorm.cu
(3 hunks)csrc/rope.cu
(8 hunks)csrc/sampling.cu
(8 hunks)csrc/single_decode.cu
(3 hunks)csrc/single_prefill.cu
(2 hunks)csrc/single_prefill_fp8_sm90.cu
(1 hunks)csrc/single_prefill_sm90.cu
(1 hunks)csrc/tgv_gemm.cu
(3 hunks)csrc/trtllm_allreduce.cu
(3 hunks)csrc/trtllm_allreduce_fusion.cu
(1 hunks)csrc/trtllm_alltoall.cu
(8 hunks)csrc/trtllm_fmha_kernel_launcher.cu
(3 hunks)csrc/trtllm_fused_moe_kernel_launcher.cu
(25 hunks)csrc/trtllm_gemm_runner.cu
(3 hunks)csrc/trtllm_low_latency_gemm_runner.cu
(3 hunks)csrc/trtllm_mnnvl_allreduce.cu
(3 hunks)csrc/trtllm_moe_allreduce_fusion.cu
(4 hunks)csrc/tvm_ffi_utils.h
(2 hunks)csrc/vllm_custom_all_reduce.cu
(3 hunks)csrc/xqa/xqa_wrapper.cu
(1 hunks)flashinfer-cubin/pyproject.toml
(1 hunks)flashinfer-jit-cache/pyproject.toml
(1 hunks)flashinfer/gemm.py
(1 hunks)flashinfer/jit/activation.py
(2 hunks)flashinfer/jit/attention/utils.py
(2 hunks)pyproject.toml
(1 hunks)requirements.txt
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (55)
csrc/batch_attention.cu (1)
csrc/tvm_ffi_utils.h (3)
get_element_size
(276-276)get_element_size
(278-280)get_stream
(272-274)
csrc/batch_prefill_sm90.cu (1)
csrc/tvm_ffi_utils.h (3)
get_element_size
(276-276)get_element_size
(278-280)get_stream
(272-274)
csrc/trtllm_allreduce_fusion.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/batch_mla_sm90_plan.cu (1)
csrc/tvm_ffi_utils.h (3)
get_element_size
(276-276)get_element_size
(278-280)get_stream
(272-274)
csrc/trtllm_allreduce.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/page.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/single_prefill_fp8_sm90.cu (2)
flashinfer/comm/cuda_ipc.py (1)
cudaSetDevice
(149-150)csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/group_gemm.cu (1)
csrc/tvm_ffi_utils.h (3)
get_stream
(272-274)get_element_size
(276-276)get_element_size
(278-280)
csrc/batch_decode_mla_cute_sm80.cu (1)
csrc/tvm_ffi_utils.h (3)
get_element_size
(276-276)get_element_size
(278-280)get_stream
(272-274)
csrc/xqa/xqa_wrapper.cu (2)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)csrc/xqa/mha.cu (2)
launchMHAFlashInfer
(2657-2743)launchMHAFlashInfer
(2657-2669)
csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp (1)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)
input
(579-660)input
(579-584)
csrc/batch_mla_plan.cu (1)
csrc/tvm_ffi_utils.h (3)
get_element_size
(276-276)get_element_size
(278-280)get_stream
(272-274)
csrc/quantization.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/renorm.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/trtllm_alltoall.cu (3)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (6)
input
(579-660)input
(579-584)output
(230-398)output
(230-238)output
(400-572)output
(400-411)csrc/tvm_ffi_utils.h (1)
get_current_stream
(266-270)csrc/trtllm_alltoall_prepare.cu (8)
computeCountAndIndice
(524-552)computeCountAndIndice
(524-528)computeCumsum
(554-561)computeCumsum
(554-555)moveIndice
(563-572)moveIndice
(563-566)allToAllMetadata
(574-610)allToAllMetadata
(574-579)
csrc/batch_decode_mla_plan.cu (1)
csrc/tvm_ffi_utils.h (3)
get_stream
(272-274)get_element_size
(276-276)get_element_size
(278-280)
csrc/single_prefill.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/fp8_gemm_cutlass.cu (1)
csrc/tvm_ffi_utils.h (2)
get_stream
(272-274)encode_dlpack_dtype
(29-31)
csrc/norm.cu (2)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)
input
(579-660)input
(579-584)csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (2)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)csrc/nv_internal/cpp/kernels/quantization.cu (4)
invokeBlockScaleInterleave
(303-314)invokeBlockScaleInterleave
(303-305)invokeBlockScaleInterleaveReverse
(317-326)invokeBlockScaleInterleaveReverse
(317-318)
csrc/bmm_fp8.cu (2)
flashinfer/comm/cuda_ipc.py (1)
cudaSetDevice
(149-150)csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/fp4_gemm_cutlass_sm120.cu (1)
csrc/tvm_ffi_utils.h (2)
get_stream
(272-274)encode_dlpack_dtype
(29-31)
csrc/nvshmem_binding.cu (1)
csrc/tvm_ffi_utils.h (4)
get_element_size
(276-276)get_element_size
(278-280)get_stream
(272-274)encode_dlpack_dtype
(29-31)
csrc/group_gemm_fp8_groupwise_sm100.cu (2)
flashinfer/comm/cuda_ipc.py (1)
cudaSetDevice
(149-150)csrc/tvm_ffi_utils.h (3)
get_stream
(272-274)get_element_size
(276-276)get_element_size
(278-280)
csrc/batch_decode_mla_run.cu (2)
flashinfer/comm/cuda_ipc.py (1)
cudaSetDevice
(149-150)csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/batch_prefill_fp8_sm90.cu (1)
csrc/tvm_ffi_utils.h (3)
get_element_size
(276-276)get_element_size
(278-280)get_stream
(272-274)
csrc/batch_decode.cu (1)
csrc/tvm_ffi_utils.h (3)
get_element_size
(276-276)get_element_size
(278-280)get_stream
(272-274)
csrc/group_gemm_sm90.cu (1)
csrc/tvm_ffi_utils.h (3)
get_stream
(272-274)get_element_size
(276-276)get_element_size
(278-280)
csrc/cutlass_mla.cu (2)
flashinfer/comm/cuda_ipc.py (1)
cudaSetDevice
(149-150)csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/single_decode.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/fmha_cutlass_sm100.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h (2)
MXFP8MXFP4
(345-354)FP8BlockScaling
(388-392)
csrc/group_gemm_fp8_groupwise_sm120.cu (1)
csrc/tvm_ffi_utils.h (3)
get_stream
(272-274)get_element_size
(276-276)get_element_size
(278-280)
csrc/trtllm_fmha_kernel_launcher.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/sampling.cu (2)
flashinfer/comm/cuda_ipc.py (1)
cudaSetDevice
(149-150)csrc/tvm_ffi_utils.h (3)
get_stream
(272-274)get_element_size
(276-276)get_element_size
(278-280)
csrc/trtllm_gemm_runner.cu (1)
csrc/tvm_ffi_utils.h (3)
get_stream
(272-274)get_element_size
(276-276)get_element_size
(278-280)
csrc/group_gemm_mxfp4_groupwise_sm100.cu (1)
csrc/tvm_ffi_utils.h (3)
get_stream
(272-274)get_element_size
(276-276)get_element_size
(278-280)
csrc/vllm_custom_all_reduce.cu (2)
flashinfer/comm/cuda_ipc.py (1)
cudaSetDevice
(149-150)csrc/tvm_ffi_utils.h (2)
get_stream
(272-274)encode_dlpack_dtype
(29-31)
csrc/pod.cu (2)
flashinfer/comm/cuda_ipc.py (1)
cudaSetDevice
(149-150)csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/trtllm_mnnvl_allreduce.cu (2)
flashinfer/comm/cuda_ipc.py (1)
cudaSetDevice
(149-150)csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/gemm_groupwise_sm100.cu (1)
csrc/tvm_ffi_utils.h (3)
get_stream
(272-274)get_element_size
(276-276)get_element_size
(278-280)
csrc/rope.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (7)
num_experts
(263-263)getMaxPermutedPaddedCount
(110-114)top_k
(270-270)intermediate_size
(275-275)getMaxNumCtasInBatchDim
(81-108)local_num_experts
(277-277)hidden_size
(265-265)csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/trtllm_low_latency_gemm_runner.cu (1)
csrc/trtllm_gemm_runner.cu (8)
m
(111-126)m
(111-111)m
(128-179)m
(128-130)m
(181-236)m
(181-181)m
(238-250)m
(238-238)
csrc/batch_mla_run.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/trtllm_moe_allreduce_fusion.cu (2)
flashinfer/comm/cuda_ipc.py (1)
cudaSetDevice
(149-150)csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/single_prefill_sm90.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/batch_prefill.cu (1)
csrc/tvm_ffi_utils.h (3)
get_element_size
(276-276)get_element_size
(278-280)get_stream
(272-274)
csrc/tgv_gemm.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/blackwell_fmha_plan.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/cudnn_sdpa_kernel_launcher.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/cascade.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/gemm_groupwise_sm120.cu (1)
csrc/tvm_ffi_utils.h (3)
get_stream
(272-274)get_element_size
(276-276)get_element_size
(278-280)
csrc/batch_mla_sm90_run.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (1)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: build-flashinfer-jit-cache (13.0, aarch64)
- GitHub Check: build-flashinfer-jit-cache (12.9, aarch64)
- GitHub Check: build-flashinfer-jit-cache (12.8, aarch64)
- GitHub Check: build-flashinfer-cubin
- GitHub Check: build (cu128, arm64)
- GitHub Check: build (cu129, arm64)
- GitHub Check: build (cu126, amd64)
- GitHub Check: Deploy Docs
.github/workflows/release.yml
Outdated
python -m pip install --upgrade pip | ||
pip install build twine wheel | ||
pip install setuptools>=61.0 requests filelock torch tqdm numpy apache-tvm-ffi==0.1.0b15 | ||
pip install setuptools>=61.0 requests filelock torch tqdm numpy "apache-tvm-ffi>=0.1,<0.2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Quote fix looks right; align setuptools minimum across projects.
This job installs setuptools>=61.0 while top-level pyproject uses >=77. Ensure flashinfer-cubin (and other subprojects) don’t require newer features; otherwise bump here for consistency.
🏁 Script executed:
#!/bin/bash
fd -a 'pyproject.toml' -x sh -c "echo {}; rg -n 'setuptools>=' {} || true"
Length of output: 502
Upgrade setuptools in workflow to match root project requirements.
The workflow installs setuptools>=61.0, but the root pyproject.toml and flashinfer-jit-cache both require setuptools>=77. This will cause build failures if the workflow builds the root project or flashinfer-jit-cache. Upgrade line 139 to setuptools>=77 for consistency with all subprojects' build requirements.
🤖 Prompt for AI Agents
.github/workflows/release.yml around line 139: the workflow currently installs
setuptools>=61.0 which conflicts with the root pyproject.toml and
flashinfer-jit-cache that require setuptools>=77; update the pip install line to
require setuptools>=77 so the CI installs a compatible setuptools version for
building root and subprojects, keeping the rest of the package list unchanged.
cudaSetDevice(q_nope.device().device_id); | ||
const cudaStream_t stream = get_stream(q_nope.device()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assert all tensors share the same device as q_nope
You set device/stream from q_nope. Add device equality checks for q_pe, caches, kv_indices, and o/maybe_lse to prevent cross-device misuse.
+ CHECK_DEVICE(q_pe, q_nope);
+ CHECK_DEVICE(ckv_cache, q_nope);
+ CHECK_DEVICE(kpe_cache, q_nope);
+ CHECK_DEVICE(kv_indices, q_nope);
+ CHECK_DEVICE(o, q_nope);
+ if (maybe_lse.has_value()) CHECK_DEVICE(maybe_lse.value(), q_nope);
cudaSetDevice(q_nope.device().device_id);
const cudaStream_t stream = get_stream(q_nope.device());
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
cudaSetDevice(q_nope.device().device_id); | |
const cudaStream_t stream = get_stream(q_nope.device()); | |
CHECK_DEVICE(q_pe, q_nope); | |
CHECK_DEVICE(ckv_cache, q_nope); | |
CHECK_DEVICE(kpe_cache, q_nope); | |
CHECK_DEVICE(kv_indices, q_nope); | |
CHECK_DEVICE(o, q_nope); | |
if (maybe_lse.has_value()) CHECK_DEVICE(maybe_lse.value(), q_nope); | |
cudaSetDevice(q_nope.device().device_id); | |
const cudaStream_t stream = get_stream(q_nope.device()); |
🤖 Prompt for AI Agents
In csrc/batch_mla_sm90_run.cu around lines 59-61, you set the CUDA device/stream
from q_nope but do not verify other tensors are on the same device; add explicit
checks that q_pe, caches, kv_indices, and o (and maybe_lse if present) have the
same device as q_nope (compare their device().device_id values) and fail fast
(assert/log+return) if any mismatch is found; optionally also ensure their
associated streams (if any) are compatible with the chosen stream to avoid
cross-device stream misuse.
uint16_t* k_ptr = reinterpret_cast<uint16_t*>(k.data_ptr() + batch_offset_k); | ||
uint16_t* v_ptr = reinterpret_cast<uint16_t*>(v.data_ptr() + batch_offset_v); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify pointer arithmetic safety on data_ptr()
.
The code performs pointer arithmetic directly on data_ptr()
results. If data_ptr()
returns void*
, this is non-standard C++ (though supported as a compiler extension). For portability and clarity, consider casting to std::byte*
or int8_t*
first, similar to the pattern used elsewhere in this file (lines 232-235):
🏁 Script executed:
#!/bin/bash
# Check the return type of data_ptr() method in TensorView
rg -n "data_ptr\s*\(" --type=cpp -A 2 -B 2 -g '!*test*' -g '!*build*' | head -50
Length of output: 5307
---
🏁 Script executed:
```shell
#!/bin/bash
# Extract context around lines 344-345 in the target file
cat csrc/cudnn_sdpa_kernel_launcher.cu | sed -n '330,360p' | cat -n
Length of output: 2231
🏁 Script executed:
#!/bin/bash
# Extract context around lines 232-235 to compare pointer arithmetic patterns
cat csrc/cudnn_sdpa_kernel_launcher.cu | sed -n '225,245p' | cat -n
Length of output: 1668
🏁 Script executed:
#!/bin/bash
# Search for where batch_offset_k and batch_offset_v are calculated
rg -n "batch_offset_[kv]\s*=" csrc/cudnn_sdpa_kernel_launcher.cu -B 2 -A 2
Length of output: 730
🏁 Script executed:
#!/bin/bash
# Look for data_ptr() method definition in the codebase
rg -n "\.data_ptr\(\)" csrc/cudnn_sdpa_kernel_launcher.cu -B 1 -A 1
Length of output: 4246
🏁 Script executed:
#!/bin/bash
# Find the TensorView struct/class definition to understand data_ptr() return type
fd -e hpp -e h | xargs rg -l "data_ptr" | head -10
Length of output: 0
🏁 Script executed:
#!/bin/bash
# Check if there's documentation or method signature for data_ptr()
rg -n "data_ptr.*\(\)" --type=cpp --type=hpp -B 2 -A 2 | grep -A 4 "return.*void\*\|return.*\*" | head -30
Length of output: 27
🏁 Script executed:
#!/bin/bash
# Search web documentation for PyTorch/similar tensor data_ptr() return type
echo "Searching for data_ptr() patterns to infer return type"
Length of output: 54
🏁 Script executed:
#!/bin/bash
# Find TensorView definition or similar tensor wrapper class
rg -n "class TensorView\|struct TensorView" --type=h --type=cpp -A 10
Length of output: 0
🏁 Script executed:
#!/bin/bash
# Check lines 386-387 to confirm same pattern issue
cat csrc/cudnn_sdpa_kernel_launcher.cu | sed -n '380,395p' | cat -n
Length of output: 1163
🏁 Script executed:
#!/bin/bash
# Search for any existing casts to byte* for arithmetic in the file
rg "std::byte\*" csrc/cudnn_sdpa_kernel_launcher.cu -n -B 1 -A 1
Length of output: 989
🏁 Script executed:
#!/bin/bash
# Check if there are header files that might define TensorView
fd -name "*.h" -o -name "*.hpp" | head -20
Length of output: 230
🌐 Web query:
PyTorch tensor data_ptr() return type void pointer
💡 Result:
Short answer:
- In the C++ API, at::Tensor::data_ptr() (non-template) is a void* (deprecated). Use the template form data_ptr() which returns T*. [1][2]
- In Python, torch.Tensor.data_ptr() returns an integer memory address (the pointer value). [3]
References:
- PyTorch C++ docs / custom ops examples showing data_ptr() usage. [1]
- Runtime API docs noting void* data_ptr() is deprecated and recommending const_data_ptr / mutable_data_ptr or template data_ptr(). [2]
- torch.Storage / docs describing tensor.data_ptr() (Python) and storage data_ptr behavior. [3]
Apply defensive byte-pointer casting for pointer arithmetic on data_ptr()
results.
The code performs byte-offset arithmetic directly on data_ptr()
returns (void*), which is non-standard C++. The codebase already uses the correct defensive pattern elsewhere—wrap with reinterpret_cast<std::byte*>()
before arithmetic:
- Lines 344–345:
reinterpret_cast<uint16_t*>(reinterpret_cast<std::byte*>(k.data_ptr()) + batch_offset_k);
- Lines 386–387:
reinterpret_cast<uint16_t*>(reinterpret_cast<std::byte*>(q.data_ptr()) + batch_offset_q);
See the correct pattern at lines 232–235 and 156–157 in the same file.
🤖 Prompt for AI Agents
csrc/cudnn_sdpa_kernel_launcher.cu lines 344–345 (and similarly 386–387): the
code performs pointer arithmetic directly on k.data_ptr() and v.data_ptr()
(void*), which is non-standard; update both casts to first reinterpret_cast the
data_ptr() to std::byte* and add the byte offset, then reinterpret_cast the
resulting pointer to uint16_t* (same defensive pattern used at lines 232–235 and
156–157) so pointer arithmetic is done on a character-sized type before the
final typed pointer cast.
TVM_FFI_ICHECK_EQ(swiglu_beta.value().size(0), num_experts_on_rank) | ||
"swiglu_beta must have num_experts_on_rank elements."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Missing stream insertion operator.
The error message string on line 486 is not connected to the check statement on line 485. The <<
operator is missing, which means the error message won't be displayed if the check fails.
Apply this diff to fix the issue:
TVM_FFI_ICHECK_EQ(swiglu_beta.value().size(0), num_experts_on_rank)
- "swiglu_beta must have num_experts_on_rank elements.";
+ << "swiglu_beta must have num_experts_on_rank elements.";
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
TVM_FFI_ICHECK_EQ(swiglu_beta.value().size(0), num_experts_on_rank) | |
"swiglu_beta must have num_experts_on_rank elements."; | |
TVM_FFI_ICHECK_EQ(swiglu_beta.value().size(0), num_experts_on_rank) | |
<< "swiglu_beta must have num_experts_on_rank elements."; |
🤖 Prompt for AI Agents
In csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
around lines 485 to 486, the TVM_FFI_ICHECK_EQ invocation is missing the stream
insertion operator before the error message string, so the message won't be
appended to the check; update the statement to insert the missing << between the
check and the string literal (ensuring the full expression ends with a
semicolon) so the error message "swiglu_beta must have num_experts_on_rank
elements." is correctly attached to the check.
SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, SCALE_MAJOR_K>( | ||
static_cast<void*>(float_workspace_buffer->data), | ||
static_cast<void*>(float_workspace_buffer.data_ptr()), | ||
get_element_size(float_workspace_buffer) * float_workspace_buffer.numel(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify buffer size calculation: .numel()
vs .size(0)
The AI summary indicates this line should use float_workspace_buffer.size(0)
(replacing ->shape[0]
), but the actual code uses float_workspace_buffer.numel()
.
.size(0)
returns only the first dimension.numel()
returns the total number of elements (product of all dimensions)
For calculating total buffer size in bytes, .numel()
is semantically correct. However, please verify:
- Is
float_workspace_buffer
always a 1D tensor? (If yes, both are equivalent) - Was
.numel()
the intended change, or should it be.size(0)
to match the original behavior?
🏁 Script executed:
#!/bin/bash
# Search for how float_workspace_buffer is created/allocated to understand its dimensionality
rg -n -C5 'float_workspace_buffer' --type cpp --type cuda
Length of output: 28
---
🏁 Script executed:
```shell
#!/bin/bash
# Search for float_workspace_buffer creation/allocation across CUDA files
# Use glob patterns instead of unrecognized file type
rg -n -C5 'float_workspace_buffer' --glob='*.cu' --glob='*.cuh'
Length of output: 83305
Fix buffer size calculation: replace .numel()
with .size(0)
for consistency
Line 126 uses float_workspace_buffer.numel()
, but this is inconsistent with the codebase pattern. All similar buffer size calculations throughout the codebase use .size(0)
:
csrc/gemm_groupwise_sm100.cu:107
:.size(0)
csrc/group_gemm_sm90.cu:65
:.size(0)
csrc/group_gemm_fp8_groupwise_sm120.cu:113
:.size(0)
- All batch_prefill, batch_decode, and batch_attention files:
.size(0)
Since float_workspace_buffer
is treated as a 1D tensor throughout the codebase, change line 126 to use .size(0)
to match:
get_element_size(float_workspace_buffer) * float_workspace_buffer.size(0),
🤖 Prompt for AI Agents
In csrc/gemm_groupwise_sm120.cu around line 126, the buffer size calculation
uses float_workspace_buffer.numel() which is inconsistent with the codebase
pattern; replace .numel() with .size(0) so the multiplication becomes
get_element_size(float_workspace_buffer) * float_workspace_buffer.size(0),
matching other files and treating the buffer as a 1D tensor.
unsigned int batch_size = output.size(0); | ||
unsigned int vocab_size = probs.size(1); | ||
|
||
cudaSetDevice(probs->device.device_id); | ||
auto stream = get_stream(probs->device); | ||
cudaSetDevice(probs.device().device_id); | ||
auto stream = get_stream(probs.device()); | ||
cudaError_t status = sampling::SamplingFromProb( | ||
static_cast<float*>(probs->data), static_cast<int*>(output->data), | ||
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr, | ||
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()), | ||
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value().data_ptr()) : nullptr, | ||
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sampling_from_probs: add device/shape checks for output (and indices if present)
Mirror the logits path validations.
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
- unsigned int batch_size = output.size(0);
- unsigned int vocab_size = probs.size(1);
+ unsigned int batch_size = output.size(0);
+ unsigned int vocab_size = probs.size(1);
+ CHECK_DEVICE(output, probs);
+ TVM_FFI_ICHECK_EQ(batch_size, probs.size(0));
+ CHECK_DIM(1, output); // output: (batch_size)
+ if (maybe_indices.has_value()) {
+ CHECK_DEVICE(maybe_indices.value(), probs);
+ CHECK_DIM(1, maybe_indices.value());
+ TVM_FFI_ICHECK_EQ(maybe_indices.value().size(0), batch_size);
+ }
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In csrc/sampling.cu around lines 68 to 76, the call into
sampling::SamplingFromProb lacks the same device, dtype and shape validations we
perform on the logits path; add checks that output is on the same CUDA device as
probs, has dtype int32, is 1D and its size(0) equals probs.size(0) (batch_size);
if maybe_indices has_value() ensure maybe_indices is on the same device, dtype
int32 and its shape/numel matches the expected indices layout (or is
nullptr-allowed), and raise clear errors (AT_ERROR / TORCH_CHECK) on mismatch
before calling cudaSetDevice/get_stream and SamplingFromProb so we avoid
device/dtype/shape mismatches at runtime.
TVM_FFI_ICHECK_EQ(mat1.device().device_type, kDLCUDA) << "mat1 tensor must be on CUDA"; | ||
TVM_FFI_ICHECK_EQ(mat2.device().device_type, kDLCUDA) << "mat2 tensor must be on CUDA"; | ||
TVM_FFI_ICHECK_EQ(mat1.ndim(), 2) << "mat1 tensor must be 2D (M, K)"; | ||
TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 tensor must be 2D (K, N)"; | ||
TVM_FFI_ICHECK_EQ(mat1.size(1), mat2.size(0)) << "mat1.K must match mat2.K"; | ||
TVM_FFI_ICHECK_EQ(mat1.dtype(), mat2.dtype()) << "mat1 and mat2 must have the same dtype"; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Block cross-device hazards; validate out tensor early.
Add CUDA device_id parity checks (mat1/mat2/out) and validate out’s device/ndim to prevent invalid cross-GPU pointers and stream mismatches.
TVM_FFI_ICHECK_EQ(mat1.device().device_type, kDLCUDA) << "mat1 tensor must be on CUDA";
TVM_FFI_ICHECK_EQ(mat2.device().device_type, kDLCUDA) << "mat2 tensor must be on CUDA";
TVM_FFI_ICHECK_EQ(mat1.ndim(), 2) << "mat1 tensor must be 2D (M, K)";
TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 tensor must be 2D (K, N)";
TVM_FFI_ICHECK_EQ(mat1.size(1), mat2.size(0)) << "mat1.K must match mat2.K";
TVM_FFI_ICHECK_EQ(mat1.dtype(), mat2.dtype()) << "mat1 and mat2 must have the same dtype";
+ TVM_FFI_ICHECK_EQ(out.device().device_type, kDLCUDA) << "out tensor must be on CUDA";
+ TVM_FFI_ICHECK_EQ(out.ndim(), 2) << "out tensor must be 2D (N, M)";
+ TVM_FFI_ICHECK_EQ(mat1.device().device_id, mat2.device().device_id)
+ << "mat1 and mat2 must be on the same CUDA device";
+ TVM_FFI_ICHECK_EQ(mat1.device().device_id, out.device().device_id)
+ << "out must be on the same CUDA device as inputs";
🤖 Prompt for AI Agents
In csrc/tgv_gemm.cu around lines 122 to 128, add validation for the output
tensor to prevent cross-GPU pointer and stream mismatches: check that
out.device().device_type == kDLCUDA, verify out.device().device_id equals
mat1.device().device_id and mat2.device().device_id (or otherwise ensure all
three tensors are on the same CUDA device), assert out.ndim() == 2, assert
out.size(0) == mat1.size(0) and out.size(1) == mat2.size(1), and assert
out.dtype() == mat1.dtype() (and mat2.dtype() if desired); use TVM_FFI_ICHECK_EQ
(or TVM_FFI_ICHECK) with clear messages for each condition.
TVM_FFI_ICHECK_EQ(bias.value().device().device_type, kDLCUDA) << "Bias tensor must be on CUDA"; | ||
TVM_FFI_ICHECK_EQ(bias.value().ndim(), 1) << "Bias tensor must be 1D (M,)"; | ||
TVM_FFI_ICHECK_EQ(bias.value().size(0), M) << "Bias tensor must have M elements"; | ||
TVM_FFI_ICHECK_EQ(bias.value().dtype(), mat1.dtype()) | ||
<< "Bias tensor must have the same dtype as input matrices"; | ||
TVM_FFI_ICHECK_EQ(bias.value()->strides[0], 1) << "Bias tensor must be M contiguous"; | ||
TVM_FFI_ICHECK_EQ(bias.value().stride(0), 1) << "Bias tensor must be M contiguous"; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure bias is on the same CUDA device.
Missing device_id parity for bias can cause invalid accesses.
if (bias.has_value()) {
TVM_FFI_ICHECK_EQ(bias.value().device().device_type, kDLCUDA) << "Bias tensor must be on CUDA";
TVM_FFI_ICHECK_EQ(bias.value().ndim(), 1) << "Bias tensor must be 1D (M,)";
TVM_FFI_ICHECK_EQ(bias.value().size(0), M) << "Bias tensor must have M elements";
TVM_FFI_ICHECK_EQ(bias.value().dtype(), mat1.dtype())
<< "Bias tensor must have the same dtype as input matrices";
TVM_FFI_ICHECK_EQ(bias.value().stride(0), 1) << "Bias tensor must be M contiguous";
+ TVM_FFI_ICHECK_EQ(bias.value().device().device_id, out.device().device_id)
+ << "Bias tensor must be on the same CUDA device as inputs";
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
TVM_FFI_ICHECK_EQ(bias.value().device().device_type, kDLCUDA) << "Bias tensor must be on CUDA"; | |
TVM_FFI_ICHECK_EQ(bias.value().ndim(), 1) << "Bias tensor must be 1D (M,)"; | |
TVM_FFI_ICHECK_EQ(bias.value().size(0), M) << "Bias tensor must have M elements"; | |
TVM_FFI_ICHECK_EQ(bias.value().dtype(), mat1.dtype()) | |
<< "Bias tensor must have the same dtype as input matrices"; | |
TVM_FFI_ICHECK_EQ(bias.value()->strides[0], 1) << "Bias tensor must be M contiguous"; | |
TVM_FFI_ICHECK_EQ(bias.value().stride(0), 1) << "Bias tensor must be M contiguous"; | |
} | |
TVM_FFI_ICHECK_EQ(bias.value().device().device_type, kDLCUDA) << "Bias tensor must be on CUDA"; | |
TVM_FFI_ICHECK_EQ(bias.value().ndim(), 1) << "Bias tensor must be 1D (M,)"; | |
TVM_FFI_ICHECK_EQ(bias.value().size(0), M) << "Bias tensor must have M elements"; | |
TVM_FFI_ICHECK_EQ(bias.value().dtype(), mat1.dtype()) | |
<< "Bias tensor must have the same dtype as input matrices"; | |
TVM_FFI_ICHECK_EQ(bias.value().stride(0), 1) << "Bias tensor must be M contiguous"; | |
TVM_FFI_ICHECK_EQ(bias.value().device().device_id, out.device().device_id) | |
<< "Bias tensor must be on the same CUDA device as inputs"; | |
} |
🤖 Prompt for AI Agents
In csrc/tgv_gemm.cu around lines 161 to 167, the bias checks validate device
type but not device_id which can lead to cross-device access; add a check that
bias.value().device().device_id equals mat1.device().device_id (or the expected
CUDA device) and emit a clear error message like "Bias tensor must be on the
same CUDA device as input matrices" if they differ; place this check alongside
the other TVM_FFI_ICHECK_EQ validations so it runs before any CUDA memory
access.
int stride_A_M = mat1.stride(0); | ||
int stride_A_K = mat1.stride(1); | ||
int stride_A_L = M * K; | ||
// B [K, N] column major | ||
int stride_B_N = mat2->strides[1]; | ||
int stride_B_K = mat2->strides[0]; | ||
int stride_B_N = mat2.stride(1); | ||
int stride_B_K = mat2.stride(0); | ||
int stride_B_L = N * K; | ||
// original C [N, M] row major | ||
int stride_C_M = out->strides[1]; | ||
int stride_C_N = out->strides[0]; | ||
int stride_C_M = out.stride(1); | ||
int stride_C_N = out.stride(0); | ||
int stride_C_L = M * N; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enforce layout assumptions or compute L-stride safely.
You hardcode L strides as MK/NK/M*N, which is only correct for contiguous row/column-major layouts. Either enforce those layouts or derive strides from metadata. Minimal fix below enforces the layouts explicitly.
// manually calculate the L stride
// A [M, K] row major
+ TVM_FFI_ICHECK_EQ(mat1.stride(1), 1) << "mat1 must be row-major contiguous (stride(1) == 1)";
int stride_A_M = mat1.stride(0);
int stride_A_K = mat1.stride(1);
int stride_A_L = M * K;
// B [K, N] column major
+ TVM_FFI_ICHECK_EQ(mat2.stride(0), 1) << "mat2 must be column-major contiguous (stride(0) == 1)";
int stride_B_N = mat2.stride(1);
int stride_B_K = mat2.stride(0);
int stride_B_L = N * K;
// original C [N, M] row major
+ TVM_FFI_ICHECK_EQ(out.stride(1), 1) << "out must be row-major contiguous (stride(1) == 1)";
int stride_C_M = out.stride(1);
int stride_C_N = out.stride(0);
int stride_C_L = M * N;
Optional: also guard stride magnitudes before casting to int to avoid UB on very large tensors.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In csrc/tgv_gemm.cu around lines 176 to 186, the code hardcodes "L" strides
(stride_A_L = M*K, stride_B_L = N*K, stride_C_L = M*N) which is only valid for
contiguous row/column-major tensors; update the code to either assert/enforce
the expected layouts (e.g., require mat1 row-major contiguous and mat2
column-major contiguous and out row-major/column-major as needed) or compute the
L-stride from the tensor metadata (derive stride_X_L from the tensor's stride()
values or storage size) instead of using M*K/N*K/M*N; additionally, before
casting strides to int, check magnitude (or use safe integer conversion) to
avoid undefined behavior on very large tensors.
void bf16_gemm(TensorView mat1, TensorView mat2, std::optional<TensorView> bias, int64_t tactic, | ||
TensorView out, bool pdl) { | ||
// Check that inputs are bfloat16 for backward compatibility | ||
TVM_FFI_ICHECK_EQ(mat1->dtype, dl_bfloat16) << "mat1 tensor must be bfloat16"; | ||
TVM_FFI_ICHECK_EQ(mat2->dtype, dl_bfloat16) << "mat2 tensor must be bfloat16"; | ||
TVM_FFI_ICHECK_EQ(mat1.dtype(), dl_bfloat16) << "mat1 tensor must be bfloat16"; | ||
TVM_FFI_ICHECK_EQ(mat2.dtype(), dl_bfloat16) << "mat2 tensor must be bfloat16"; | ||
tgv_gemm(mat1, mat2, bias, tactic, out, pdl); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix type mismatch: std::optional vs tvm::ffi::Optional (compile blocker).
bf16_gemm takes std::optional but tgv_gemm expects tvm::ffi::Optional, causing an implicit-conversion mismatch at the call.
-void bf16_gemm(TensorView mat1, TensorView mat2, std::optional<TensorView> bias, int64_t tactic,
+void bf16_gemm(TensorView mat1, TensorView mat2, Optional<TensorView> bias, int64_t tactic,
TensorView out, bool pdl) {
// Check that inputs are bfloat16 for backward compatibility
TVM_FFI_ICHECK_EQ(mat1.dtype(), dl_bfloat16) << "mat1 tensor must be bfloat16";
TVM_FFI_ICHECK_EQ(mat2.dtype(), dl_bfloat16) << "mat2 tensor must be bfloat16";
tgv_gemm(mat1, mat2, bias, tactic, out, pdl);
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
void bf16_gemm(TensorView mat1, TensorView mat2, std::optional<TensorView> bias, int64_t tactic, | |
TensorView out, bool pdl) { | |
// Check that inputs are bfloat16 for backward compatibility | |
TVM_FFI_ICHECK_EQ(mat1->dtype, dl_bfloat16) << "mat1 tensor must be bfloat16"; | |
TVM_FFI_ICHECK_EQ(mat2->dtype, dl_bfloat16) << "mat2 tensor must be bfloat16"; | |
TVM_FFI_ICHECK_EQ(mat1.dtype(), dl_bfloat16) << "mat1 tensor must be bfloat16"; | |
TVM_FFI_ICHECK_EQ(mat2.dtype(), dl_bfloat16) << "mat2 tensor must be bfloat16"; | |
tgv_gemm(mat1, mat2, bias, tactic, out, pdl); | |
void bf16_gemm(TensorView mat1, TensorView mat2, Optional<TensorView> bias, int64_t tactic, | |
TensorView out, bool pdl) { | |
// Check that inputs are bfloat16 for backward compatibility | |
TVM_FFI_ICHECK_EQ(mat1.dtype(), dl_bfloat16) << "mat1 tensor must be bfloat16"; | |
TVM_FFI_ICHECK_EQ(mat2.dtype(), dl_bfloat16) << "mat2 tensor must be bfloat16"; | |
tgv_gemm(mat1, mat2, bias, tactic, out, pdl); | |
} |
🤖 Prompt for AI Agents
In csrc/tgv_gemm.cu around lines 211 to 216, bf16_gemm currently accepts a
std::optional<TensorView> for bias but calls tgv_gemm which expects a
tvm::ffi::Optional<TensorView>, causing a type mismatch; fix by either changing
the bf16_gemm signature to accept tvm::ffi::Optional<TensorView> (preferred) or
convert the std::optional to tvm::ffi::Optional before calling tgv_gemm (e.g.,
construct a tvm::ffi::Optional with has_value and value when present); ensure
includes/using are present for tvm::ffi::Optional and update any callers of
bf16_gemm if you change its signature.
/bot run |
[FAILED] Pipeline #36993872: 0/17 passed |
/bot run |
[CANCELING] Pipeline #37003140: canceled |
/bot run |
📌 Description
This PR bumps the tvm-ffi to stable version 0.1.0 and update the flashinfer code base.
🔍 Related Issues
#1939
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes
Summary by CodeRabbit