-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][feat] Autodeploy: Update the ssm to use slice #8667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: nvchenghaoz <[email protected]>
|
/bot run |
|
PR_Github #22521 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThe changes refactor the Mamba Triton backend to replace index-based tensor selections with direct slicing operations for both prefill and decode stages, removing the prefill_idx construct. Additionally, test cases for MoE Triton kernels are extended with early_exit parameterization to cover both balanced and imbalanced routing scenarios. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (1)
137-137: Consider computingtotal_prefill_tokensonce to avoid redundancy.The value is recomputed here because it was previously computed inside the
if num_prefill > 0block (line 82) and isn't in scope. Consider computingtotal_prefill_tokensbefore both conditional blocks to avoid the redundant calculation.Apply this diff to eliminate redundant computation:
# Prefill: concatenate tokens at the front and run combined scan + total_prefill_tokens = 0 if num_prefill == 0 else int(seq_len[:num_prefill].sum().item()) + if num_prefill > 0: seq_len_prefill = seq_len[:num_prefill].to(torch.int32) - total_prefill_tokens = int(seq_len_prefill.sum().item()) hs_prefill = hs_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]And remove the recomputation at line 137:
# Decode: batch single-token updates via selective_state_update if num_decode > 0: - total_prefill_tokens = 0 if num_prefill == 0 else int(seq_len[:num_prefill].sum().item()) slot_idx_decode = slot_idx[num_prefill:].to(torch.long)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py(3 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py(5 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py (2)
78-125: LGTM! Good test coverage enhancement.The parameterization of
early_exitto test both balanced and imbalanced routing scenarios is well-designed. The imbalanced routing (concentrating 75% of tokens on first 2 experts) will help validate the MoE kernel's behavior under skewed load distribution.
237-348: LGTM! Consistent test parameterization for FP8 quantized MoE.The parameterization follows the same sound pattern as the BF16 test, appropriately adjusted for larger token counts in the FP8 test. The routing logic correctly implements both balanced and imbalanced scenarios.
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
Show resolved
Hide resolved
|
PR_Github #22521 [ run ] completed with state |
|
/bot run |
2 similar comments
|
/bot run |
|
/bot run |
|
PR_Github #22568 [ run ] triggered by Bot. Commit: |
|
PR_Github #22568 [ run ] completed with state |
Summary by CodeRabbit
Tests
Performance