|
8 | 8 |
|
9 | 9 | import torch |
10 | 10 |
|
| 11 | +# use trtllm distributed ops to improve TP performance if possible |
| 12 | +from ....mapping import Mapping |
| 13 | +from ...distributed import AllReduce, allgather |
| 14 | +from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy |
11 | 15 | from ..distributed.common import ReduceOp, get_rank_world_size, get_world_size, is_ompi |
12 | 16 |
|
13 | | -# use trtllm distributed ops to improve TP performance if possible |
14 | | -try: |
15 | | - from ....mapping import Mapping |
16 | | - from ...distributed import AllReduce, allgather |
17 | | - from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy |
18 | | - |
19 | | - # Cache AllReduce modules to avoid recreating on every call |
20 | | - # This is critical for CUDA graph compatibility - recreating modules during |
21 | | - # warmup causes hangs due to workspace allocation with CPU synchronization |
22 | | - _allreduce_cache = {} |
23 | | - |
24 | | - def trtllm_allgather(tensor, dim, sizes=None): |
25 | | - rank, world_size = get_rank_world_size() |
| 17 | +# Cache AllReduce modules to avoid recreating on every call |
| 18 | +# This is critical for CUDA graph compatibility - recreating modules during |
| 19 | +# warmup causes hangs due to workspace allocation with CPU synchronization |
| 20 | +_allreduce_cache = {} |
| 21 | + |
| 22 | + |
| 23 | +def trtllm_allgather(tensor, dim, sizes=None): |
| 24 | + rank, world_size = get_rank_world_size() |
| 25 | + p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) |
| 26 | + return allgather(tensor, p_config, dim=dim, sizes=sizes) |
| 27 | + |
| 28 | + |
| 29 | +def trtllm_allreduce(tensor, op, all_reduce_params=None): |
| 30 | + rank, world_size = get_rank_world_size() |
| 31 | + assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op." |
| 32 | + |
| 33 | + # Cache key includes rank, world_size, and dtype to handle different configurations |
| 34 | + cache_key = (rank, world_size, tensor.dtype) |
| 35 | + if cache_key not in _allreduce_cache: |
26 | 36 | p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) |
27 | | - return allgather(tensor, p_config, dim=dim, sizes=sizes) |
28 | | - |
29 | | - def trtllm_allreduce(tensor, op, all_reduce_params=None): |
30 | | - rank, world_size = get_rank_world_size() |
31 | | - assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op." |
32 | | - |
33 | | - # Cache key includes rank, world_size, and dtype to handle different configurations |
34 | | - cache_key = (rank, world_size, tensor.dtype) |
35 | | - if cache_key not in _allreduce_cache: |
36 | | - p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) |
37 | | - # Use Strategy.AUTO for optimal performance |
38 | | - _allreduce_cache[cache_key] = AllReduce( |
39 | | - mapping=p_config, strategy=AllReduceStrategy.NCCL, dtype=tensor.dtype |
40 | | - ) |
41 | | - |
42 | | - torch_op = _allreduce_cache[cache_key] |
43 | | - return torch_op(tensor, all_reduce_params=all_reduce_params) |
44 | | - |
45 | | - # ============================================================================ |
46 | | - # TRT-LLM Backend Ops (MPI mode) |
47 | | - # ============================================================================ |
48 | | - |
49 | | - @torch.library.custom_op( |
50 | | - "auto_deploy::trtllm_dist_all_gather", mutates_args=(), device_types="cuda" |
51 | | - ) |
52 | | - def trtllm_dist_all_gather( |
53 | | - tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None |
54 | | - ) -> torch.Tensor: |
55 | | - """All gather using TRT-LLM optimized backend. |
| 37 | + # Use Strategy.AUTO for optimal performance |
| 38 | + _allreduce_cache[cache_key] = AllReduce( |
| 39 | + mapping=p_config, strategy=AllReduceStrategy.NCCL, dtype=tensor.dtype |
| 40 | + ) |
56 | 41 |
|
57 | | - This op always uses TRT-LLM's optimized allgather and is used in MPI mode. |
58 | | - """ |
59 | | - return trtllm_allgather(tensor, dim=dim, sizes=sizes) |
| 42 | + torch_op = _allreduce_cache[cache_key] |
| 43 | + return torch_op(tensor, all_reduce_params=all_reduce_params) |
60 | 44 |
|
61 | | - @trtllm_dist_all_gather.register_fake |
62 | | - def trtllm_dist_all_gather_fake(tensor, dim=0, sizes=None): |
63 | | - return torch.cat([torch.empty_like(tensor) for _ in range(get_world_size())], dim=dim) |
64 | 45 |
|
65 | | - @torch.library.custom_op( |
66 | | - "auto_deploy::trtllm_dist_all_reduce", mutates_args=(), device_types="cuda" |
67 | | - ) |
68 | | - def trtllm_dist_all_reduce(t: torch.Tensor) -> torch.Tensor: |
69 | | - """All_reduce using TRT-LLM optimized backend. Reduction op is SUM. |
| 46 | +# ============================================================================ |
| 47 | +# TRT-LLM Backend Ops (MPI mode) |
| 48 | +# ============================================================================ |
70 | 49 |
|
71 | | - This op always uses TRT-LLM's optimized allreduce and is used in MPI mode. |
72 | | - """ |
73 | | - return trtllm_allreduce(t, op=ReduceOp.SUM) |
74 | 50 |
|
75 | | - @trtllm_dist_all_reduce.register_fake |
76 | | - def trtllm_dist_all_reduce_fake(tensor): |
77 | | - return torch.empty_like(tensor) |
| 51 | +@torch.library.custom_op( |
| 52 | + "auto_deploy::trtllm_dist_all_gather", mutates_args=(), device_types="cuda" |
| 53 | +) |
| 54 | +def trtllm_dist_all_gather( |
| 55 | + tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None |
| 56 | +) -> torch.Tensor: |
| 57 | + """All gather using TRT-LLM optimized backend. |
| 58 | +
|
| 59 | + This op always uses TRT-LLM's optimized allgather and is used in MPI mode. |
| 60 | + """ |
| 61 | + return trtllm_allgather(tensor, dim=dim, sizes=sizes) |
| 62 | + |
| 63 | + |
| 64 | +@trtllm_dist_all_gather.register_fake |
| 65 | +def trtllm_dist_all_gather_fake(tensor, dim=0, sizes=None): |
| 66 | + return torch.cat([torch.empty_like(tensor) for _ in range(get_world_size())], dim=dim) |
78 | 67 |
|
79 | | - # TRT-LLM fused op (atomic - always uses TRT-LLM backend) |
80 | | - @torch.library.custom_op( |
81 | | - "dist::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda" |
82 | | - ) |
83 | | - def trtllm_fused_allreduce_residual_rmsnorm( |
84 | | - tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float |
85 | | - ) -> tuple[torch.Tensor, torch.Tensor]: |
86 | | - """Fused allreduce + residual + rmsnorm using TRT-LLM optimized kernel. |
87 | | -
|
88 | | - This op always uses TRT-LLM's fused kernel and is used in MPI mode. |
89 | | - """ |
90 | | - all_reduce_params = AllReduceParams( |
91 | | - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, |
92 | | - bias=None, |
93 | | - residual=residual, |
94 | | - norm_weight=norm_weight, |
95 | | - eps=eps, |
96 | | - ) |
97 | | - return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params) |
98 | 68 |
|
99 | | - @trtllm_fused_allreduce_residual_rmsnorm.register_fake |
100 | | - def trtllm_fused_allreduce_residual_rmsnorm_fake( |
101 | | - tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float |
102 | | - ) -> tuple[torch.Tensor, torch.Tensor]: |
103 | | - return torch.empty_like(tensor), torch.empty_like(tensor) |
| 69 | +@torch.library.custom_op( |
| 70 | + "auto_deploy::trtllm_dist_all_reduce", mutates_args=(), device_types="cuda" |
| 71 | +) |
| 72 | +def trtllm_dist_all_reduce(t: torch.Tensor) -> torch.Tensor: |
| 73 | + """All_reduce using TRT-LLM optimized backend. Reduction op is SUM. |
104 | 74 |
|
105 | | - TRTLLM_OP_AVAILABLE = True |
106 | | -except ImportError: |
| 75 | + This op always uses TRT-LLM's optimized allreduce and is used in MPI mode. |
| 76 | + """ |
| 77 | + return trtllm_allreduce(t, op=ReduceOp.SUM) |
107 | 78 |
|
108 | | - def trtllm_allgather(tensor, dim, sizes=None): |
109 | | - raise ImportError("TRT-LLM is not available.") |
110 | 79 |
|
111 | | - def trtllm_allreduce(tensor, op, all_reduce_params=None): |
112 | | - raise ImportError("TRT-LLM is not available.") |
| 80 | +@trtllm_dist_all_reduce.register_fake |
| 81 | +def trtllm_dist_all_reduce_fake(tensor): |
| 82 | + return torch.empty_like(tensor) |
| 83 | + |
| 84 | + |
| 85 | +# TRT-LLM fused op (atomic - always uses TRT-LLM backend) |
| 86 | +@torch.library.custom_op( |
| 87 | + "dist::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda" |
| 88 | +) |
| 89 | +def trtllm_fused_allreduce_residual_rmsnorm( |
| 90 | + tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float |
| 91 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 92 | + """Fused allreduce + residual + rmsnorm using TRT-LLM optimized kernel. |
| 93 | +
|
| 94 | + This op always uses TRT-LLM's fused kernel and is used in MPI mode. |
| 95 | + """ |
| 96 | + all_reduce_params = AllReduceParams( |
| 97 | + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, |
| 98 | + bias=None, |
| 99 | + residual=residual, |
| 100 | + norm_weight=norm_weight, |
| 101 | + eps=eps, |
| 102 | + ) |
| 103 | + return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params) |
| 104 | + |
113 | 105 |
|
114 | | - TRTLLM_OP_AVAILABLE = False |
| 106 | +@trtllm_fused_allreduce_residual_rmsnorm.register_fake |
| 107 | +def trtllm_fused_allreduce_residual_rmsnorm_fake( |
| 108 | + tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float |
| 109 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 110 | + return torch.empty_like(tensor), torch.empty_like(tensor) |
115 | 111 |
|
116 | 112 |
|
117 | 113 | def is_trtllm_op_available(): |
118 | 114 | """Check if TRT-LLM ops are available and running with MPI.""" |
119 | | - return TRTLLM_OP_AVAILABLE and is_ompi() |
| 115 | + return is_ompi() |
0 commit comments