diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 4bc4b74fd5d..e24f645576a 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -119,15 +119,11 @@ transforms: fuse_allreduce_residual_rmsnorm: stage: post_load_fusion # TODO (lucaslie): add backend selection as part of configurable inference optimizers - # check if we can fuse rmsnorm fuse_rmsnorm: - # TODO (lucaslie): add backend selection as part of configurable inference optimizers - # check if we can fuse rmsnorm stage: post_load_fusion - backend: flashinfer + rmsnorm_backend: flashinfer + gated_rmsnorm_backend: triton requires_shape_prop: true - fuse_gated_rmsnorm: - stage: post_load_fusion ############################################################################################ # VISUALIZE GRAPH diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py index 88335de2056..f4b98d49df0 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py @@ -83,8 +83,8 @@ def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: return torch.empty_like(input) -@torch.library.custom_op("auto_deploy::torch_rmsnorm_gated", mutates_args=()) -def torch_rmsnorm_gated( +@torch.library.custom_op("auto_deploy::triton_rmsnorm_gated", mutates_args=()) +def triton_rmsnorm_gated( x: torch.Tensor, weight: torch.Tensor, gate: torch.Tensor | None, @@ -140,8 +140,8 @@ def torch_rmsnorm_gated( return out2.reshape(x_shape) -@torch_rmsnorm_gated.register_fake -def _torch_rmsnorm_gated_meta( +@triton_rmsnorm_gated.register_fake +def _triton_rmsnorm_gated_meta( x, weight, gate, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py index 6cca71855cd..d50208ab54c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py @@ -70,23 +70,28 @@ def _rms_norm_replacement( class FuseRMSNormConfig(TransformConfig): """Configuration for the RMSNorm fusion transform.""" - backend: str = Field( + rmsnorm_backend: str = Field( default="flashinfer", - description="Backend to use for RMSNorm computation ('flashinfer' or 'triton').", + description="Backend to use for RMSNorm computation ('flashinfer', 'triton', or 'torch').", + ) + gated_rmsnorm_backend: str = Field( + default="triton", + description="Backend to use for gated RMSNorm computation (currently only 'triton').", ) @TransformRegistry.register("fuse_rmsnorm") class FuseRMSNorm(BaseTransform): - """Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation. + """Matches and replaces RMSNorm patterns (regular and gated) in the graph with optimized implementations. - This function sets up pattern matching to identify RMSNorm operations in the graph + This function sets up pattern matching to identify both regular and gated RMSNorm operations in the graph and replaces them with optimized implementations. It uses dummy tensors to register the pattern matching rules. Args: gm: Input graph module to transform. - backend: Backend to use for RMSNorm computation ("flashinfer" or "triton"). + rmsnorm_backend: Backend to use for regular RMSNorm computation ("flashinfer", "triton", or "torch"). + gated_rmsnorm_backend: Backend to use for gated RMSNorm computation (currently only "triton"). Returns: Transformed graph module with optimized RMSNorm operations. @@ -105,15 +110,23 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - if self.config.backend.lower() not in _BACKEND_OPS: + # Validate rmsnorm_backend + if self.config.rmsnorm_backend.lower() not in _BACKEND_OPS: + raise ValueError( + f"Invalid rmsnorm_backend, must be one of {list(_BACKEND_OPS)}, got {self.config.rmsnorm_backend}" + ) + + # Validate gated_rmsnorm_backend (currently only triton is supported) + if self.config.gated_rmsnorm_backend.lower() != "triton": raise ValueError( - f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {self.config.backend}" + f"""Invalid gated_rmsnorm_backend, currently only 'triton' is supported, + got {self.config.gated_rmsnorm_backend}""" ) graph = gm.graph patterns = ADPatternMatcherPass() - # Create dummy tensors for pattern matching + # Pattern matching for regular RMSNorm bs = 2 hidden_size = 512 @@ -131,17 +144,46 @@ def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = (torch.float32, torch.float32), ] - # Register patterns for each configuration + # Register patterns for regular RMSNorm for input_dtype, weight_dtype in configs: register_ad_pattern( search_fn=_rms_norm_pattern, - replace_fn=partial(_rms_norm_replacement, backend=self.config.backend), + replace_fn=partial(_rms_norm_replacement, backend=self.config.rmsnorm_backend), patterns=patterns, dummy_args=dummy_args(input_dtype, weight_dtype), op_ignore_types={}, scalar_workaround={"eps": 1e-6}, ) + # Pattern matching for gated RMSNorm + B, S, H = 2, 3, 4096 + group_size = 512 + eps = 1e-5 + + def make_dummy_args_gated(group_size: int, eps: float) -> list: + x = torch.randn(B, S, H, dtype=torch.float32) + w = torch.randn(H, dtype=torch.float32) + g = torch.randn(B, S, H, dtype=torch.float32) + return [x, w, g, eps, group_size] + + op_ignore_types = { + torch.ops.aten.reshape.default: (int, list, tuple), + torch.ops.aten.view.default: (int, list, tuple), + torch.ops.aten.mean.dim: (list, tuple), + torch.ops.aten.to.dtype: (torch.dtype,), + } + + # Register pattern for gated RMSNorm + register_ad_pattern( + search_fn=_gated_rmsnorm_pattern_ref, + replace_fn=partial(_gated_rmsnorm_replacement), + patterns=patterns, + dummy_args=make_dummy_args_gated(group_size, eps), + op_ignore_types=op_ignore_types, + scalar_workaround={"eps": eps, "group_size": group_size}, + skip_duplicates=True, + ) + cnt = patterns.apply(graph) info = TransformInfo(skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=False) @@ -177,61 +219,6 @@ def _gated_rmsnorm_replacement( eps: float, group_size: int, ) -> torch.Tensor: - return torch.ops.auto_deploy.torch_rmsnorm_gated( + return torch.ops.auto_deploy.triton_rmsnorm_gated( x, weight, gate, float(eps), int(group_size), False ) - - -@TransformRegistry.register("fuse_gated_rmsnorm") -class FuseGatedRMSNorm(BaseTransform): - """ - Fuse the NemotronH-style gated RMSNorm subgraph into a single custom op: - auto_deploy::torch_rmsnorm_gated(x, weight, gate, eps, group_size, norm_before_gate=False) - """ - - def _apply( - self, - gm: GraphModule, - cm: CachedSequenceInterface, - factory: ModelFactory, - shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: - graph = gm.graph - patterns = ADPatternMatcherPass() - - B, S, H = 2, 3, 4096 - group_size = 512 - eps = 1e-5 - - def make_dummy_args(group_size: int, eps: float) -> list: - x = torch.randn(B, S, H, dtype=torch.float32) - w = torch.randn(H, dtype=torch.float32) - g = torch.randn(B, S, H, dtype=torch.float32) - return [x, w, g, eps, group_size] - - op_ignore_types = { - torch.ops.aten.reshape.default: (int, list, tuple), - torch.ops.aten.view.default: (int, list, tuple), - torch.ops.aten.mean.dim: (list, tuple), - torch.ops.aten.to.dtype: (torch.dtype,), - } - - register_ad_pattern( - search_fn=_gated_rmsnorm_pattern_ref, - replace_fn=partial(_gated_rmsnorm_replacement), - patterns=patterns, - dummy_args=make_dummy_args(group_size, eps), - op_ignore_types=op_ignore_types, - scalar_workaround={"eps": eps, "group_size": group_size}, - skip_duplicates=True, - ) - - num = patterns.apply(graph) - - info = TransformInfo( - skipped=False, - num_matches=num, - is_clean=False, - has_valid_shapes=False, - ) - return gm, info diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py index 96b0ef072e3..35b293686d2 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py @@ -24,7 +24,7 @@ def test_custom_op_matches_ref(B, T, H, group, use_gate, dtype): ) # Custom op (currently returns fp32). Cast it back to x.dtype for apples-to-apples with ref. - y_op_fp32 = torch.ops.auto_deploy.torch_rmsnorm_gated(x, w, z, 1e-5, group, False) + y_op_fp32 = torch.ops.auto_deploy.triton_rmsnorm_gated(x, w, z, 1e-5, group, False) y_op = y_op_fp32.to(x.dtype) assert y_ref.dtype == x.dtype and y_op.dtype == x.dtype