Skip to content

Conversation

@Fridah-nv
Copy link
Collaborator

@Fridah-nv Fridah-nv commented Oct 24, 2025

Summary by CodeRabbit

  • Refactor
    • Updated RMSNorm transform configuration structure with explicit backend selection for standard and gated normalization variants
    • Consolidated gated RMSNorm handling into the unified RMSNorm transform, improving configuration consistency

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 24, 2025

📝 Walkthrough

Walkthrough

This change refactors RMSNorm fusion configuration and implementation. The single backend field is split into separate rmsnorm_backend and gated_rmsnorm_backend configuration options. The separate FuseGatedRMSNorm transformer is removed and consolidated into FuseRMSNorm. The gated RMSNorm custom operation is renamed to reflect Triton backend routing and expanded with additional function parameters.

Changes

Cohort / File(s) Summary
Configuration update
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Replaced fuse_rmsnorm backend field with dual backend configuration: rmsnorm_backend: flashinfer and gated_rmsnorm_backend: triton. Removed the separate fuse_gated_rmsnorm transform block.
Custom ops implementation
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py
Renamed gated RMSNorm function and custom op from torch_rmsnorm_gated to triton_rmsnorm_gated. Updated function signature to include eps, group_size, and norm_before_gate parameters. Updated corresponding fake registration from _torch_rmsnorm_gated_meta to _triton_rmsnorm_gated_meta.
Transform library refactoring
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
Renamed FuseRMSNormConfig.backend to rmsnorm_backend and added new gated_rmsnorm_backend field. Consolidated gated RMSNorm handling into FuseRMSNorm transformer with new pattern registration and replacement logic. Removed standalone FuseGatedRMSNorm transformer. Updated validation and backend routing logic for both regular and gated RMSNorm paths.
Test update
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py
Updated operator reference from torch.ops.auto_deploy.torch_rmsnorm_gated to torch.ops.auto_deploy.triton_rmsnorm_gated.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

The changes involve straightforward configuration and naming updates across multiple files, along with one complex refactoring in the transform library that consolidates two transformers into one with expanded pattern support and backend routing logic. The heterogeneity of simple versus complex changes, combined with the integration of gated RMSNorm handling within the main transformer, warrants careful review of the validation logic and pattern matching behavior.

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.57% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description Check ⚠️ Warning The pull request description is incomplete and consists entirely of the repository template with placeholder comments. The critical sections required by the template—specifically the "Description" field (which should explain the issue and solution) and the "Test Coverage" field (which should list relevant tests that safeguard the changes)—are both empty and contain only HTML comment prompts. The PR Checklist items are all unchecked with no indication of which items are applicable or verified. While the <raw_summary> shows substantial code changes spanning configuration files, public API modifications, and test updates, none of this information is documented in the actual PR description provided by the author. The author must complete the PR description by filling in the "Description" section with a clear explanation of what changes are being made and why (e.g., refactoring the gated RMSNorm handling to integrate it into FuseRMSNorm and rename the backend API from torch_rmsnorm_gated to triton_rmsnorm_gated). The "Test Coverage" section must explicitly list the relevant tests that validate these changes, such as the test file modified in the PR. Finally, the PR Checklist items should be reviewed and checked off as appropriate, confirming compliance with coding guidelines, test coverage, and documentation requirements.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title "[None][autodeploy] minor refactor to rmsnorm transforms" is directly related to the core changes in this pull request. The raw summary shows the primary objectives include: consolidating fuse_rmsnorm and fuse_gated_rmsnorm transforms (removing the separate fuse_gated_rmsnorm block and merging it into fuse_rmsnorm with separate backend configurations), renaming public custom ops from torch_rmsnorm_gated to triton_rmsnorm_gated, and restructuring the RMSNorm transform configuration. The title accurately captures the essence of these changes as a refactoring of RMSNorm transforms. While it doesn't enumerate every modified file or detail (which is expected), it clearly summarizes the main purpose from the developer's perspective.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (2)

86-141: New gated Triton op is well-scoped; fix Python 3.8+ type hints.

The annotation gate: torch.Tensor | None requires Python 3.10+. Target is 3.8+; switch to Optional.

@@
-import torch
+import torch
+from typing import Optional
@@
-def triton_rmsnorm_gated(
+def triton_rmsnorm_gated(
     x: torch.Tensor,
     weight: torch.Tensor,
-    gate: torch.Tensor | None,
+    gate: Optional[torch.Tensor],
     eps: float,
     group_size: int,
     norm_before_gate: bool = False,
 ) -> torch.Tensor:

Also, keeping the fp32 return here and in the meta function is correct; no downcast in the op is needed. Based on learnings.


1-1: Add NVIDIA Apache-2.0 header.

Per coding guidelines, prepend the 2025 NVIDIA Apache-2.0 header.

tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (1)

1-1: Add NVIDIA Apache-2.0 header.

Prepend the 2025 NVIDIA header per guidelines.

🧹 Nitpick comments (7)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py (1)

27-29: Op rename looks good; add coverage for norm_before_gate=True.

Current test only exercises norm_before_gate=False. Add a param to exercise True as well.

-    # Custom op (currently returns fp32). Cast it back to x.dtype for apples-to-apples with ref.
-    y_op_fp32 = torch.ops.auto_deploy.triton_rmsnorm_gated(x, w, z, 1e-5, group, False)
+    # Custom op (currently returns fp32). Cast it back to x.dtype for apples-to-apples with ref.
+    # Optionally parametrize norm_before_gate over [False, True].
+    y_op_fp32 = torch.ops.auto_deploy.triton_rmsnorm_gated(x, w, z, 1e-5, group, False)

If you’d like, I can push a full parametrization patch.

tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)

10-25: Avoid hard dependency on flashinfer at import-time.

Importing flashinfer at module import can break environments that only need Triton. Lazy-import inside the function or guard with try/except.

-import flashinfer
+try:
+    import flashinfer  # type: ignore
+except Exception:  # pragma: no cover
+    flashinfer = None
@@
-    input_flat = input.reshape(-1, input.shape[-1])
-    rmsnorm_flat = flashinfer.norm.rmsnorm(input_flat, weight, eps)
+    if flashinfer is None:
+        raise RuntimeError("flashinfer not available; set rmsnorm_backend != 'flashinfer'")
+    input_flat = input.reshape(-1, input.shape[-1])
+    rmsnorm_flat = flashinfer.norm.rmsnorm(input_flat, weight, eps)
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (5)

85-98: Docstring should reflect multiple backends, not just FlashInfer.

Update wording to cover FlashInfer, Triton, and Torch.

-"""Graph transform to optimize RMSNorm execution using FlashInfer."""
+"""Graph transform to optimize RMSNorm execution using FlashInfer, Triton, or Torch backends.
+
+Supports both regular and gated RMSNorm paths."""

113-118: Tighten exception messages; satisfy TRY003 hint.

Shorten messages or move details to constants to avoid long f-strings in raises.

-        if self.config.rmsnorm_backend.lower() not in _BACKEND_OPS:
-            raise ValueError(
-                f"Invalid rmsnorm_backend, must be one of {list(_BACKEND_OPS)}, got {self.config.rmsnorm_backend}"
-            )
+        if self.config.rmsnorm_backend.lower() not in _BACKEND_OPS:
+            allowed = ", ".join(_BACKEND_OPS.keys())
+            raise ValueError(f"rmsnorm_backend must be one of [{allowed}]")
@@
-        if self.config.gated_rmsnorm_backend.lower() != "triton":
-            raise ValueError(
-                f"""Invalid gated_rmsnorm_backend, currently only 'triton' is supported,
-                got {self.config.gated_rmsnorm_backend}"""
-            )
+        if self.config.gated_rmsnorm_backend.lower() != "triton":
+            raise ValueError("gated_rmsnorm_backend must be 'triton'")

If Ruff is enabled, confirm TRY003 warnings are gone after this change.

Also applies to: 119-124


158-185: Create gated dummy tensors on CUDA to avoid device-cast mismatches.

Regular RMSNorm dummies are on CUDA; mirror that for gated to reduce to(device) artifacts during tracing.

-        def make_dummy_args_gated(group_size: int, eps: float) -> list:
-            x = torch.randn(B, S, H, dtype=torch.float32)
-            w = torch.randn(H, dtype=torch.float32)
-            g = torch.randn(B, S, H, dtype=torch.float32)
+        def make_dummy_args_gated(group_size: int, eps: float) -> list:
+            x = torch.randn(B, S, H, device="cuda", dtype=torch.float32)
+            w = torch.randn(H, device="cuda", dtype=torch.float32)
+            g = torch.randn(B, S, H, device="cuda", dtype=torch.float32)
             return [x, w, g, eps, group_size]

Alternatively, expand op_ignore_types to cover device args, but CUDA dummies are simpler and consistent with comments in register_ad_pattern.


176-185: Route gated backend via config for future extensibility.

You validate gated_rmsnorm_backend=='triton' but don’t thread it to the replacement. Plumb through now for consistency.

-        register_ad_pattern(
+        register_ad_pattern(
             search_fn=_gated_rmsnorm_pattern_ref,
-            replace_fn=partial(_gated_rmsnorm_replacement),
+            replace_fn=partial(_gated_rmsnorm_replacement, backend=self.config.gated_rmsnorm_backend),
             patterns=patterns,
             dummy_args=make_dummy_args_gated(group_size, eps),
             op_ignore_types=op_ignore_types,
             scalar_workaround={"eps": eps, "group_size": group_size},
             skip_duplicates=True,
         )
@@
-def _gated_rmsnorm_replacement(
+def _gated_rmsnorm_replacement(
     x: torch.Tensor,
     weight: torch.Tensor,
     gate: torch.Tensor,
     eps: float,
     group_size: int,
-) -> torch.Tensor:
-    return torch.ops.auto_deploy.triton_rmsnorm_gated(
-        x, weight, gate, float(eps), int(group_size), False
-    )
+) -> torch.Tensor:
+    return torch.ops.auto_deploy.triton_rmsnorm_gated(
+        x, weight, gate, float(eps), int(group_size), False
+    )

No behavioral change today; it avoids touching this code again when adding more backends. Keeping fp32 from the op is intentional. Based on learnings.

Also applies to: 222-224


70-80: Minor naming: consider aligning _BACKEND_OPS docstrings with keys.

Keys mix "rms_norm" and "rmsnorm" in op names; not a bug, but a brief comment noting the intent would help avoid confusion.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2d86d6b and 476ad6c.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: nvchenghaoz
PR: NVIDIA/TensorRT-LLM#8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
PR: NVIDIA/TensorRT-LLM#8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
📚 Learning: 2025-10-20T17:09:21.560Z
Learnt from: nvchenghaoz
PR: NVIDIA/TensorRT-LLM#8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
🧬 Code graph analysis (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)
  • triton_rmsnorm_gated (87-140)
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (2)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (2)
  • ADPatternMatcherPass (61-67)
  • register_ad_pattern (99-182)
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)
  • triton_rmsnorm_gated (87-140)
🪛 Ruff (0.14.1)
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py

115-117: Avoid specifying long messages outside the exception class

(TRY003)


121-124: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (4)
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)

143-159: Meta returns fp32 as intended.

Keeping meta dtype as fp32 matches the real op and downstream casting expectations. No change needed.

tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (3)

73-80: Config split LGTM.

rmsnorm_backend and gated_rmsnorm_backend fields are clear and documented.


222-224: No cast-back in replacement is correct.

Let downstream handle dtype; matches prior guidance and tests already cast for comparison.


129-157: Regular RMSNorm pattern registration: solid.

Use of partial with backend and scalar workaround is appropriate.

Comment on lines +124 to 126
rmsnorm_backend: flashinfer
gated_rmsnorm_backend: triton
requires_shape_prop: true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Config split is clear; verify no stale uses of the old backend field.

Ensure no remaining consumers expect fuse_rmsnorm.backend or the removed fuse_gated_rmsnorm block.


🏁 Script executed:

#!/bin/bash
# Find old config key usages and removed transform references.
rg -nP '\bfuse_rmsnorm\b.*\bbackend\b|fuse_gated_rmsnorm|torch_rmsnorm_gated' -C2

Length of output: 0


I'll run broader searches to verify if there are any remaining references to the old config field names in the codebase.




---

🏁 Script executed:

```shell
#!/bin/bash
# Search more broadly for old config key references in Python files
# Look for fuse_rmsnorm, fuse_gated_rmsnorm patterns
rg -i 'fuse_rmsnorm|fuse_gated_rmsnorm' --type=py -C3

Length of output: 1482


🏁 Script executed:

#!/bin/bash
# Search for where config is loaded/parsed to see if old keys are referenced
rg -i 'rmsnorm.*backend|gated_rmsnorm' --type=py -C3

Length of output: 6538


🏁 Script executed:

#!/bin/bash
# Check if there are any YAML references to old config structure
fd -e yaml -e yml | xargs grep -i 'fuse_rmsnorm\|fuse_gated_rmsnorm' || echo "No matches in YAML files"

Length of output: 67


I need to check the full YAML config structure and how it's being loaded.




---

🏁 Script executed:

```shell
#!/bin/bash
# View the full default.yaml file to see the complete config structure
cat tensorrt_llm/_torch/auto_deploy/config/default.yaml

Length of output: 6055


🏁 Script executed:

#!/bin/bash
# Search for where config is loaded/parsed to understand the expected structure
rg -i 'fuse_rmsnorm|FuseRMSNormConfig' --type=py -A5 -B5 | head -100

Length of output: 4934


🏁 Script executed:

#!/bin/bash
# Check the test file to see what config it's providing
cat tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py | grep -A20 "fuse_rmsnorm"

Length of output: 557


Test file using removed config field; update required.

The test file tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py is still passing a single "backend" field, which no longer exists. The FuseRMSNormConfig class expects rmsnorm_backend and gated_rmsnorm_backend fields. The test must be updated to match the new config schema—either split the backend selection or use the appropriate field names for the variant being tested.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/config/default.yaml around lines 124-126 the
schema now uses rmsnorm_backend and gated_rmsnorm_backend instead of a single
backend field; update the test at
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
to stop passing "backend" and instead pass the correct field(s): use
"rmsnorm_backend" for regular RMSNorm tests and "gated_rmsnorm_backend" for
gated variants (or include both fields if the code path reads both), and adjust
any assertions/expected config keys in the test to reference the new field
names.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Backlog

Development

Successfully merging this pull request may close these issues.

3 participants