From 9603844e32b6f95899c2f179cb09cc967d73c496 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Wed, 19 Nov 2025 18:55:55 -0800 Subject: [PATCH 01/15] Enable multi-stream MOE optimization in AutoDeploy Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 3 + .../auto_deploy/custom_ops/multi_stream.py | 223 ++++++++++++++++++ .../auto_deploy/models/patches/nemotron_h.py | 18 +- .../transform/library/multi_stream_moe.py | 90 +++++++ .../singlegpu/custom_ops/test_multi_stream.py | 134 +++++++++++ .../unit/singlegpu/test_multi_stream.py | 139 +++++++++++ 6 files changed, 596 insertions(+), 11 deletions(-) create mode 100644 tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/test_multi_stream.py diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 2bd93277002..880fd6834e4 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -161,6 +161,9 @@ transforms: ############################################################################################ fuse_causal_conv_activation: stage: compile + multi_stream_moe: + stage: compile + enabled: true compile_model: stage: compile run_per_gm: false diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py new file mode 100644 index 00000000000..14902b9b081 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py @@ -0,0 +1,223 @@ +""" +Custom ops to enable multi-stream execution. +""" + +from __future__ import annotations + +from threading import RLock +from typing import Any, Callable, Dict, Tuple + +import torch + + +class _Singleton(type): + _instances: Dict[type, Any] = {} + _lock = RLock() + + def __call__(cls, *args: Any, **kwargs: Any) -> Any: + if cls not in cls._instances: + with cls._lock: + if cls not in cls._instances: # double-checked locking + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] + + +# A singleton that holds the pointers to the cuda streams and events. +# In multi-gpu scenario, each GPU/rank has its own CudaStreamManager. +class CudaStreamManager(metaclass=_Singleton): + AUX_STREAM_NAME = "aux" + MAIN_STREAM_NAME = "main" + + def __init__(self) -> None: + # In case __init__ ever gets called twice, guard against re-init + if hasattr(self, "streams"): + return + + self._lock = RLock() + + # Events needed for stream synchronization + self.events: Dict[str, Any] = { + self.AUX_STREAM_NAME: torch.cuda.Event(), + self.MAIN_STREAM_NAME: torch.cuda.Event(), + } + + # Streams for multi-stream execution + self.aux_stream = torch.cuda.Stream() + self.streams: Dict[str, Any] = { + self.AUX_STREAM_NAME: self.aux_stream, + self.MAIN_STREAM_NAME: torch.cuda.default_stream(), + } + + +cuda_stream_manager = CudaStreamManager() + + +@torch.library.custom_op("auto_deploy::record_event", mutates_args=()) +def record_event(stream_name: str) -> None: + event = cuda_stream_manager.events[stream_name] + event.record() + + +@torch.library.custom_op("auto_deploy::wait_event", mutates_args=()) +def wait_event(event_name: str) -> None: + event = cuda_stream_manager.events[event_name] + event.wait() + + +def record_event_wrapper( + fn: Callable, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> torch.Tensor: + output = fn(*args, **kwargs) + torch.ops.auto_deploy.record_event(cuda_stream_manager.MAIN_STREAM_NAME) + return output + + +def aux_stream_wrapper( + fn: Callable, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> torch.Tensor: + stream_name = cuda_stream_manager.AUX_STREAM_NAME + with torch.cuda.stream(cuda_stream_manager.streams[stream_name]): + torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME) + output = fn(*args, **kwargs) + torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME) + torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME) + return output + + +# bf16 +@torch.library.custom_op("auto_deploy::trtllm_moe_fused_aux", mutates_args=()) +def trtllm_moe_fused_aux( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w3_w1_stacked_weight: torch.Tensor, + w2_stacked_weight: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]): + torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME) + output = torch.ops.auto_deploy.trtllm_moe_fused( + x, + selected_experts, + routing_weights, + w3_w1_stacked_weight, + w2_stacked_weight, + mlp_style, + act_fn, + ) + torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME) + torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME) + return output + + +@trtllm_moe_fused_aux.register_fake +def trtllm_moe_fused_aux_fake( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w3_w1_stacked_weight: torch.Tensor, + w2_stacked_weight: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + return torch.empty_like(x) + + +# triton bf16 +@torch.library.custom_op("auto_deploy::triton_moe_fused_aux", mutates_args=()) +def triton_moe_fused_aux( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_stacked_weight: torch.Tensor, + w2_stacked_weight: torch.Tensor, +) -> torch.Tensor: + with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]): + torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME) + output = torch.ops.auto_deploy.triton_moe_fused( + x, + selected_experts, + routing_weights, + w1_stacked_weight, + w2_stacked_weight, + ) + torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME) + torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME) + return output + + +@triton_moe_fused_aux.register_fake +def triton_moe_fused_aux_fake( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_stacked_weight: torch.Tensor, + w2_stacked_weight: torch.Tensor, +) -> torch.Tensor: + return torch.empty_like(x) + + +# trtllm fp8 +@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused_aux", mutates_args=()) +def trtllm_quant_fp8_moe_fused_aux( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights + w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights + w3_weight: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp + w1_input_scale: torch.Tensor, # [E] stacked input scales + w2_input_scale: torch.Tensor, # [E] stacked input scales + w3_input_scale: torch.Tensor, # [E] or unused + w1_weight_scale: torch.Tensor, # [E] stacked weight scales + w2_weight_scale: torch.Tensor, # [E] stacked weight scales + w3_weight_scale: torch.Tensor, # [E] or unused + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]): + torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME) + output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused( + x, + selected_experts, + routing_weights, + w1_weight, + w2_weight, + w3_weight, + w1_input_scale, + w2_input_scale, + w3_input_scale, + w1_weight_scale, + w2_weight_scale, + w3_weight_scale, + mlp_style, + act_fn, + ) + torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME) + torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME) + return output + + +@trtllm_quant_fp8_moe_fused_aux.register_fake +def trtllm_quant_fp8_moe_fused_fake( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_weight: torch.Tensor, + w2_weight: torch.Tensor, + w3_weight: torch.Tensor, + w1_input_scale: torch.Tensor, + w2_input_scale: torch.Tensor, + w3_input_scale: torch.Tensor, + w1_weight_scale: torch.Tensor, + w2_weight_scale: torch.Tensor, + w3_weight_scale: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py index fb6a6dafe7d..4dca4afd853 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py @@ -128,6 +128,10 @@ def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor): topk_indices, topk_weights = self.gate(hidden_states) x_flat = hidden_states.view(-1, hidden_states.shape[-1]) + # NOTE: So far we've seen that the dispatch order in eager code is the same as the node order in the exported graph. + # We dispatch shared expert first so that we can easily fork the execution of the routed experts + # (using the custom op below) to an auxiliary stream. + shared_out = self.shared_experts(residuals) # Check if this is a latent MOE (has fc1_latent_proj and fc2_latent_proj) has_latent_proj = hasattr(self, "fc1_latent_proj") and hasattr(self, "fc2_latent_proj") @@ -151,8 +155,8 @@ def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor): # Latent MOE: project back from latent space out_flat = self.fc2_latent_proj(out_flat) - out = out_flat.view(*orig_shape) - out = out + self.shared_experts(residuals) + routed_out = out_flat.view(*orig_shape) + out = shared_out + routed_out return out @@ -187,23 +191,15 @@ def get_model_from_config_patched(config, **kwargs): _config_from_pretrained_original = AutoConfig.from_pretrained _nemotron_h_base_model_tp_plan = { - # mamba SSM layer "in_proj": "mamba", "out_proj": "rowwise", - # attention layer "q_proj": "colwise", "k_proj": "colwise", "v_proj": "colwise", "o_proj": "rowwise", - # NOTE: consider not sharding shared experts and/or - # latent projections at all, keeping them replicated. - # To do so, comment out the corresponding entries. - # moe layer: SHARED experts "up_proj": "colwise", "down_proj": "rowwise", - # MoLE: latent projections: simple shard - "fc1_latent_proj": "gather", - "fc2_latent_proj": "gather", + # "*": "gather", } diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py new file mode 100644 index 00000000000..1c6015ce558 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py @@ -0,0 +1,90 @@ +"""Transform for multi-stream execution of MoE layers that have shared experts and routed experts.""" + +from typing import Callable, Dict, Tuple + +import torch +from torch.fx import GraphModule, Node + +from tensorrt_llm._torch.auto_deploy.custom_ops.multi_stream import record_event_wrapper + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.node_utils import is_op +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +def _execute_op_in_aux_stream( + gm: GraphModule, op_dict: Dict[Callable, Callable] +) -> Tuple[GraphModule, int]: + graph = gm.graph + num_replaced = 0 + + # Collect targets first to avoid mutating while iterating + target_nodes: list[Node] = [] + for n in graph.nodes: + if is_op(n, op_dict.keys()): + target_nodes.append(n) + + for n in target_nodes: + target_input_node = None + for input_node in n.all_input_nodes: + if input_node.target == torch.ops.aten.view.default: + target_input_node = input_node + break + + if target_input_node is None: + raise ValueError(f"Target input node not found for node {n}") + with graph.inserting_before(target_input_node): + new_node = graph.call_function( + record_event_wrapper, + args=(target_input_node.target, *target_input_node.args), + kwargs=target_input_node.kwargs, + ) + target_input_node.replace_all_uses_with(new_node) + graph.erase_node(target_input_node) + with graph.inserting_after(n): + new_node = graph.call_function(op_dict[n.target], args=n.args, kwargs=n.kwargs) + n.replace_all_uses_with(new_node) + graph.erase_node(n) + num_replaced += 1 + if num_replaced: + graph.eliminate_dead_code() + graph.lint() + gm.recompile() + + return gm, num_replaced + + +@TransformRegistry.register("multi_stream_moe") +class MultiStreamMOE(BaseTransform): + """Multi-stream execution of MoE layers that have shared experts and routed experts.""" + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + with open("graph-before-moe.txt", "w") as f: + f.write(str(gm.graph)) + print("wrote graph to graph-before-moe.txt") + + op_dict = { + torch.ops.auto_deploy.trtllm_moe_fused: torch.ops.auto_deploy.trtllm_moe_fused_aux, + torch.ops.auto_deploy.triton_moe_fused: torch.ops.auto_deploy.triton_moe_fused_aux, + torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused: torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused_aux, + } + + gm, num_matches = _execute_op_in_aux_stream(gm, op_dict) + + info = TransformInfo( + skipped=False, + num_matches=num_matches, + is_clean=False, + has_valid_shapes=False, + ) + with open("graph-after-moe.txt", "w") as f: + f.write(str(gm.graph)) + print("wrote graph to graph-after-moe.txt") + return gm, info diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py new file mode 100644 index 00000000000..e008c4a99d8 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py @@ -0,0 +1,134 @@ +from typing import Tuple + +import torch +import torch.nn as nn +from torch.fx import GraphModule, Node + +from tensorrt_llm._torch.auto_deploy.custom_ops.multi_stream import ( + aux_stream_wrapper, + record_event_wrapper, +) +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op + + +@torch.library.custom_op("auto_deploy::multi_stream_linear", mutates_args=()) +def multi_stream_linear( + input: torch.Tensor, weight0: torch.Tensor, weight1: torch.Tensor +) -> torch.Tensor: + output = torch.ops.aten.linear(input, weight0) + output = torch.ops.aten.linear(output, weight1) + return output + + +@multi_stream_linear.register_fake +def multi_stream_linear_fake(input, weight0, weight1): + """Fake implementation of multi_stream_linear.""" + output = torch.ops.aten.linear(input, weight0) + return torch.ops.aten.linear(output, weight1) + + +def replace_multi_stream_linear_with_aux_stream_wrapper(gm: GraphModule) -> Tuple[GraphModule, int]: + """Traverse ``gm`` and replace all ``auto_deploy::multi_stream_linear`` ops with ``aux_stream_wrapper``. + + The replacement preserves the original args/kwargs of the node. + After rewriting, the graph is cleaned and recompiled. + + Args: + gm: The FX graph module to transform. + aux_stream_wrapper: A callable to replace the custom op with. + + Returns: + A tuple of (gm, num_replaced) + """ + graph = gm.graph + num_replaced = 0 + + # Collect targets first to avoid mutating while iterating + target_nodes: list[Node] = [] + for n in graph.nodes: + if is_op(n, torch.ops.auto_deploy.multi_stream_linear): + target_nodes.append(n) + + for n in target_nodes: + target_input_node = None + for input_node in n.all_input_nodes: + if len(input_node.users) > 1: + target_input_node = input_node + break + if target_input_node is None: + raise ValueError(f"Target input node not found for node {n}") + with graph.inserting_before(target_input_node): + new_node = graph.call_function( + record_event_wrapper, + args=(target_input_node.target, *target_input_node.args), + kwargs=target_input_node.kwargs, + ) + target_input_node.replace_all_uses_with(new_node) + graph.erase_node(target_input_node) + with graph.inserting_after(n): + new_node = graph.call_function( + aux_stream_wrapper, args=(n.target, *n.args), kwargs=n.kwargs + ) + n.replace_all_uses_with(new_node) + graph.erase_node(n) + num_replaced += 1 + + if num_replaced: + graph.eliminate_dead_code() + graph.lint() + gm.recompile() + + return gm, num_replaced + + +class ParallelTwoLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.fc10 = nn.Linear(in_dim, in_dim) + self.fc11 = nn.Linear(in_dim, out_dim) + self.fc2 = nn.Linear(in_dim, out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.relu(x) + y0 = self.fc2(x) + y1 = torch.ops.auto_deploy.multi_stream_linear(x, self.fc10.weight, self.fc11.weight) + return y0 + y1 + + +def test_multi_stream_linear(): + in_dim, out_dim = 128, 256 + + model = ( + nn.Sequential(ParallelTwoLinear(in_dim, out_dim), ParallelTwoLinear(out_dim, out_dim)) + .eval() + .to("cuda") + ) + + # Example input used for export + example_input = torch.randn(4, in_dim).to("cuda") + + # Export the graph + egm = torch.export.export(model, (example_input,)) + gm = egm.module() + + test_x = torch.randn(4, in_dim).to("cuda") + ref_output = model(test_x) + + # pattern matching and replace + gm, num_replaced = replace_multi_stream_linear_with_aux_stream_wrapper(gm) + + assert num_replaced == 2 + y = gm(test_x) + assert torch.allclose(y, ref_output) + + static_x = torch.randn(4, in_dim).to("cuda") + static_output = torch.randn(4, out_dim).to("cuda") + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + static_output.copy_(gm(static_x)) + + static_x.copy_(test_x) + graph.replay() + + assert torch.allclose(static_output, ref_output) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_multi_stream.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_multi_stream.py new file mode 100644 index 00000000000..5dffc0bd392 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_multi_stream.py @@ -0,0 +1,139 @@ +from typing import Tuple + +import torch +import torch.nn as nn +from torch.fx import GraphModule, Node + +from tensorrt_llm._torch.auto_deploy.custom_ops.multi_stream import ( + aux_stream_wrapper, + record_event_wrapper, +) +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op + + +@torch.library.custom_op("auto_deploy::multi_stream_linear", mutates_args=()) +def multi_stream_linear( + input: torch.Tensor, weight0: torch.Tensor, weight1: torch.Tensor +) -> torch.Tensor: + output = torch.ops.aten.linear(input, weight0) + output = torch.ops.aten.linear(output, weight1) + return output + + +@multi_stream_linear.register_fake +def multi_stream_linear_fake(input, weight0, weight1): + """Fake implementation of multi_stream_linear.""" + output = torch.ops.aten.linear(input, weight0) + return torch.ops.aten.linear(output, weight1) + + +def replace_multi_stream_linear_with_aux_stream_wrapper(gm: GraphModule) -> Tuple[GraphModule, int]: + """Traverse ``gm`` and replace all ``auto_deploy::multi_stream_linear`` ops with ``aux_stream_wrapper``. + + The replacement preserves the original args/kwargs of the node. + After rewriting, the graph is cleaned and recompiled. + + Args: + gm: The FX graph module to transform. + aux_stream_wrapper: A callable to replace the custom op with. + + Returns: + A tuple of (gm, num_replaced) + """ + graph = gm.graph + num_replaced = 0 + + # Collect targets first to avoid mutating while iterating + target_nodes: list[Node] = [] + for n in graph.nodes: + if is_op(n, torch.ops.auto_deploy.multi_stream_linear): + target_nodes.append(n) + + for n in target_nodes: + target_input_node = None + for input_node in n.all_input_nodes: + if len(input_node.users) > 1: + target_input_node = input_node + break + if target_input_node is None: + raise ValueError(f"Target input node not found for node {n}") + with graph.inserting_before(target_input_node): + new_node = graph.call_function( + record_event_wrapper, + args=(target_input_node.target, *target_input_node.args), + kwargs=target_input_node.kwargs, + ) + target_input_node.replace_all_uses_with(new_node) + graph.erase_node(target_input_node) + with graph.inserting_after(n): + new_node = graph.call_function( + aux_stream_wrapper, args=(n.target, *n.args), kwargs=n.kwargs + ) + n.replace_all_uses_with(new_node) + graph.erase_node(n) + num_replaced += 1 + + if num_replaced: + graph.eliminate_dead_code() + graph.lint() + gm.recompile() + + return gm, num_replaced + + +class ParallelTwoLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.fc10 = nn.Linear(in_dim, in_dim) + self.fc11 = nn.Linear(in_dim, out_dim) + self.fc2 = nn.Linear(in_dim, out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.relu(x) + y0 = self.fc2(x) + y1 = multi_stream_linear(x, self.fc10.weight, self.fc11.weight) + return y0 + y1 + + +in_dim, out_dim = 128, 256 + +model = ( + nn.Sequential(ParallelTwoLinear(in_dim, out_dim), ParallelTwoLinear(out_dim, out_dim)) + .eval() + .to("cuda") +) + +# Example input used for export +example_input = torch.randn(4, in_dim).to("cuda") + +# Export the graph +egm = torch.export.export(model, (example_input,)) +gm = egm.module() +output = gm(example_input) + +test_x = torch.randn(4, in_dim).to("cuda") +ref_output = model(test_x) + +# pattern matching and replace +gm, num_replaced = replace_multi_stream_linear_with_aux_stream_wrapper(gm) +print(f"Replaced {num_replaced} nodes") +print(gm.graph) +y = gm(test_x) +assert torch.allclose(y, ref_output) + +static_x = torch.randn(4, in_dim).to("cuda") +static_output = torch.randn(4, out_dim).to("cuda") + +graph = torch.cuda.CUDAGraph() +with torch.cuda.graph(graph): + static_output.copy_(gm(static_x)) + +static_x.copy_(test_x) +graph.replay() + +assert torch.allclose(static_output, ref_output) +for i in range(100): + gm(torch.randn(4, in_dim).to("cuda")) + +for i in range(100): + graph.replay() From 1a0f90b53db73af3eefa4ee858923bd03718f45d Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Wed, 19 Nov 2025 18:58:03 -0800 Subject: [PATCH 02/15] remove duplicate test file Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../unit/singlegpu/test_multi_stream.py | 139 ------------------ 1 file changed, 139 deletions(-) delete mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/test_multi_stream.py diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_multi_stream.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_multi_stream.py deleted file mode 100644 index 5dffc0bd392..00000000000 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_multi_stream.py +++ /dev/null @@ -1,139 +0,0 @@ -from typing import Tuple - -import torch -import torch.nn as nn -from torch.fx import GraphModule, Node - -from tensorrt_llm._torch.auto_deploy.custom_ops.multi_stream import ( - aux_stream_wrapper, - record_event_wrapper, -) -from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op - - -@torch.library.custom_op("auto_deploy::multi_stream_linear", mutates_args=()) -def multi_stream_linear( - input: torch.Tensor, weight0: torch.Tensor, weight1: torch.Tensor -) -> torch.Tensor: - output = torch.ops.aten.linear(input, weight0) - output = torch.ops.aten.linear(output, weight1) - return output - - -@multi_stream_linear.register_fake -def multi_stream_linear_fake(input, weight0, weight1): - """Fake implementation of multi_stream_linear.""" - output = torch.ops.aten.linear(input, weight0) - return torch.ops.aten.linear(output, weight1) - - -def replace_multi_stream_linear_with_aux_stream_wrapper(gm: GraphModule) -> Tuple[GraphModule, int]: - """Traverse ``gm`` and replace all ``auto_deploy::multi_stream_linear`` ops with ``aux_stream_wrapper``. - - The replacement preserves the original args/kwargs of the node. - After rewriting, the graph is cleaned and recompiled. - - Args: - gm: The FX graph module to transform. - aux_stream_wrapper: A callable to replace the custom op with. - - Returns: - A tuple of (gm, num_replaced) - """ - graph = gm.graph - num_replaced = 0 - - # Collect targets first to avoid mutating while iterating - target_nodes: list[Node] = [] - for n in graph.nodes: - if is_op(n, torch.ops.auto_deploy.multi_stream_linear): - target_nodes.append(n) - - for n in target_nodes: - target_input_node = None - for input_node in n.all_input_nodes: - if len(input_node.users) > 1: - target_input_node = input_node - break - if target_input_node is None: - raise ValueError(f"Target input node not found for node {n}") - with graph.inserting_before(target_input_node): - new_node = graph.call_function( - record_event_wrapper, - args=(target_input_node.target, *target_input_node.args), - kwargs=target_input_node.kwargs, - ) - target_input_node.replace_all_uses_with(new_node) - graph.erase_node(target_input_node) - with graph.inserting_after(n): - new_node = graph.call_function( - aux_stream_wrapper, args=(n.target, *n.args), kwargs=n.kwargs - ) - n.replace_all_uses_with(new_node) - graph.erase_node(n) - num_replaced += 1 - - if num_replaced: - graph.eliminate_dead_code() - graph.lint() - gm.recompile() - - return gm, num_replaced - - -class ParallelTwoLinear(nn.Module): - def __init__(self, in_dim: int, out_dim: int): - super().__init__() - self.fc10 = nn.Linear(in_dim, in_dim) - self.fc11 = nn.Linear(in_dim, out_dim) - self.fc2 = nn.Linear(in_dim, out_dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = torch.nn.functional.relu(x) - y0 = self.fc2(x) - y1 = multi_stream_linear(x, self.fc10.weight, self.fc11.weight) - return y0 + y1 - - -in_dim, out_dim = 128, 256 - -model = ( - nn.Sequential(ParallelTwoLinear(in_dim, out_dim), ParallelTwoLinear(out_dim, out_dim)) - .eval() - .to("cuda") -) - -# Example input used for export -example_input = torch.randn(4, in_dim).to("cuda") - -# Export the graph -egm = torch.export.export(model, (example_input,)) -gm = egm.module() -output = gm(example_input) - -test_x = torch.randn(4, in_dim).to("cuda") -ref_output = model(test_x) - -# pattern matching and replace -gm, num_replaced = replace_multi_stream_linear_with_aux_stream_wrapper(gm) -print(f"Replaced {num_replaced} nodes") -print(gm.graph) -y = gm(test_x) -assert torch.allclose(y, ref_output) - -static_x = torch.randn(4, in_dim).to("cuda") -static_output = torch.randn(4, out_dim).to("cuda") - -graph = torch.cuda.CUDAGraph() -with torch.cuda.graph(graph): - static_output.copy_(gm(static_x)) - -static_x.copy_(test_x) -graph.replay() - -assert torch.allclose(static_output, ref_output) -for i in range(100): - gm(torch.randn(4, in_dim).to("cuda")) - -for i in range(100): - graph.replay() From 4692586a7f764b2cfc1979fdc4dc6a11b56e6488 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Wed, 19 Nov 2025 19:02:14 -0800 Subject: [PATCH 03/15] revert tp plan changes Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../_torch/auto_deploy/models/patches/nemotron_h.py | 10 +++++++++- .../auto_deploy/transform/library/multi_stream_moe.py | 8 +------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py index 4dca4afd853..8248ab209f2 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py @@ -191,15 +191,23 @@ def get_model_from_config_patched(config, **kwargs): _config_from_pretrained_original = AutoConfig.from_pretrained _nemotron_h_base_model_tp_plan = { + # mamba SSM layer "in_proj": "mamba", "out_proj": "rowwise", + # attention layer "q_proj": "colwise", "k_proj": "colwise", "v_proj": "colwise", "o_proj": "rowwise", + # NOTE: consider not sharding shared experts and/or + # latent projections at all, keeping them replicated. + # To do so, comment out the corresponding entries. + # moe layer: SHARED experts "up_proj": "colwise", "down_proj": "rowwise", - # "*": "gather", + # MoLE: latent projections: simple shard + "fc1_latent_proj": "gather", + "fc2_latent_proj": "gather", } diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py index 1c6015ce558..a52d1de1708 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py @@ -66,10 +66,6 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - with open("graph-before-moe.txt", "w") as f: - f.write(str(gm.graph)) - print("wrote graph to graph-before-moe.txt") - op_dict = { torch.ops.auto_deploy.trtllm_moe_fused: torch.ops.auto_deploy.trtllm_moe_fused_aux, torch.ops.auto_deploy.triton_moe_fused: torch.ops.auto_deploy.triton_moe_fused_aux, @@ -84,7 +80,5 @@ def _apply( is_clean=False, has_valid_shapes=False, ) - with open("graph-after-moe.txt", "w") as f: - f.write(str(gm.graph)) - print("wrote graph to graph-after-moe.txt") + return gm, info From 948e5d49044e2ab05a4aeddeb20d838034fe6ac8 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Wed, 19 Nov 2025 21:44:14 -0800 Subject: [PATCH 04/15] remove redundant casts before rmsnorm Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../_torch/auto_deploy/custom_ops/multi_stream.py | 3 +++ tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py | 9 ++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py index 14902b9b081..de269acea47 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py @@ -64,6 +64,8 @@ def wait_event(event_name: str) -> None: event.wait() +# skip during compilation +@torch._dynamo.disable def record_event_wrapper( fn: Callable, *args: Tuple[Any, ...], @@ -74,6 +76,7 @@ def record_event_wrapper( return output +@torch._dynamo.disable def aux_stream_wrapper( fn: Callable, *args: Tuple[Any, ...], diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py index f4b98d49df0..708449ea732 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py @@ -113,21 +113,20 @@ def triton_rmsnorm_gated( # Flatten to (M, H), ensure last-dim contiguous, and run in fp32 x_shape = x.shape - x2 = x.to(torch.float32).reshape(-1, H) + x2 = x.reshape(-1, H) if x2.stride(-1) != 1: x2 = x2.contiguous() z2 = None if gate is not None: - z2 = gate.to(torch.float32).reshape(-1, H) + z2 = gate.reshape(-1, H) if z2.stride(-1) != 1: z2 = z2.contiguous() - - w = weight.to(torch.float32).contiguous() + assert weight.is_contiguous(), "weight must be contiguous" out2, _, _ = _layer_norm_fwd( x2, - w, + weight, None, # bias eps, z=z2, From 0439b3bbeedc0f903d8da16b9f05d3d654b1e815 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Thu, 20 Nov 2025 10:25:11 -0800 Subject: [PATCH 05/15] fused quant scale op Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../auto_deploy/custom_ops/fused_moe/trtllm_moe.py | 2 +- .../custom_ops/mamba/cuda_backend_causal_conv.py | 4 +--- .../custom_ops/mamba/triton_backend_mamba.py | 13 ++----------- 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 62e7b36dd94..f6f76fe1ca1 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -67,7 +67,7 @@ def trtllm_moe_fused_fake( return torch.empty_like(x) -# Todo: refactor this repeating code block +@torch.compile def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear).""" FP8_MIN = torch.finfo(torch.float8_e4m3fn).min diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index 33a7eb2a284..5f0a6d429ea 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -204,9 +204,7 @@ def _cuda_cached_causal_conv1d( if y_dec.dim() == 3: y_dec = y_dec.squeeze(-1) - y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_( - y_dec.to(y_flat.dtype) - ) + y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(y_dec) # Custom op must not return an alias of any input; return a fresh tensor return y diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 3ab13309009..f8339865abc 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -43,10 +43,6 @@ def _triton_ssm_prepare_metadata( seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) - seq_start = torch.zeros_like(seq_len_sanitized) - if num_seq > 1: - seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0) - # Truncate slot indices to match active sequences slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long) # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch @@ -88,7 +84,6 @@ def _triton_ssm_prepare_metadata( return ( seq_len_sanitized, - seq_start, slot_idx_sanitized, use_initial_states, cu_seqlens, @@ -109,7 +104,6 @@ def _triton_ssm_prepare_metadata_fake( device = slot_idx.device # Always-correct shapes seq_len_fake = torch.empty_like(seq_len_sanitized) - seq_start_fake = torch.empty_like(seq_len_sanitized) slot_idx_fake = torch.empty(num_seq, dtype=torch.long, device=device) use_initial_states_fake = torch.empty(num_seq, dtype=torch.bool, device=device) cu_seqlens_fake = torch.empty(num_seq + 1, dtype=torch.int32, device=device) @@ -142,7 +136,6 @@ def _triton_ssm_prepare_metadata_fake( return ( seq_len_fake, - seq_start_fake, slot_idx_fake, use_initial_states_fake, cu_seqlens_fake, @@ -165,7 +158,6 @@ def _triton_cached_ssm( dt_bias: torch.Tensor, # [num_heads] # METADATA seq_len: torch.Tensor, # [num_seq] - seq_start: torch.Tensor, # [num_seq] slot_idx: torch.Tensor, # [num_seq] use_initial_states: torch.Tensor, # [num_seq] cu_seqlens: torch.Tensor, # [num_seq + 1] @@ -290,7 +282,6 @@ def _triton_cached_ssm_fake( dt_bias: torch.Tensor, # [num_heads] # METADATA seq_len: torch.Tensor, # [num_seq] - seq_start: torch.Tensor, # [num_seq] slot_idx: torch.Tensor, # [num_seq] use_initial_states: torch.Tensor, # [num_seq] cu_seqlens: torch.Tensor, # [num_seq + 1] @@ -340,9 +331,9 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: - # Returns: seq_len, seq_start, slot_idx, use_initial_states, + # Returns: seq_len, slot_idx, use_initial_states, # cu_seqlens, chunk_indices, chunk_offsets, seq_idx_prefill, batch_info_tensor - return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 9 + return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 8 @classmethod def get_cache_initializers( From 9c955929d815a46c4e02c8bcb2a073a8fa87b5b6 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:06:34 -0800 Subject: [PATCH 06/15] more quant moe opt Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../custom_ops/fused_moe/trtllm_moe.py | 24 ++++++---- .../auto_deploy/custom_ops/multi_stream.py | 9 ++++ .../transform/library/fused_moe.py | 46 +++++++++++++++---- 3 files changed, 62 insertions(+), 17 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index f6f76fe1ca1..686c77f0595 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -67,6 +67,8 @@ def trtllm_moe_fused_fake( return torch.empty_like(x) +# NOTE(suyogg): If compile ever fails because of this, just write a triton kernel +# for this function and use it as a custom op. @torch.compile def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear).""" @@ -107,6 +109,9 @@ def trtllm_quant_fp8_moe_fused( w1_weight_scale: torch.Tensor, # [E] stacked weight scales w2_weight_scale: torch.Tensor, # [E] stacked weight scales w3_weight_scale: torch.Tensor, # [E] or unused + gemm1_dequant: torch.Tensor, # [E] + gemm2_act_quant: torch.Tensor, # [E] + gemm2_dequant: torch.Tensor, # [E] mlp_style: str = "gated_mlp", act_fn: str = "silu", ) -> torch.Tensor: @@ -144,10 +149,10 @@ def trtllm_quant_fp8_moe_fused( x_q_fp8 = _quantize_fp8(x2d, w1_input_scale[0]) # Scales are stored in float32 - w1_weight_scale = w1_weight_scale.to(torch.float32) - w2_weight_scale = w2_weight_scale.to(torch.float32) + # w1_weight_scale = w1_weight_scale.to(torch.float32) + # w2_weight_scale = w2_weight_scale.to(torch.float32) w1_input_scale = w1_input_scale.to(torch.float32)[0] - w2_input_scale = w2_input_scale.to(torch.float32)[0] + # w2_input_scale = w2_input_scale.to(torch.float32)[0] # Prepare quant_scales for TensorRT-LLM FP8 format: # [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale] @@ -158,14 +163,14 @@ def trtllm_quant_fp8_moe_fused( # - gemm1_input_dequant_scale: w1_input_scale # Compute combined scales - gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze() - gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32) - gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze() - gemm1_input_dequant = w1_input_scale.contiguous() + # gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze() + # gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32) + # gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze() + # gemm1_input_dequant = w1_input_scale.contiguous() assert gemm1_dequant.ndim == 1, "gemm1_dequant must be 1D" assert gemm2_dequant.ndim == 1, "gemm2_dequant must be 1D" - quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, gemm1_input_dequant] + quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, w1_input_scale] # Ensure contiguous tensors selected_experts = selected_experts.int().contiguous() @@ -229,6 +234,9 @@ def trtllm_quant_fp8_moe_fused_fake( w1_weight_scale: torch.Tensor, w2_weight_scale: torch.Tensor, w3_weight_scale: torch.Tensor, + gemm1_dequant: torch.Tensor, + gemm2_act_quant: torch.Tensor, + gemm2_dequant: torch.Tensor, mlp_style: str, act_fn: str, ) -> torch.Tensor: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py index de269acea47..3629ea9c12e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py @@ -180,6 +180,9 @@ def trtllm_quant_fp8_moe_fused_aux( w1_weight_scale: torch.Tensor, # [E] stacked weight scales w2_weight_scale: torch.Tensor, # [E] stacked weight scales w3_weight_scale: torch.Tensor, # [E] or unused + gemm1_dequant: torch.Tensor, # [E] + gemm2_act_quant: torch.Tensor, # [E] + gemm2_dequant: torch.Tensor, # [E] mlp_style: str = "gated_mlp", act_fn: str = "silu", ) -> torch.Tensor: @@ -198,6 +201,9 @@ def trtllm_quant_fp8_moe_fused_aux( w1_weight_scale, w2_weight_scale, w3_weight_scale, + gemm1_dequant, + gemm2_act_quant, + gemm2_dequant, mlp_style, act_fn, ) @@ -220,6 +226,9 @@ def trtllm_quant_fp8_moe_fused_fake( w1_weight_scale: torch.Tensor, w2_weight_scale: torch.Tensor, w3_weight_scale: torch.Tensor, + gemm1_dequant: torch.Tensor, + gemm2_act_quant: torch.Tensor, + gemm2_dequant: torch.Tensor, mlp_style: str = "gated_mlp", act_fn: str = "silu", ) -> torch.Tensor: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index b62fd9f2b9c..a6e1b5ad383 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -653,20 +653,31 @@ def get_param_or_buffer(target): ) ) - w1_weight_scale_stacked = torch.stack( - [get_param_or_buffer(n.target) for n in w1_weight_scale], dim=0 + w1_weight_scale_stacked = ( + torch.stack([get_param_or_buffer(n.target) for n in w1_weight_scale], dim=0) + .to(torch.float32) + .contiguous() ) - w2_weight_scale_stacked = torch.stack( - [get_param_or_buffer(n.target) for n in w2_weight_scale], dim=0 + w2_weight_scale_stacked = ( + torch.stack([get_param_or_buffer(n.target) for n in w2_weight_scale], dim=0) + .to(torch.float32) + .contiguous() ) w3_weight_scale_stacked = ( - torch.stack([get_param_or_buffer(n.target) for n in w3_weight_scale], dim=0) - if w3_weight_scale - else torch.empty( - 0, device=w1_weight_scale_stacked.device, dtype=w1_weight_scale_stacked.dtype + ( + torch.stack([get_param_or_buffer(n.target) for n in w3_weight_scale], dim=0) + if w3_weight_scale + else torch.empty( + 0, device=w1_weight_scale_stacked.device, dtype=w1_weight_scale_stacked.dtype + ) ) + .to(torch.float32) + .contiguous() ) + gemm1_dequant = (w1_weight_scale_stacked * w1_input_scale_stacked[0]).squeeze() + gemm2_act_quant = (1.0 / w2_input_scale_stacked[0]).to(torch.float32) + gemm2_dequant = (w2_weight_scale_stacked * w2_input_scale_stacked[0]).squeeze() # Register stacked tensors as new parameters new_key_w1 = f"quant_moe_w1_stacked_{fused_key_counter}" new_key_w2 = f"quant_moe_w2_stacked_{fused_key_counter}" @@ -677,6 +688,9 @@ def get_param_or_buffer(target): new_key_w1_weight_scale = f"quant_moe_w1_weight_scale_stacked_{fused_key_counter}" new_key_w2_weight_scale = f"quant_moe_w2_weight_scale_stacked_{fused_key_counter}" new_key_w3_weight_scale = f"quant_moe_w3_weight_scale_stacked_{fused_key_counter}" + new_key_gemm1_dequant = f"quant_moe_gemm1_dequant_stacked_{fused_key_counter}" + new_key_gemm2_act_quant = f"quant_moe_gemm2_act_quant_stacked_{fused_key_counter}" + new_key_gemm2_dequant = f"quant_moe_gemm2_dequant_stacked_{fused_key_counter}" fused_key_counter += 1 @@ -705,7 +719,18 @@ def get_param_or_buffer(target): new_key_w3_weight_scale, torch.nn.Parameter(w3_weight_scale_stacked, requires_grad=False), ) - + gm.register_parameter( + new_key_gemm1_dequant, + torch.nn.Parameter(gemm1_dequant, requires_grad=False), + ) + gm.register_parameter( + new_key_gemm2_act_quant, + torch.nn.Parameter(gemm2_act_quant, requires_grad=False), + ) + gm.register_parameter( + new_key_gemm2_dequant, + torch.nn.Parameter(gemm2_dequant, requires_grad=False), + ) # Create new node with get_attr for stacked parameters with graph.inserting_before(node): new_node = graph.call_function( @@ -723,6 +748,9 @@ def get_param_or_buffer(target): graph.get_attr(new_key_w1_weight_scale), graph.get_attr(new_key_w2_weight_scale), graph.get_attr(new_key_w3_weight_scale), + graph.get_attr(new_key_gemm1_dequant), + graph.get_attr(new_key_gemm2_act_quant), + graph.get_attr(new_key_gemm2_dequant), ), kwargs=node.kwargs, ) From f78ae5ada0917baed0bcd69a64a388b80948b00c Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Sun, 23 Nov 2025 18:16:15 -0800 Subject: [PATCH 07/15] minor refactoring Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../transform/library/fused_moe.py | 50 +++++++++++-------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 85581b69dc0..bdfc752778e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -698,9 +698,6 @@ def get_param_or_buffer(target): .contiguous() ) - gemm1_dequant = (w1_weight_scale_stacked * w1_input_scale_stacked[0]).squeeze() - gemm2_act_quant = (1.0 / w2_input_scale_stacked[0]).to(torch.float32) - gemm2_dequant = (w2_weight_scale_stacked * w2_input_scale_stacked[0]).squeeze() # Register stacked tensors as new parameters new_key_w1 = f"quant_moe_w1_stacked_{fused_key_counter}" new_key_w2 = f"quant_moe_w2_stacked_{fused_key_counter}" @@ -711,9 +708,6 @@ def get_param_or_buffer(target): new_key_w1_weight_scale = f"quant_moe_w1_weight_scale_stacked_{fused_key_counter}" new_key_w2_weight_scale = f"quant_moe_w2_weight_scale_stacked_{fused_key_counter}" new_key_w3_weight_scale = f"quant_moe_w3_weight_scale_stacked_{fused_key_counter}" - new_key_gemm1_dequant = f"quant_moe_gemm1_dequant_stacked_{fused_key_counter}" - new_key_gemm2_act_quant = f"quant_moe_gemm2_act_quant_stacked_{fused_key_counter}" - new_key_gemm2_dequant = f"quant_moe_gemm2_dequant_stacked_{fused_key_counter}" fused_key_counter += 1 @@ -742,18 +736,34 @@ def get_param_or_buffer(target): new_key_w3_weight_scale, torch.nn.Parameter(w3_weight_scale_stacked, requires_grad=False), ) - gm.register_parameter( - new_key_gemm1_dequant, - torch.nn.Parameter(gemm1_dequant, requires_grad=False), - ) - gm.register_parameter( - new_key_gemm2_act_quant, - torch.nn.Parameter(gemm2_act_quant, requires_grad=False), - ) - gm.register_parameter( - new_key_gemm2_dequant, - torch.nn.Parameter(gemm2_dequant, requires_grad=False), - ) + additional_args = [] + if backend == "trtllm": + # For optimization reasons, we precompute a few additional arguments to the trtllm_quant_fp8_moe_fused op + # to avoid computing them at runtime. + gemm1_dequant = (w1_weight_scale_stacked * w1_input_scale_stacked[0]).squeeze() + gemm2_act_quant = (1.0 / w2_input_scale_stacked[0]).to(torch.float32) + gemm2_dequant = (w2_weight_scale_stacked * w2_input_scale_stacked[0]).squeeze() + + new_key_gemm1_dequant = f"quant_moe_gemm1_dequant_stacked_{fused_key_counter}" + new_key_gemm2_act_quant = f"quant_moe_gemm2_act_quant_stacked_{fused_key_counter}" + new_key_gemm2_dequant = f"quant_moe_gemm2_dequant_stacked_{fused_key_counter}" + gm.register_parameter( + new_key_gemm1_dequant, + torch.nn.Parameter(gemm1_dequant, requires_grad=False), + ) + gm.register_parameter( + new_key_gemm2_act_quant, + torch.nn.Parameter(gemm2_act_quant, requires_grad=False), + ) + gm.register_parameter( + new_key_gemm2_dequant, + torch.nn.Parameter(gemm2_dequant, requires_grad=False), + ) + additional_args = [ + graph.get_attr(new_key_gemm1_dequant), + graph.get_attr(new_key_gemm2_act_quant), + graph.get_attr(new_key_gemm2_dequant), + ] # Create new node with get_attr for stacked parameters with graph.inserting_before(node): new_node = graph.call_function( @@ -771,9 +781,7 @@ def get_param_or_buffer(target): graph.get_attr(new_key_w1_weight_scale), graph.get_attr(new_key_w2_weight_scale), graph.get_attr(new_key_w3_weight_scale), - graph.get_attr(new_key_gemm1_dequant), - graph.get_attr(new_key_gemm2_act_quant), - graph.get_attr(new_key_gemm2_dequant), + *additional_args, ), kwargs=node.kwargs, ) From 73e2b1b2c75ff7de8a499bbb45d293c212990511 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Sun, 23 Nov 2025 18:20:47 -0800 Subject: [PATCH 08/15] update some comments Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../auto_deploy/custom_ops/fused_moe/trtllm_moe.py | 9 +++++---- .../_torch/auto_deploy/custom_ops/multi_stream.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 686c77f0595..eab6a3bf224 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -130,6 +130,9 @@ def trtllm_quant_fp8_moe_fused( w1_weight_scale: Weight scales for w1 [E] w2_weight_scale: Weight scales for w2 [E] w3_weight_scale: Weight scales for w3 [E] + gemm1_dequant: Precomputed gemm1 dequant scale [E] + gemm2_act_quant: Precomputed gemm2 act quant scale [1] + gemm2_dequant: Precomputed gemm2 dequant scale [E] mlp_style: "gated_mlp" or "mlp" act_fn: "silu" for gated_mlp, "relu2" for mlp @@ -149,14 +152,12 @@ def trtllm_quant_fp8_moe_fused( x_q_fp8 = _quantize_fp8(x2d, w1_input_scale[0]) # Scales are stored in float32 - # w1_weight_scale = w1_weight_scale.to(torch.float32) - # w2_weight_scale = w2_weight_scale.to(torch.float32) - w1_input_scale = w1_input_scale.to(torch.float32)[0] - # w2_input_scale = w2_input_scale.to(torch.float32)[0] + w1_input_scale = w1_input_scale[0] # Prepare quant_scales for TensorRT-LLM FP8 format: # [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale] # For gated MLP: + # These are precomputed in `fused_moe` transform # - gemm1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3) # - gemm2_act_quant_scale: 1 / w2_input_scale # - gemm2_dequant_scale: w2_weight_scale * w2_input_scale diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py index 3629ea9c12e..871374155e8 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py @@ -91,7 +91,7 @@ def aux_stream_wrapper( return output -# bf16 +# trtllm bf16 @torch.library.custom_op("auto_deploy::trtllm_moe_fused_aux", mutates_args=()) def trtllm_moe_fused_aux( x: torch.Tensor, @@ -213,7 +213,7 @@ def trtllm_quant_fp8_moe_fused_aux( @trtllm_quant_fp8_moe_fused_aux.register_fake -def trtllm_quant_fp8_moe_fused_fake( +def trtllm_quant_fp8_moe_fused_aux_fake( x: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor, From df1daca313f9987206cd4e9058c8d1439636e68f Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Sun, 23 Nov 2025 20:45:43 -0800 Subject: [PATCH 09/15] fix tests Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- examples/auto_deploy/nano_v3.yaml | 3 +++ tensorrt_llm/_torch/auto_deploy/config/default.yaml | 3 --- .../unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/auto_deploy/nano_v3.yaml b/examples/auto_deploy/nano_v3.yaml index 411037cc175..9d9acf6ef7f 100644 --- a/examples/auto_deploy/nano_v3.yaml +++ b/examples/auto_deploy/nano_v3.yaml @@ -15,6 +15,9 @@ transforms: detect_sharding: sharding_source: ['factory', 'heuristic'] sharding_dims: ['ep', 'bmm'] + multi_stream_moe: + stage: compile + enabled: true # tunable mamba cache dtype # --> use float32 for accuracy and default (null) for speed insert_cached_ssm_attention: diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index d486ff5061e..55416141e73 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -163,9 +163,6 @@ transforms: ############################################################################################ fuse_causal_conv_activation: stage: compile - multi_stream_moe: - stage: compile - enabled: true compile_model: stage: compile run_per_gm: false diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py index 4b1c373b0fc..917cdbaca2e 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py @@ -163,7 +163,6 @@ def test_triton_context_flattened_and_state_writeback(mamba_env): dt, dt_bias, seq_len, - seq_start, slot_idx, use_initial_states, cu_seqlens, From 9a1cf7e5f9d53ca754f97663431b13c0b15aeede Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Sun, 23 Nov 2025 21:00:05 -0800 Subject: [PATCH 10/15] update ad test Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../_torch/auto_deploy/_utils_test/_model_test_utils.py | 6 ++++++ .../unit/singlegpu/test_ad_build_small_single.py | 8 ++++++++ 2 files changed, 14 insertions(+) diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py index 6b9bf92a9f7..af821955d49 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py @@ -502,6 +502,12 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1): "num_hidden_layers": 2, }, }, + "nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024": { + "llm_models_subdir": "Nemotron-Nano-3-30B-A3.5B-dev-1024", + "model_kwargs": { + "num_hidden_layers": 8, + }, + }, } diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py index 4e1e78bd97d..320dbdcfa62 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py @@ -186,6 +186,14 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): }, }, ), + ( + "nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024", + { + "transforms": { + "multi_stream_moe": {"stage": "compile", "enabled": True}, + }, + }, + ), ], ) def test_build_ad(model_hub_id: str, llm_extra_args: dict): From 5b411c59dc96967b1f5f57c1a143baa4a3dc630a Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Mon, 24 Nov 2025 10:31:16 -0800 Subject: [PATCH 11/15] fix test failures, address review comments Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> --- .../auto_deploy/custom_ops/fused_moe/trtllm_moe.py | 6 ------ .../_torch/auto_deploy/transform/library/fused_moe.py | 8 +++++++- .../auto_deploy/transform/library/multi_stream_moe.py | 10 +++------- .../unit/singlegpu/custom_ops/test_multi_stream.py | 4 +--- .../unit/singlegpu/custom_ops/test_trtllm_moe.py | 7 +++++++ 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index eab6a3bf224..8b130d98744 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -163,12 +163,6 @@ def trtllm_quant_fp8_moe_fused( # - gemm2_dequant_scale: w2_weight_scale * w2_input_scale # - gemm1_input_dequant_scale: w1_input_scale - # Compute combined scales - # gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze() - # gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32) - # gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze() - # gemm1_input_dequant = w1_input_scale.contiguous() - assert gemm1_dequant.ndim == 1, "gemm1_dequant must be 1D" assert gemm2_dequant.ndim == 1, "gemm2_dequant must be 1D" quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, w1_input_scale] diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index bdfc752778e..534bd130ac3 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -697,7 +697,12 @@ def get_param_or_buffer(target): .to(torch.float32) .contiguous() ) - + assert torch.all(w1_input_scale_stacked[0] == w1_input_scale_stacked), ( + "All w1 scales should have the same value." + ) + assert torch.all(w2_input_scale_stacked[0] == w2_input_scale_stacked), ( + "All w2 scales should have the same value." + ) # Register stacked tensors as new parameters new_key_w1 = f"quant_moe_w1_stacked_{fused_key_counter}" new_key_w2 = f"quant_moe_w2_stacked_{fused_key_counter}" @@ -736,6 +741,7 @@ def get_param_or_buffer(target): new_key_w3_weight_scale, torch.nn.Parameter(w3_weight_scale_stacked, requires_grad=False), ) + additional_args = [] if backend == "trtllm": # For optimization reasons, we precompute a few additional arguments to the trtllm_quant_fp8_moe_fused op diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py index a52d1de1708..a0ec07777b3 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py @@ -3,7 +3,7 @@ from typing import Callable, Dict, Tuple import torch -from torch.fx import GraphModule, Node +from torch.fx import GraphModule from tensorrt_llm._torch.auto_deploy.custom_ops.multi_stream import record_event_wrapper @@ -20,10 +20,7 @@ def _execute_op_in_aux_stream( num_replaced = 0 # Collect targets first to avoid mutating while iterating - target_nodes: list[Node] = [] - for n in graph.nodes: - if is_op(n, op_dict.keys()): - target_nodes.append(n) + target_nodes = [n for n in graph.nodes if is_op(n, op_dict.keys())] for n in target_nodes: target_input_node = None @@ -32,8 +29,7 @@ def _execute_op_in_aux_stream( target_input_node = input_node break - if target_input_node is None: - raise ValueError(f"Target input node not found for node {n}") + assert target_input_node is not None, f"Target input node not found for node {n}" with graph.inserting_before(target_input_node): new_node = graph.call_function( record_event_wrapper, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py index e008c4a99d8..972cf013a3b 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py @@ -45,9 +45,7 @@ def replace_multi_stream_linear_with_aux_stream_wrapper(gm: GraphModule) -> Tupl # Collect targets first to avoid mutating while iterating target_nodes: list[Node] = [] - for n in graph.nodes: - if is_op(n, torch.ops.auto_deploy.multi_stream_linear): - target_nodes.append(n) + target_nodes = [n for n in graph.nodes if is_op(n, torch.ops.auto_deploy.multi_stream_linear)] for n in target_nodes: target_input_node = None diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index 5eb6bcfaa38..3e13e28a0c5 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -404,6 +404,10 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE w3_weight, w1_weight = torch.chunk(w31_weight, 2, dim=1) mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" + # compute quant_scales + gemm1_dequant = (w1_scales * hidden_states_scale).contiguous().squeeze().to(torch.float32) + gemm2_act_quant = (1.0 / w2_input_scale[0]).contiguous().to(torch.float32) + gemm2_dequant = (w2_scales * w2_input_scale[0]).contiguous().squeeze().to(torch.float32) ad_test_output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused( x, # Note! unquantized input is expected selected_experts.to(torch.int), @@ -417,6 +421,9 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE w1_weight_scale=w1_scales, w2_weight_scale=w2_scales, w3_weight_scale=w3_scales, + gemm1_dequant=gemm1_dequant, + gemm2_act_quant=gemm2_act_quant, + gemm2_dequant=gemm2_dequant, mlp_style=mlp_style, act_fn=activation_func, ) From 4783eb30d29af269b7ff9dc6e9ebb32a36fac695 Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:13:40 -0800 Subject: [PATCH 12/15] Precompute the A log for mamba layers Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 2 + .../transform/library/fuse_mamba_a_log.py | 140 ++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 55416141e73..5e1c43fff95 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -96,6 +96,8 @@ transforms: ############################################################################################ # RUN POST-LOAD FUSION AND OPTIMIZATIONS ############################################################################################ + fuse_mamba_a_log: + stage: post_load_fusion fuse_gemms: stage: post_load_fusion enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py new file mode 100644 index 00000000000..47bdb374eba --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py @@ -0,0 +1,140 @@ +"""Transform to fuse A_log into A for Mamba/NemotronH models.""" + +import operator +from typing import Tuple + +import torch +import torch.nn as nn +from torch.fx import GraphModule + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.logger import ad_logger +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +def _get_attr_by_name(obj, name): + for part in name.split("."): + obj = getattr(obj, part) + return obj + + +def _set_attr_by_name(obj, name, value): + parts = name.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + +@TransformRegistry.register("fuse_mamba_a_log") +class FuseMambaALog(BaseTransform): + """Fuse A_log parameter into A constant/parameter. + + Replaces: + A = -torch.exp(self.A_log.float()) + With: + A = self.A_fused + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + num_matches = 0 + + # Candidates for operations + exp_ops = {torch.exp, torch.ops.aten.exp.default, "exp"} + neg_ops = {operator.neg, torch.neg, torch.ops.aten.neg.default, "neg"} + + # We search bottom-up starting from A_log parameters to be more robust + # pattern: A_log -> [optional cast] -> exp -> neg + + # Snapshot nodes to avoid modification issues during iteration + nodes = list(gm.graph.nodes) + + for node in nodes: + if node.op != "get_attr": + continue + + if not node.target.endswith("A_log"): + continue + # Found an A_log node. Check its usage. + users = list(node.users.keys()) + for user in users: + # 1. Check for optional Cast + current_node = user + + # Skip cast/to nodes + exp_node = None + + # Walk forward looking for exp + cursor = current_node + for _ in range(3): # Max depth for casts + if (cursor.op == "call_function" and cursor.target in exp_ops) or ( + cursor.op == "call_method" and cursor.target == "exp" + ): + exp_node = cursor + break + + if len(cursor.users) != 1: + break + cursor = list(cursor.users.keys())[0] + + if not exp_node: + continue + + # 2. Check for Neg + if len(exp_node.users) != 1: + continue + + neg_node = list(exp_node.users.keys())[0] + is_neg = (neg_node.op == "call_function" and neg_node.target in neg_ops) or ( + neg_node.op == "call_method" and neg_node.target == "neg" + ) + + if not is_neg: + continue + # Found the pattern: node -> ... -> exp_node -> neg_node + num_matches += 1 + + # Perform Fusion + param_name = node.target + try: + a_log = _get_attr_by_name(gm, param_name) + except AttributeError: + ad_logger.warning(f"Could not find attribute {param_name} in gm.") + continue + + # Compute A_fused + with torch.no_grad(): + # Replicate the logic: -exp(a_log.float()) + a_fused = -torch.exp(a_log.float()) + + new_param_name = param_name.replace("A_log", "A_fused") + + # Check if we already created this param (if A_log used multiple times) + try: + _get_attr_by_name(gm, new_param_name) + except AttributeError: + _set_attr_by_name( + gm, new_param_name, nn.Parameter(a_fused, requires_grad=False) + ) + + # Replace usage + with gm.graph.inserting_before(neg_node): + new_node = gm.graph.create_node("get_attr", new_param_name) + + neg_node.replace_all_uses_with(new_node) + + if num_matches > 0: + gm.graph.eliminate_dead_code() + + return gm, TransformInfo( + skipped=False, + num_matches=num_matches, + is_clean=num_matches == 0, + has_valid_shapes=True, + ) From 06e2fb3a4ec52fc7aaef6d1c02cd401e046a237c Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Thu, 20 Nov 2025 17:44:57 -0800 Subject: [PATCH 13/15] Perf update for mamba layers Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --- .../mamba/cuda_backend_causal_conv.py | 36 +++++++++---------- .../custom_ops/mamba/triton_backend_mamba.py | 31 ++++++++++++---- 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index 5f0a6d429ea..59069b1cf16 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -94,7 +94,7 @@ def cuda_causal_conv_prepare_metadata_fake( ) -@torch.library.custom_op("auto_deploy::cuda_cached_causal_conv1d", mutates_args={}) +@torch.library.custom_op("auto_deploy::cuda_cached_causal_conv1d", mutates_args={"input"}) def _cuda_cached_causal_conv1d( # INPUTS (dense but may be flattened across sequences) input: torch.Tensor, # [b, s, c_in] @@ -114,13 +114,15 @@ def _cuda_cached_causal_conv1d( groups: int, padding_mode: str, activation: Optional[str], -) -> torch.Tensor: +) -> None: """Flattened cached causal conv that respects slot-indexed state caches (CUDA backend). Supports two layouts from the attention interface: - Generate-only: input is [b, 1, c_in]. We'll gather caches using slot_idx[:b]. - Flattened context/mixed: input is [1, total_s, c_in] and seq_len/seq_start describe per-sequence segments. We'll process each segment and scatter final states to caches. + + NOTE: This op modifies `input` in-place. """ b, s = input.shape[:2] num_seq = seq_len.shape[0] @@ -137,8 +139,6 @@ def _cuda_cached_causal_conv1d( # Flatten tokens bs = b * s inp_flat = input.reshape(bs, *input.shape[2:]) # [total_s, C_in] - y = torch.empty(b, s, weight.shape[0], device=input.device, dtype=input.dtype) - y_flat = y.view(bs, *y.shape[2:]) # Prepare weight as [dim, width] (depthwise) if weight.ndim == 3: @@ -155,6 +155,7 @@ def _cuda_cached_causal_conv1d( total_prefill_tokens = int(seq_len_prefill.sum().item()) # x_varlen: (dim, cu_seq_len) + # We must clone to make it contiguous for the kernel x_varlen = inp_flat[:total_prefill_tokens].transpose(0, 1).contiguous() # Metadata @@ -181,9 +182,8 @@ def _cuda_cached_causal_conv1d( pad_slot_id=PAD_SLOT_ID, ) # (dim, total_prefill_tokens) - # Scatter outputs back to y - y_prefill = y_varlen.transpose(0, 1) # [total_prefill_tokens, C_out] - y_flat[:total_prefill_tokens].copy_(y_prefill) + # Scatter outputs back to input buffer + inp_flat[:total_prefill_tokens] = y_varlen.transpose(0, 1) # DECODE: batch update for single-token sequences if num_decode > 0: @@ -191,7 +191,8 @@ def _cuda_cached_causal_conv1d( total_prefill_tokens : total_prefill_tokens + num_decode ] # [num_decode, C_in] - y_dec = causal_conv1d_update( + # causal_conv1d_update modifies x_decode in-place + causal_conv1d_update( x_decode, # [batch, dim] conv_state_cache, w2d, @@ -201,13 +202,9 @@ def _cuda_cached_causal_conv1d( conv_state_indices=slot_idx[num_prefill:].to(torch.int32), pad_slot_id=PAD_SLOT_ID, ) + # No copy needed! - if y_dec.dim() == 3: - y_dec = y_dec.squeeze(-1) - y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(y_dec) - - # Custom op must not return an alias of any input; return a fresh tensor - return y + return @_cuda_cached_causal_conv1d.register_fake @@ -231,9 +228,12 @@ def _cuda_cached_causal_conv1d_fake( padding_mode: str, activation: Optional[str], ): - return torch.empty( - input.shape[0], input.shape[1], weight.shape[0], device=input.device, dtype=input.dtype - ) + return + + +def cuda_cached_causal_conv1d_wrapper(input, *args, **kwargs): + torch.ops.auto_deploy.cuda_cached_causal_conv1d(input, *args, **kwargs) + return input @AttentionRegistry.register("cuda_causal_conv") @@ -259,7 +259,7 @@ def get_source_attention_op(cls) -> OpOverloadPacket: @classmethod def get_cached_attention_op(cls) -> MHACallable: - return torch.ops.auto_deploy.cuda_cached_causal_conv1d + return cuda_cached_causal_conv1d_wrapper @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index f8339865abc..70f074a1825 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -185,14 +185,12 @@ def _triton_cached_ssm( C_flat = C.reshape(bs, *C.shape[2:]) # [bs, G, N] dt_flat = dt.reshape(bs, dt.shape[2]) # [bs, H] - y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format) - y_flat = y.view(bs, *y.shape[2:]) - ssm_state_size = B.shape[3] num_prefill, num_prefill_tokens, num_decode = batch_info_tensor.tolist() # Prefill: concatenate tokens at the front and run combined scan + y_prefill = None if num_prefill > 0: hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D] B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] @@ -232,12 +230,15 @@ def _triton_cached_ssm( mamba_ssm_cache_dtype=ssm_state_cache.dtype, ) - y_flat[:num_prefill_tokens] = y_prefill[0].to(y_flat.dtype) ssm_state_cache.index_copy_( 0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype) ) + # y_prefill is [1, S_p, H, D] -> remove batch dim + y_prefill = y_prefill[0] + # Decode: batch single-token updates via selective_state_update + y_dec = None if num_decode > 0: slot_idx_decode = slot_idx[num_prefill:] @@ -265,9 +266,27 @@ def _triton_cached_ssm( state_batch_indices=slot_idx_decode, ) # [nd, H, D] - y_flat[num_prefill_tokens : num_prefill_tokens + num_decode].copy_(y_dec.to(y_flat.dtype)) + # Combine results + if num_prefill > 0 and num_decode > 0: + # Concatenate prefill and decode outputs to form the final flattened output + # Both need to be the same dtype + y_flat = torch.cat( + [y_prefill.to(hidden_states.dtype), y_dec.to(hidden_states.dtype)], dim=0 + ) + elif num_prefill > 0: + y_flat = y_prefill.to(hidden_states.dtype) + elif num_decode > 0: + y_flat = y_dec.to(hidden_states.dtype) + else: + # Should not happen given input shapes, but handle empty case + y_flat = torch.empty( + 0, num_heads, head_dim, device=hidden_states.device, dtype=hidden_states.dtype + ) - return y + # Reshape back to [B, S, H, D] if needed, or return flat if layout allows + # The original code reshaped y_flat into y [b, s, h, d] via view at the start. + # We constructed y_flat directly, so we just view it back to original shape. + return y_flat.view(b, s, num_heads, head_dim) @_triton_cached_ssm.register_fake From ad94c87e4b243b8e6f497d88d2e577be12cad503 Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Thu, 20 Nov 2025 18:04:05 -0800 Subject: [PATCH 14/15] Re-enable the activation fusion Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --- .../auto_deploy/transform/library/fuse_causal_conv.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py index 3acc8e1f80f..8020ea72e11 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py @@ -85,10 +85,17 @@ def _apply( ) -> Tuple[GraphModule, TransformInfo]: graph = gm.graph + # Import wrapper to match against + # We use the wrapper because the underlying op returns None (void) to avoid aliasing, + # but the wrapper returns the tensor to maintain graph data flow. + from ...custom_ops.mamba.cuda_backend_causal_conv import cuda_cached_causal_conv1d_wrapper + + target_op = cuda_cached_causal_conv1d_wrapper + # Step 1: Identify causal_conv + activation pattern matches = _match_causal_conv_activation_pattern( graph, - target_op=torch.ops.auto_deploy.cuda_cached_causal_conv1d, + target_op=target_op, ) # Step 2: Replace matched patterns with fused version @@ -98,7 +105,7 @@ def _apply( # Replace the last arg (activation=None) with activation_name new_args = list(conv_node.args[:-1]) + [activation_name] fused_node = graph.call_function( - torch.ops.auto_deploy.cuda_cached_causal_conv1d, + target_op, args=tuple(new_args), ) From 448700e03ab569739456f921e7d33efd8fed01dc Mon Sep 17 00:00:00 2001 From: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Date: Thu, 20 Nov 2025 21:58:29 -0800 Subject: [PATCH 15/15] Change the concat to copy_ to reduce the memory usage Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> --- .../custom_ops/mamba/triton_backend_mamba.py | 34 +++++++------------ 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 70f074a1825..37cce941dbe 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -189,8 +189,10 @@ def _triton_cached_ssm( num_prefill, num_prefill_tokens, num_decode = batch_info_tensor.tolist() - # Prefill: concatenate tokens at the front and run combined scan y_prefill = None + y_dec = None + + # Prefill: concatenate tokens at the front and run combined scan if num_prefill > 0: hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D] B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] @@ -234,11 +236,7 @@ def _triton_cached_ssm( 0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype) ) - # y_prefill is [1, S_p, H, D] -> remove batch dim - y_prefill = y_prefill[0] - # Decode: batch single-token updates via selective_state_update - y_dec = None if num_decode > 0: slot_idx_decode = slot_idx[num_prefill:] @@ -266,27 +264,19 @@ def _triton_cached_ssm( state_batch_indices=slot_idx_decode, ) # [nd, H, D] - # Combine results + # Dispatch return logic if num_prefill > 0 and num_decode > 0: - # Concatenate prefill and decode outputs to form the final flattened output - # Both need to be the same dtype - y_flat = torch.cat( - [y_prefill.to(hidden_states.dtype), y_dec.to(hidden_states.dtype)], dim=0 - ) + y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format) + y_flat = y.view(bs, *y.shape[2:]) + y_flat[:num_prefill_tokens].copy_(y_prefill[0]) + y_flat[num_prefill_tokens : num_prefill_tokens + num_decode].copy_(y_dec) + return y elif num_prefill > 0: - y_flat = y_prefill.to(hidden_states.dtype) + return y_prefill[0].view(b, s, num_heads, head_dim).to(hidden_states.dtype) elif num_decode > 0: - y_flat = y_dec.to(hidden_states.dtype) + return y_dec.view(b, s, num_heads, head_dim).to(hidden_states.dtype) else: - # Should not happen given input shapes, but handle empty case - y_flat = torch.empty( - 0, num_heads, head_dim, device=hidden_states.device, dtype=hidden_states.dtype - ) - - # Reshape back to [B, S, H, D] if needed, or return flat if layout allows - # The original code reshaped y_flat into y [b, s, h, d] via view at the start. - # We constructed y_flat directly, so we just view it back to original shape. - return y_flat.view(b, s, num_heads, head_dim) + return torch.empty_like(hidden_states) @_triton_cached_ssm.register_fake