Skip to content

Commit 81bb39d

Browse files
committed
fixed tests, readme, removed rms torch pattern
Signed-off-by: Eran Geva <[email protected]>
1 parent 9e0b47a commit 81bb39d

File tree

7 files changed

+26
-44
lines changed

7 files changed

+26
-44
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/README.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ The table below lists the operators ordered by their backend.
1717
| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported |
1818
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
1919
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
20-
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation |
21-
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation |
20+
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation (PyTorch backend, demollm mode) |
21+
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation (PyTorch backend, demollm mode) |
22+
| `torch.ops.auto_deploy.torch_fused_linear_all_reduce` | Fused linear layer followed by all-reduce (PyTorch backend, demollm mode) |
23+
| `torch.ops.auto_deploy.torch_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce (PyTorch backend, demollm mode) |
2224
| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation |
2325
| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation |
2426
| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation |
2527
| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values |
26-
| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce operation |
28+
| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Legacy name for `torch_fused_fp8_linear_all_reduce` |
2729
| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | FP4 quantized linear layer |
2830
| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer |
2931
| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies |
@@ -38,4 +40,10 @@ The table below lists the operators ordered by their backend.
3840
| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs |
3941
| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions |
4042
| `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT LLM fused MoE implementation |
41-
| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | TensorRT LLM fused linear layer followed by all-reduce operation |
43+
| `torch.ops.auto_deploy.trtllm_dist_all_gather` | Distributed all-gather operation (TRT-LLM backend, MPI mode) |
44+
| `torch.ops.auto_deploy.trtllm_dist_all_reduce` | Distributed all-reduce operation (TRT-LLM backend, MPI mode) |
45+
| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | Legacy name for `trtllm_fused_linear_all_reduce` |
46+
| `torch.ops.auto_deploy.trtllm_fused_linear_all_reduce` | Fused linear layer followed by all-reduce (TRT-LLM backend, MPI mode) |
47+
| `torch.ops.auto_deploy.trtllm_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce (TRT-LLM backend, MPI mode) |
48+
| `torch.ops.dist.torch_fused_allreduce_residual_rmsnorm` | Fused all-reduce + residual add + RMSNorm (PyTorch backend, demollm mode) |
49+
| `torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm` | Fused all-reduce + residual add + RMSNorm (TRT-LLM backend, MPI mode) |

tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,6 @@ def replacement_fn(x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor
9292
# Instantiate Pattern and Replacement Functions
9393
# ============================================================================
9494

95-
# Torch backend (demollm mode)
96-
_allreduce_residual_rmsnorm_pattern_torch = _make_allreduce_residual_rmsnorm_pattern(
97-
torch.ops.auto_deploy.torch_dist_all_reduce, add_order="residual_first"
98-
)
99-
_allreduce_residual_rmsnorm_pattern2_torch = _make_allreduce_residual_rmsnorm_pattern(
100-
torch.ops.auto_deploy.torch_dist_all_reduce, add_order="x_first"
101-
)
102-
_allreduce_residual_rmsnorm_repl_torch = _make_allreduce_residual_rmsnorm_replacement(
103-
torch.ops.dist.torch_fused_allreduce_residual_rmsnorm
104-
)
105-
10695
# TRT-LLM backend (MPI mode)
10796
_allreduce_residual_rmsnorm_pattern_trtllm = _make_allreduce_residual_rmsnorm_pattern(
10897
torch.ops.auto_deploy.trtllm_dist_all_reduce, add_order="residual_first"
@@ -149,29 +138,7 @@ def _apply(
149138
op_ignore_types = {torch.ops.aten.to.dtype: (torch.dtype,)}
150139
scalar_workaround = {"eps": 0.1253}
151140

152-
# Register BOTH torch and trtllm patterns
153-
# The pattern matcher will find whichever is present in the graph
154-
155-
# Torch backend patterns (residual + x)
156-
register_ad_pattern(
157-
search_fn=_allreduce_residual_rmsnorm_pattern_torch,
158-
replace_fn=_allreduce_residual_rmsnorm_repl_torch,
159-
patterns=patterns,
160-
dummy_args=dummy_args,
161-
op_ignore_types=op_ignore_types,
162-
scalar_workaround=scalar_workaround,
163-
)
164-
165-
# Torch backend patterns (x + residual)
166-
register_ad_pattern(
167-
search_fn=_allreduce_residual_rmsnorm_pattern2_torch,
168-
replace_fn=_allreduce_residual_rmsnorm_repl_torch,
169-
patterns=patterns,
170-
dummy_args=dummy_args,
171-
op_ignore_types=op_ignore_types,
172-
scalar_workaround=scalar_workaround,
173-
)
174-
141+
# Register only trtllm patterns
175142
# TRT-LLM backend patterns (residual + x)
176143
register_ad_pattern(
177144
search_fn=_allreduce_residual_rmsnorm_pattern_trtllm,

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,8 @@ def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Tra
10041004
base_size = bmm_batch_size // world_size
10051005
remainder = bmm_batch_size % world_size
10061006

1007-
# NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment.
1007+
# NOTE: our torch.ops.auto_deploy.torch_dist_all_gather/trtllm_dist_all_gather
1008+
# doesn't support uneven splits at the moment.
10081009
if remainder:
10091010
ad_logger.warning(
10101011
f"BMM batch size {bmm_batch_size} is not divisible by world size {world_size}. "

tensorrt_llm/_torch/auto_deploy/transform/library/visualization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
7777
# TODO(yudong): make custom_ops configurable
7878
CUSTOM_OPS = (
7979
torch.ops.auto_deploy.torch_dist_all_reduce.default,
80+
torch.ops.auto_deploy.trtllm_dist_all_reduce.default,
8081
torch.ops.aten.slice.Tensor,
8182
torch.ops.auto_deploy.triton_attention_fused_mha_with_cache.default,
8283
torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce.default,

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,8 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
881881
# Check if the distribution is balanced
882882
remainder = bmm_batch_size % self.world_size
883883

884-
# NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment.
884+
# NOTE: our torch.ops.auto_deploy.torch_dist_all_gather/trtllm_dist_all_gather
885+
# doesn't support uneven splits at the moment.
885886
if remainder:
886887
ad_logger.warning(
887888
f"BMM batch size {bmm_batch_size} is not divisible by world size {self.world_size}. "
@@ -1070,7 +1071,7 @@ def _insert_sharded_mxfp4_mlp_ep(
10701071
Transform a call to auto_deploy::triton_mxfp4_moe into:
10711072
- sharded expert parameters along dim 0 (this rank's slice),
10721073
- call to auto_deploy::triton_mxfp4_moe_ep(..., local_lo, local_hi),
1073-
- followed by torch_dist_all_reduce.
1074+
- followed by torch_dist_all_reduce/trtllm_dist_all_reduce.
10741075
10751076
Expects the original op signature:
10761077
(hidden_states,

tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99

1010
def _run_all_reduce_test(rank, world_size):
1111
x = torch.ones(10, 10).to("cuda")
12+
# Test torch backend (demollm mode with Python multiprocessing)
1213
y = torch.ops.auto_deploy.torch_dist_all_reduce(x)
1314

1415
assert torch.equal(x * world_size, y)
1516

1617

1718
def _run_all_gather_test(rank, world_size):
1819
x = torch.ones(10, 10).to("cuda")
20+
# Test torch backend (demollm mode with Python multiprocessing)
1921
y = torch.ops.auto_deploy.torch_dist_all_gather(x)
2022

2123
assert torch.sum(y) == world_size * torch.sum(x)

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def __init__(self, hidden_size, dtype):
3737
self.norm = RMSNorm(hidden_size, 1e-5, dtype)
3838

3939
def forward(self, x, residual):
40-
x = torch.ops.auto_deploy.torch_dist_all_reduce.default(x)
40+
# Use trtllm backend ops when running with MPI/TRT-LLM
41+
x = torch.ops.auto_deploy.trtllm_dist_all_reduce.default(x)
4142
y = residual + x
4243
normed = self.norm(y)
4344
return normed, y
@@ -51,7 +52,8 @@ def __init__(self, hidden_size, dtype):
5152
self.norm = RMSNorm(hidden_size, 1e-5, dtype)
5253

5354
def forward(self, x, residual):
54-
x = torch.ops.auto_deploy.torch_dist_all_reduce.default(x)
55+
# Use trtllm backend ops when running with MPI/TRT-LLM
56+
x = torch.ops.auto_deploy.trtllm_dist_all_reduce.default(x)
5557
y = x + residual
5658
normed = self.norm(y)
5759
return normed, y
@@ -94,7 +96,7 @@ def _test_allreduce_fusion(port: int, ModuleCls):
9496
# Check if fused node in the graph
9597
has_fused_node = False
9698
for node in gm_transformed.graph.nodes:
97-
if is_op(node, torch.ops.dist.fused_allreduce_residual_rmsnorm):
99+
if is_op(node, torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm):
98100
has_fused_node = True
99101
assert has_fused_node, "Fused node not found."
100102

0 commit comments

Comments
 (0)