-
Couldn't load subscription status.
- Fork 1.8k
[None][autodeploy] minor refactor to rmsnorm transforms #8657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Fridah-nv <[email protected]>
📝 WalkthroughWalkthroughThis change refactors RMSNorm fusion configuration and implementation. The single backend field is split into separate rmsnorm_backend and gated_rmsnorm_backend configuration options. The separate FuseGatedRMSNorm transformer is removed and consolidated into FuseRMSNorm. The gated RMSNorm custom operation is renamed to reflect Triton backend routing and expanded with additional function parameters. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes The changes involve straightforward configuration and naming updates across multiple files, along with one complex refactoring in the transform library that consolidates two transformers into one with expanded pattern support and backend routing logic. The heterogeneity of simple versus complex changes, combined with the integration of gated RMSNorm handling within the main transformer, warrants careful review of the validation logic and pattern matching behavior. Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (2)
86-141: New gated Triton op is well-scoped; fix Python 3.8+ type hints.The annotation
gate: torch.Tensor | Nonerequires Python 3.10+. Target is 3.8+; switch to Optional.@@ -import torch +import torch +from typing import Optional @@ -def triton_rmsnorm_gated( +def triton_rmsnorm_gated( x: torch.Tensor, weight: torch.Tensor, - gate: torch.Tensor | None, + gate: Optional[torch.Tensor], eps: float, group_size: int, norm_before_gate: bool = False, ) -> torch.Tensor:Also, keeping the fp32 return here and in the meta function is correct; no downcast in the op is needed. Based on learnings.
1-1: Add NVIDIA Apache-2.0 header.Per coding guidelines, prepend the 2025 NVIDIA Apache-2.0 header.
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (1)
1-1: Add NVIDIA Apache-2.0 header.Prepend the 2025 NVIDIA header per guidelines.
🧹 Nitpick comments (7)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py (1)
27-29: Op rename looks good; add coverage for norm_before_gate=True.Current test only exercises norm_before_gate=False. Add a param to exercise True as well.
- # Custom op (currently returns fp32). Cast it back to x.dtype for apples-to-apples with ref. - y_op_fp32 = torch.ops.auto_deploy.triton_rmsnorm_gated(x, w, z, 1e-5, group, False) + # Custom op (currently returns fp32). Cast it back to x.dtype for apples-to-apples with ref. + # Optionally parametrize norm_before_gate over [False, True]. + y_op_fp32 = torch.ops.auto_deploy.triton_rmsnorm_gated(x, w, z, 1e-5, group, False)If you’d like, I can push a full parametrization patch.
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)
10-25: Avoid hard dependency on flashinfer at import-time.Importing flashinfer at module import can break environments that only need Triton. Lazy-import inside the function or guard with try/except.
-import flashinfer +try: + import flashinfer # type: ignore +except Exception: # pragma: no cover + flashinfer = None @@ - input_flat = input.reshape(-1, input.shape[-1]) - rmsnorm_flat = flashinfer.norm.rmsnorm(input_flat, weight, eps) + if flashinfer is None: + raise RuntimeError("flashinfer not available; set rmsnorm_backend != 'flashinfer'") + input_flat = input.reshape(-1, input.shape[-1]) + rmsnorm_flat = flashinfer.norm.rmsnorm(input_flat, weight, eps)tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (5)
85-98: Docstring should reflect multiple backends, not just FlashInfer.Update wording to cover FlashInfer, Triton, and Torch.
-"""Graph transform to optimize RMSNorm execution using FlashInfer.""" +"""Graph transform to optimize RMSNorm execution using FlashInfer, Triton, or Torch backends. + +Supports both regular and gated RMSNorm paths."""
113-118: Tighten exception messages; satisfy TRY003 hint.Shorten messages or move details to constants to avoid long f-strings in raises.
- if self.config.rmsnorm_backend.lower() not in _BACKEND_OPS: - raise ValueError( - f"Invalid rmsnorm_backend, must be one of {list(_BACKEND_OPS)}, got {self.config.rmsnorm_backend}" - ) + if self.config.rmsnorm_backend.lower() not in _BACKEND_OPS: + allowed = ", ".join(_BACKEND_OPS.keys()) + raise ValueError(f"rmsnorm_backend must be one of [{allowed}]") @@ - if self.config.gated_rmsnorm_backend.lower() != "triton": - raise ValueError( - f"""Invalid gated_rmsnorm_backend, currently only 'triton' is supported, - got {self.config.gated_rmsnorm_backend}""" - ) + if self.config.gated_rmsnorm_backend.lower() != "triton": + raise ValueError("gated_rmsnorm_backend must be 'triton'")If Ruff is enabled, confirm TRY003 warnings are gone after this change.
Also applies to: 119-124
158-185: Create gated dummy tensors on CUDA to avoid device-cast mismatches.Regular RMSNorm dummies are on CUDA; mirror that for gated to reduce to(device) artifacts during tracing.
- def make_dummy_args_gated(group_size: int, eps: float) -> list: - x = torch.randn(B, S, H, dtype=torch.float32) - w = torch.randn(H, dtype=torch.float32) - g = torch.randn(B, S, H, dtype=torch.float32) + def make_dummy_args_gated(group_size: int, eps: float) -> list: + x = torch.randn(B, S, H, device="cuda", dtype=torch.float32) + w = torch.randn(H, device="cuda", dtype=torch.float32) + g = torch.randn(B, S, H, device="cuda", dtype=torch.float32) return [x, w, g, eps, group_size]Alternatively, expand op_ignore_types to cover device args, but CUDA dummies are simpler and consistent with comments in register_ad_pattern.
176-185: Route gated backend via config for future extensibility.You validate gated_rmsnorm_backend=='triton' but don’t thread it to the replacement. Plumb through now for consistency.
- register_ad_pattern( + register_ad_pattern( search_fn=_gated_rmsnorm_pattern_ref, - replace_fn=partial(_gated_rmsnorm_replacement), + replace_fn=partial(_gated_rmsnorm_replacement, backend=self.config.gated_rmsnorm_backend), patterns=patterns, dummy_args=make_dummy_args_gated(group_size, eps), op_ignore_types=op_ignore_types, scalar_workaround={"eps": eps, "group_size": group_size}, skip_duplicates=True, ) @@ -def _gated_rmsnorm_replacement( +def _gated_rmsnorm_replacement( x: torch.Tensor, weight: torch.Tensor, gate: torch.Tensor, eps: float, group_size: int, -) -> torch.Tensor: - return torch.ops.auto_deploy.triton_rmsnorm_gated( - x, weight, gate, float(eps), int(group_size), False - ) +) -> torch.Tensor: + return torch.ops.auto_deploy.triton_rmsnorm_gated( + x, weight, gate, float(eps), int(group_size), False + )No behavioral change today; it avoids touching this code again when adding more backends. Keeping fp32 from the op is intentional. Based on learnings.
Also applies to: 222-224
70-80: Minor naming: consider aligning _BACKEND_OPS docstrings with keys.Keys mix "rms_norm" and "rmsnorm" in op names; not a bug, but a brief comment noting the intent would help avoid confusion.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
tensorrt_llm/_torch/auto_deploy/config/default.yaml(1 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py(2 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py(4 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.pytensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.pytensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.pytensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: nvchenghaoz
PR: NVIDIA/TensorRT-LLM#8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
PR: NVIDIA/TensorRT-LLM#8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.pytensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
📚 Learning: 2025-10-20T17:09:21.560Z
Learnt from: nvchenghaoz
PR: NVIDIA/TensorRT-LLM#8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.pytensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
🧬 Code graph analysis (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)
triton_rmsnorm_gated(87-140)
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (2)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (2)
ADPatternMatcherPass(61-67)register_ad_pattern(99-182)tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)
triton_rmsnorm_gated(87-140)
🪛 Ruff (0.14.1)
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
115-117: Avoid specifying long messages outside the exception class
(TRY003)
121-124: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (4)
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)
143-159: Meta returns fp32 as intended.Keeping meta dtype as fp32 matches the real op and downstream casting expectations. No change needed.
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (3)
73-80: Config split LGTM.rmsnorm_backend and gated_rmsnorm_backend fields are clear and documented.
222-224: No cast-back in replacement is correct.Let downstream handle dtype; matches prior guidance and tests already cast for comparison.
129-157: Regular RMSNorm pattern registration: solid.Use of partial with backend and scalar workaround is appropriate.
| rmsnorm_backend: flashinfer | ||
| gated_rmsnorm_backend: triton | ||
| requires_shape_prop: true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Config split is clear; verify no stale uses of the old backend field.
Ensure no remaining consumers expect fuse_rmsnorm.backend or the removed fuse_gated_rmsnorm block.
🏁 Script executed:
#!/bin/bash
# Find old config key usages and removed transform references.
rg -nP '\bfuse_rmsnorm\b.*\bbackend\b|fuse_gated_rmsnorm|torch_rmsnorm_gated' -C2Length of output: 0
I'll run broader searches to verify if there are any remaining references to the old config field names in the codebase.
---
🏁 Script executed:
```shell
#!/bin/bash
# Search more broadly for old config key references in Python files
# Look for fuse_rmsnorm, fuse_gated_rmsnorm patterns
rg -i 'fuse_rmsnorm|fuse_gated_rmsnorm' --type=py -C3
Length of output: 1482
🏁 Script executed:
#!/bin/bash
# Search for where config is loaded/parsed to see if old keys are referenced
rg -i 'rmsnorm.*backend|gated_rmsnorm' --type=py -C3Length of output: 6538
🏁 Script executed:
#!/bin/bash
# Check if there are any YAML references to old config structure
fd -e yaml -e yml | xargs grep -i 'fuse_rmsnorm\|fuse_gated_rmsnorm' || echo "No matches in YAML files"Length of output: 67
I need to check the full YAML config structure and how it's being loaded.
---
🏁 Script executed:
```shell
#!/bin/bash
# View the full default.yaml file to see the complete config structure
cat tensorrt_llm/_torch/auto_deploy/config/default.yaml
Length of output: 6055
🏁 Script executed:
#!/bin/bash
# Search for where config is loaded/parsed to understand the expected structure
rg -i 'fuse_rmsnorm|FuseRMSNormConfig' --type=py -A5 -B5 | head -100Length of output: 4934
🏁 Script executed:
#!/bin/bash
# Check the test file to see what config it's providing
cat tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py | grep -A20 "fuse_rmsnorm"Length of output: 557
Test file using removed config field; update required.
The test file tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py is still passing a single "backend" field, which no longer exists. The FuseRMSNormConfig class expects rmsnorm_backend and gated_rmsnorm_backend fields. The test must be updated to match the new config schema—either split the backend selection or use the appropriate field names for the variant being tested.
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/config/default.yaml around lines 124-126 the
schema now uses rmsnorm_backend and gated_rmsnorm_backend instead of a single
backend field; update the test at
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
to stop passing "backend" and instead pass the correct field(s): use
"rmsnorm_backend" for regular RMSNorm tests and "gated_rmsnorm_backend" for
gated variants (or include both fields if the code path reads both), and adjust
any assertions/expected config keys in the test to reference the new field
names.
Summary by CodeRabbit
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.