Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ transforms:
sharding_source: ['factory','heuristic']
support_partial_config: true
sharding_dims: ['tp', 'ep', 'bmm']
dist_backend: auto
requires_shape_prop: true
sharding_transform_executor:
stage: sharding
Expand Down
9 changes: 5 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ The table below lists the operators ordered by their backend.
| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported |
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation |
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation |
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation (PyTorch backend, demollm mode) |
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation (PyTorch backend, demollm mode) |
| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation |
| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation |
| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation |
| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values |
| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce operation |
| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | FP4 quantized linear layer |
| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer |
| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies |
Expand All @@ -38,4 +37,6 @@ The table below lists the operators ordered by their backend.
| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs |
| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions |
| `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT LLM fused MoE implementation |
| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | TensorRT LLM fused linear layer followed by all-reduce operation |
| `torch.ops.auto_deploy.trtllm_dist_all_gather` | Distributed all-gather operation (TRT-LLM backend, MPI mode) |
| `torch.ops.auto_deploy.trtllm_dist_all_reduce` | Distributed all-reduce operation (TRT-LLM backend, MPI mode) |
| `torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm` | Fused all-reduce + residual add + RMSNorm (TRT-LLM backend, MPI mode) |
44 changes: 0 additions & 44 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py

This file was deleted.

25 changes: 0 additions & 25 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

import torch

from ..distributed import common as dist
from ..distributed import trtllm as trtllm_dist


@torch.library.custom_op("auto_deploy::torch_linear_simple", mutates_args=())
def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor:
Expand All @@ -24,26 +21,4 @@ def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso
@simple.register_fake
def simple_fake(input, weight, bias):
"""Fake implementation of simple_linear."""
# return torch.empty(
# input.shape[:-1] + (weight.shape[-1],), dtype=input.dtype, device=input.device
# )
return torch.ops.aten.linear(input, weight, bias)


@torch.library.custom_op(
"auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
)
def fused_linear_all_reduce(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
"""Fused linear followed by all_reduce on the output."""
output = torch.ops.aten.linear(input, weight, bias)
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM)
dist.all_reduce(output, op=dist.ReduceOp.SUM)
return output


@fused_linear_all_reduce.register_fake
def fused_linear_all_reduce_fake(input, weight, bias):
return torch.ops.aten.linear(input, weight, bias)
33 changes: 0 additions & 33 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

from tensorrt_llm._torch.autotuner import autotune

from ..distributed import common as dist
from ..distributed import trtllm as trtllm_dist
from .torch_libs.float8_python_api import addmm_float8_unwrapped

TRTLLM_FP4_OP_AVAILABLE = True
Expand Down Expand Up @@ -240,37 +238,6 @@ def fp8_linear_fake(
return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias)


@torch.library.custom_op("auto_deploy::torch_quant_fused_fp8_linear_all_reduce", mutates_args=())
@torch.compile(dynamic=True)
def fused_fp8_linear_all_reduce(
input: torch.Tensor,
weight_fp8: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = torch.ops.auto_deploy.torch_quant_fp8_linear(
input, weight_fp8, bias, input_scale, weight_scale
)
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM)
dist.all_reduce(out, op=dist.ReduceOp.SUM)
return out


@fused_fp8_linear_all_reduce.register_fake
def fused_fp8_linear_all_reduce_fake(
input: torch.Tensor,
weight_fp8: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.ops.auto_deploy.torch_quant_fp8_linear(
input, weight_fp8, bias, input_scale, weight_scale
)


class FP8Linear(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
52 changes: 52 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/torch_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Custom ops required for implementing tensor parallelism.

This module defines atomic distributed ops - each op uses a specific backend
(torch.distributed or TRT-LLM) without internal dispatch logic.
"""

from typing import List, Optional

import torch

from ..distributed import common as dist

# ============================================================================
# PyTorch Distributed Backend Ops (demollm mode)
# ============================================================================


@torch.library.custom_op("auto_deploy::torch_dist_all_gather", mutates_args=(), device_types="cuda")
def torch_dist_all_gather(
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
) -> torch.Tensor:
"""All gather using PyTorch distributed backend.

This op always uses torch.distributed.all_gather and is used in demollm mode.
"""
tl = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
dist.all_gather(tl, tensor)
return torch.cat(tl, dim=dim)


@torch_dist_all_gather.register_fake
def torch_dist_all_gather_fake(tensor, dim=0, sizes=None):
return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim)


@torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda")
def torch_dist_all_reduce(t: torch.Tensor) -> torch.Tensor:
"""All_reduce using PyTorch distributed backend. Reduction op is SUM.

This op always uses torch.distributed.all_reduce and is used in demollm mode.

NOTE: this op requires an extra memory copy and should ONLY be used for debugging + testing. For
efficient all_reduce ops one should write/replace it with a fused op.
"""
t_res = t.clone()
dist.all_reduce(t_res, op=dist.ReduceOp.SUM)
return t_res


@torch_dist_all_reduce.register_fake
def torch_dist_all_reduce_fake(tensor):
return torch.empty_like(tensor)
115 changes: 115 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""TRT-LLM distributed operations and fused kernels.

This module defines atomic TRT-LLM-specific ops that use optimized kernels.
The torch fallback variants are defined separately to enable multi-pattern matching.
"""

from typing import List, Optional

import torch

# use trtllm distributed ops to improve TP performance if possible
from ....mapping import Mapping
from ...distributed import AllReduce, allgather
from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy
from ..distributed.common import ReduceOp, get_rank_world_size, get_world_size, is_ompi

# Cache AllReduce modules to avoid recreating on every call
# This is critical for CUDA graph compatibility - recreating modules during
# warmup causes hangs due to workspace allocation with CPU synchronization
_allreduce_cache = {}


def trtllm_allgather(tensor, dim, sizes=None):
rank, world_size = get_rank_world_size()
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
return allgather(tensor, p_config, dim=dim, sizes=sizes)


def trtllm_allreduce(tensor, op, all_reduce_params=None):
rank, world_size = get_rank_world_size()
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."

# Cache key includes rank, world_size, and dtype to handle different configurations
cache_key = (rank, world_size, tensor.dtype)
if cache_key not in _allreduce_cache:
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
# Use Strategy.AUTO for optimal performance
_allreduce_cache[cache_key] = AllReduce(
mapping=p_config, strategy=AllReduceStrategy.NCCL, dtype=tensor.dtype
)

torch_op = _allreduce_cache[cache_key]
return torch_op(tensor, all_reduce_params=all_reduce_params)


# ============================================================================
# TRT-LLM Backend Ops (MPI mode)
# ============================================================================


@torch.library.custom_op(
"auto_deploy::trtllm_dist_all_gather", mutates_args=(), device_types="cuda"
)
def trtllm_dist_all_gather(
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
) -> torch.Tensor:
"""All gather using TRT-LLM optimized backend.

This op always uses TRT-LLM's optimized allgather and is used in MPI mode.
"""
return trtllm_allgather(tensor, dim=dim, sizes=sizes)


@trtllm_dist_all_gather.register_fake
def trtllm_dist_all_gather_fake(tensor, dim=0, sizes=None):
return torch.cat([torch.empty_like(tensor) for _ in range(get_world_size())], dim=dim)


@torch.library.custom_op(
"auto_deploy::trtllm_dist_all_reduce", mutates_args=(), device_types="cuda"
)
def trtllm_dist_all_reduce(t: torch.Tensor) -> torch.Tensor:
"""All_reduce using TRT-LLM optimized backend. Reduction op is SUM.

This op always uses TRT-LLM's optimized allreduce and is used in MPI mode.
"""
return trtllm_allreduce(t, op=ReduceOp.SUM)


@trtllm_dist_all_reduce.register_fake
def trtllm_dist_all_reduce_fake(tensor):
return torch.empty_like(tensor)


# TRT-LLM fused op (atomic - always uses TRT-LLM backend)
@torch.library.custom_op(
"dist::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda"
)
def trtllm_fused_allreduce_residual_rmsnorm(
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fused allreduce + residual + rmsnorm using TRT-LLM optimized kernel.

This op always uses TRT-LLM's fused kernel and is used in MPI mode.
"""
all_reduce_params = AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
bias=None,
residual=residual,
norm_weight=norm_weight,
eps=eps,
)
return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params)


@trtllm_fused_allreduce_residual_rmsnorm.register_fake
def trtllm_fused_allreduce_residual_rmsnorm_fake(
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(tensor), torch.empty_like(tensor)


def is_trtllm_op_available():
"""Check if TRT-LLM ops are available and running with MPI."""
return is_ompi()
Loading