Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,11 @@ transforms:
fuse_allreduce_residual_rmsnorm:
stage: post_load_fusion
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
# check if we can fuse rmsnorm
fuse_rmsnorm:
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
# check if we can fuse rmsnorm
stage: post_load_fusion
backend: flashinfer
rmsnorm_backend: flashinfer
gated_rmsnorm_backend: triton
requires_shape_prop: true
Comment on lines +124 to 126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Config split is clear; verify no stale uses of the old backend field.

Ensure no remaining consumers expect fuse_rmsnorm.backend or the removed fuse_gated_rmsnorm block.


🏁 Script executed:

#!/bin/bash
# Find old config key usages and removed transform references.
rg -nP '\bfuse_rmsnorm\b.*\bbackend\b|fuse_gated_rmsnorm|torch_rmsnorm_gated' -C2

Length of output: 0


I'll run broader searches to verify if there are any remaining references to the old config field names in the codebase.




---

🏁 Script executed:

```shell
#!/bin/bash
# Search more broadly for old config key references in Python files
# Look for fuse_rmsnorm, fuse_gated_rmsnorm patterns
rg -i 'fuse_rmsnorm|fuse_gated_rmsnorm' --type=py -C3

Length of output: 1482


🏁 Script executed:

#!/bin/bash
# Search for where config is loaded/parsed to see if old keys are referenced
rg -i 'rmsnorm.*backend|gated_rmsnorm' --type=py -C3

Length of output: 6538


🏁 Script executed:

#!/bin/bash
# Check if there are any YAML references to old config structure
fd -e yaml -e yml | xargs grep -i 'fuse_rmsnorm\|fuse_gated_rmsnorm' || echo "No matches in YAML files"

Length of output: 67


I need to check the full YAML config structure and how it's being loaded.




---

🏁 Script executed:

```shell
#!/bin/bash
# View the full default.yaml file to see the complete config structure
cat tensorrt_llm/_torch/auto_deploy/config/default.yaml

Length of output: 6055


🏁 Script executed:

#!/bin/bash
# Search for where config is loaded/parsed to understand the expected structure
rg -i 'fuse_rmsnorm|FuseRMSNormConfig' --type=py -A5 -B5 | head -100

Length of output: 4934


🏁 Script executed:

#!/bin/bash
# Check the test file to see what config it's providing
cat tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py | grep -A20 "fuse_rmsnorm"

Length of output: 557


Test file using removed config field; update required.

The test file tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py is still passing a single "backend" field, which no longer exists. The FuseRMSNormConfig class expects rmsnorm_backend and gated_rmsnorm_backend fields. The test must be updated to match the new config schema—either split the backend selection or use the appropriate field names for the variant being tested.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/auto_deploy/config/default.yaml around lines 124-126 the
schema now uses rmsnorm_backend and gated_rmsnorm_backend instead of a single
backend field; update the test at
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
to stop passing "backend" and instead pass the correct field(s): use
"rmsnorm_backend" for regular RMSNorm tests and "gated_rmsnorm_backend" for
gated variants (or include both fields if the code path reads both), and adjust
any assertions/expected config keys in the test to reference the new field
names.

fuse_gated_rmsnorm:
stage: post_load_fusion

############################################################################################
# VISUALIZE GRAPH
Expand Down
8 changes: 4 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
return torch.empty_like(input)


@torch.library.custom_op("auto_deploy::torch_rmsnorm_gated", mutates_args=())
def torch_rmsnorm_gated(
@torch.library.custom_op("auto_deploy::triton_rmsnorm_gated", mutates_args=())
def triton_rmsnorm_gated(
x: torch.Tensor,
weight: torch.Tensor,
gate: torch.Tensor | None,
Expand Down Expand Up @@ -140,8 +140,8 @@ def torch_rmsnorm_gated(
return out2.reshape(x_shape)


@torch_rmsnorm_gated.register_fake
def _torch_rmsnorm_gated_meta(
@triton_rmsnorm_gated.register_fake
def _triton_rmsnorm_gated_meta(
x,
weight,
gate,
Expand Down
119 changes: 53 additions & 66 deletions tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,28 @@ def _rms_norm_replacement(
class FuseRMSNormConfig(TransformConfig):
"""Configuration for the RMSNorm fusion transform."""

backend: str = Field(
rmsnorm_backend: str = Field(
default="flashinfer",
description="Backend to use for RMSNorm computation ('flashinfer' or 'triton').",
description="Backend to use for RMSNorm computation ('flashinfer', 'triton', or 'torch').",
)
gated_rmsnorm_backend: str = Field(
default="triton",
description="Backend to use for gated RMSNorm computation (currently only 'triton').",
)


@TransformRegistry.register("fuse_rmsnorm")
class FuseRMSNorm(BaseTransform):
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
"""Matches and replaces RMSNorm patterns (regular and gated) in the graph with optimized implementations.

This function sets up pattern matching to identify RMSNorm operations in the graph
This function sets up pattern matching to identify both regular and gated RMSNorm operations in the graph
and replaces them with optimized implementations. It uses dummy tensors to register
the pattern matching rules.

Args:
gm: Input graph module to transform.
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
rmsnorm_backend: Backend to use for regular RMSNorm computation ("flashinfer", "triton", or "torch").
gated_rmsnorm_backend: Backend to use for gated RMSNorm computation (currently only "triton").

Returns:
Transformed graph module with optimized RMSNorm operations.
Expand All @@ -105,15 +110,23 @@ def _apply(
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
if self.config.backend.lower() not in _BACKEND_OPS:
# Validate rmsnorm_backend
if self.config.rmsnorm_backend.lower() not in _BACKEND_OPS:
raise ValueError(
f"Invalid rmsnorm_backend, must be one of {list(_BACKEND_OPS)}, got {self.config.rmsnorm_backend}"
)

# Validate gated_rmsnorm_backend (currently only triton is supported)
if self.config.gated_rmsnorm_backend.lower() != "triton":
raise ValueError(
f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {self.config.backend}"
f"""Invalid gated_rmsnorm_backend, currently only 'triton' is supported,
got {self.config.gated_rmsnorm_backend}"""
)

graph = gm.graph
patterns = ADPatternMatcherPass()

# Create dummy tensors for pattern matching
# Pattern matching for regular RMSNorm
bs = 2
hidden_size = 512

Expand All @@ -131,17 +144,46 @@ def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float =
(torch.float32, torch.float32),
]

# Register patterns for each configuration
# Register patterns for regular RMSNorm
for input_dtype, weight_dtype in configs:
register_ad_pattern(
search_fn=_rms_norm_pattern,
replace_fn=partial(_rms_norm_replacement, backend=self.config.backend),
replace_fn=partial(_rms_norm_replacement, backend=self.config.rmsnorm_backend),
patterns=patterns,
dummy_args=dummy_args(input_dtype, weight_dtype),
op_ignore_types={},
scalar_workaround={"eps": 1e-6},
)

# Pattern matching for gated RMSNorm
B, S, H = 2, 3, 4096
group_size = 512
eps = 1e-5

def make_dummy_args_gated(group_size: int, eps: float) -> list:
x = torch.randn(B, S, H, dtype=torch.float32)
w = torch.randn(H, dtype=torch.float32)
g = torch.randn(B, S, H, dtype=torch.float32)
return [x, w, g, eps, group_size]

op_ignore_types = {
torch.ops.aten.reshape.default: (int, list, tuple),
torch.ops.aten.view.default: (int, list, tuple),
torch.ops.aten.mean.dim: (list, tuple),
torch.ops.aten.to.dtype: (torch.dtype,),
}

# Register pattern for gated RMSNorm
register_ad_pattern(
search_fn=_gated_rmsnorm_pattern_ref,
replace_fn=partial(_gated_rmsnorm_replacement),
patterns=patterns,
dummy_args=make_dummy_args_gated(group_size, eps),
op_ignore_types=op_ignore_types,
scalar_workaround={"eps": eps, "group_size": group_size},
skip_duplicates=True,
)

cnt = patterns.apply(graph)

info = TransformInfo(skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=False)
Expand Down Expand Up @@ -177,61 +219,6 @@ def _gated_rmsnorm_replacement(
eps: float,
group_size: int,
) -> torch.Tensor:
return torch.ops.auto_deploy.torch_rmsnorm_gated(
return torch.ops.auto_deploy.triton_rmsnorm_gated(
x, weight, gate, float(eps), int(group_size), False
)


@TransformRegistry.register("fuse_gated_rmsnorm")
class FuseGatedRMSNorm(BaseTransform):
"""
Fuse the NemotronH-style gated RMSNorm subgraph into a single custom op:
auto_deploy::torch_rmsnorm_gated(x, weight, gate, eps, group_size, norm_before_gate=False)
"""

def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
graph = gm.graph
patterns = ADPatternMatcherPass()

B, S, H = 2, 3, 4096
group_size = 512
eps = 1e-5

def make_dummy_args(group_size: int, eps: float) -> list:
x = torch.randn(B, S, H, dtype=torch.float32)
w = torch.randn(H, dtype=torch.float32)
g = torch.randn(B, S, H, dtype=torch.float32)
return [x, w, g, eps, group_size]

op_ignore_types = {
torch.ops.aten.reshape.default: (int, list, tuple),
torch.ops.aten.view.default: (int, list, tuple),
torch.ops.aten.mean.dim: (list, tuple),
torch.ops.aten.to.dtype: (torch.dtype,),
}

register_ad_pattern(
search_fn=_gated_rmsnorm_pattern_ref,
replace_fn=partial(_gated_rmsnorm_replacement),
patterns=patterns,
dummy_args=make_dummy_args(group_size, eps),
op_ignore_types=op_ignore_types,
scalar_workaround={"eps": eps, "group_size": group_size},
skip_duplicates=True,
)

num = patterns.apply(graph)

info = TransformInfo(
skipped=False,
num_matches=num,
is_clean=False,
has_valid_shapes=False,
)
return gm, info
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_custom_op_matches_ref(B, T, H, group, use_gate, dtype):
)

# Custom op (currently returns fp32). Cast it back to x.dtype for apples-to-apples with ref.
y_op_fp32 = torch.ops.auto_deploy.torch_rmsnorm_gated(x, w, z, 1e-5, group, False)
y_op_fp32 = torch.ops.auto_deploy.triton_rmsnorm_gated(x, w, z, 1e-5, group, False)
y_op = y_op_fp32.to(x.dtype)

assert y_ref.dtype == x.dtype and y_op.dtype == x.dtype
Expand Down