-
Notifications
You must be signed in to change notification settings - Fork 541
Fix: Verify scales are not None for Cutlass FP8 FusedMoE #1961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix: Verify scales are not None for Cutlass FP8 FusedMoE #1961
Conversation
Signed-off-by: Amir Klein <[email protected]>
Summary of ChangesHello @amirkl94, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical stability issue in the Cutlass FP8 FusedMoE implementation by adding robust validation for quantization scale tensors. By ensuring these tensors are not null, the change effectively prevents segmentation faults that could arise from Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
WalkthroughAdded runtime null checks for FP8 quantization parameters (fc1_dequant, fc2_quant, fc2_dequant, fc1_input_dequant) in the getQuantParams function before type validation, tightening input validation for quantization scales. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds important null pointer checks for quantization scales in the FP8 FusedMoE path, which is a good defensive measure to prevent potential segmentation faults. The change is correct and addresses the issue described. I've found a minor typo in one of the new error messages and have provided a suggestion to fix it.
| TVM_FFI_ICHECK(fc2_dequant.get() != nullptr) | ||
| << "Expecting fc1fc2_dequant_dequant to be non null"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls take a look at this comment from gemini
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)
839-843: Add individual null checks to other quantization paths like the FP8 path.The FP8 path (lines 803-808) checks each extracted tensor element for null (
fc1_dequant.get() != nullptr), but other quantization modes only verifyquant_scales.value().size()before directly accessing array elements. Since Python tests passquant_scales=Noneto these other paths (NVFP4, W4A8_MXFP4_FP8, W4A8_MXFP4_MXFP8, BlockScaling, W4A16, INT4), they share the same segfault vulnerability.Add per-item null checks after extracting each tensor from
quant_scales.value()in:
- W4A8_MXFP4_FP8 (lines 839-843): after extracting fc1_weight_block, fc1_global, fc2_act_global, fc2_weight_block, fc2_global
- W4A8_MXFP4_MXFP8 (lines 904-907): after extracting fc1_weight_block, fc1_global, fc2_weight_block, fc2_global
- NVFP4 (lines 963-968): after extracting all 6 scale tensors
- BlockScaling (lines 1028-1029): after extracting fc1_scales, fc2_scales
- W4A16 (lines 1037-1038): after extracting fc1_weight_scales, fc2_weight_scales
- INT4 (lines 1048-1055): after extracting all 8 scale tensors
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
| TVM_FFI_ICHECK(fc1_dequant.get() != nullptr) << "Expecting fc1_dequant to be non null"; | ||
| TVM_FFI_ICHECK(fc2_quant.get() != nullptr) << "Expecting fc2_quant to be non null"; | ||
| TVM_FFI_ICHECK(fc2_dequant.get() != nullptr) | ||
| << "Expecting fc1fc2_dequant_dequant to be non null"; | ||
| TVM_FFI_ICHECK(fc1_input_dequant.get() != nullptr) | ||
| << "Expecting fc1_input_dequant to be non null"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix typo in error message on line 806.
The error message on line 806 reads "Expecting fc1fc2_dequant_dequant to be non null" but should say "Expecting fc2_dequant to be non null".
Apply this diff to fix the typo:
- TVM_FFI_ICHECK(fc2_dequant.get() != nullptr)
- << "Expecting fc1fc2_dequant_dequant to be non null";
+ TVM_FFI_ICHECK(fc2_dequant.get() != nullptr)
+ << "Expecting fc2_dequant to be non null";📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| TVM_FFI_ICHECK(fc1_dequant.get() != nullptr) << "Expecting fc1_dequant to be non null"; | |
| TVM_FFI_ICHECK(fc2_quant.get() != nullptr) << "Expecting fc2_quant to be non null"; | |
| TVM_FFI_ICHECK(fc2_dequant.get() != nullptr) | |
| << "Expecting fc1fc2_dequant_dequant to be non null"; | |
| TVM_FFI_ICHECK(fc1_input_dequant.get() != nullptr) | |
| << "Expecting fc1_input_dequant to be non null"; | |
| TVM_FFI_ICHECK(fc1_dequant.get() != nullptr) << "Expecting fc1_dequant to be non null"; | |
| TVM_FFI_ICHECK(fc2_quant.get() != nullptr) << "Expecting fc2_quant to be non null"; | |
| TVM_FFI_ICHECK(fc2_dequant.get() != nullptr) | |
| << "Expecting fc2_dequant to be non null"; | |
| TVM_FFI_ICHECK(fc1_input_dequant.get() != nullptr) | |
| << "Expecting fc1_input_dequant to be non null"; |
🤖 Prompt for AI Agents
In csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
around lines 803 to 808, the TVM_FFI_ICHECK error message for fc2_dequant
contains a typo ("Expecting fc1fc2_dequant_dequant to be non null"); update that
string to the correct text "Expecting fc2_dequant to be non null" so the check
reports the right variable name.
|
/bot run |
|
[FAILED] Pipeline #36988944: 1/17 passed |
📌 Description
Verify quant scales for fp8 are non null in cutlass FusedMoE path. Currently, if these tensors are passed as None from python it will result in segmentation fault.
Summary by CodeRabbit