Skip to content

Commit 2f3b295

Browse files
committed
removed try/catch in trtllm)dist.py
Signed-off-by: Eran Geva <[email protected]>
1 parent aeb7410 commit 2f3b295

File tree

1 file changed

+87
-91
lines changed

1 file changed

+87
-91
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_dist.py

Lines changed: 87 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -8,112 +8,108 @@
88

99
import torch
1010

11+
# use trtllm distributed ops to improve TP performance if possible
12+
from ....mapping import Mapping
13+
from ...distributed import AllReduce, allgather
14+
from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy
1115
from ..distributed.common import ReduceOp, get_rank_world_size, get_world_size, is_ompi
1216

13-
# use trtllm distributed ops to improve TP performance if possible
14-
try:
15-
from ....mapping import Mapping
16-
from ...distributed import AllReduce, allgather
17-
from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy
18-
19-
# Cache AllReduce modules to avoid recreating on every call
20-
# This is critical for CUDA graph compatibility - recreating modules during
21-
# warmup causes hangs due to workspace allocation with CPU synchronization
22-
_allreduce_cache = {}
23-
24-
def trtllm_allgather(tensor, dim, sizes=None):
25-
rank, world_size = get_rank_world_size()
17+
# Cache AllReduce modules to avoid recreating on every call
18+
# This is critical for CUDA graph compatibility - recreating modules during
19+
# warmup causes hangs due to workspace allocation with CPU synchronization
20+
_allreduce_cache = {}
21+
22+
23+
def trtllm_allgather(tensor, dim, sizes=None):
24+
rank, world_size = get_rank_world_size()
25+
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
26+
return allgather(tensor, p_config, dim=dim, sizes=sizes)
27+
28+
29+
def trtllm_allreduce(tensor, op, all_reduce_params=None):
30+
rank, world_size = get_rank_world_size()
31+
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
32+
33+
# Cache key includes rank, world_size, and dtype to handle different configurations
34+
cache_key = (rank, world_size, tensor.dtype)
35+
if cache_key not in _allreduce_cache:
2636
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
27-
return allgather(tensor, p_config, dim=dim, sizes=sizes)
28-
29-
def trtllm_allreduce(tensor, op, all_reduce_params=None):
30-
rank, world_size = get_rank_world_size()
31-
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
32-
33-
# Cache key includes rank, world_size, and dtype to handle different configurations
34-
cache_key = (rank, world_size, tensor.dtype)
35-
if cache_key not in _allreduce_cache:
36-
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
37-
# Use Strategy.AUTO for optimal performance
38-
_allreduce_cache[cache_key] = AllReduce(
39-
mapping=p_config, strategy=AllReduceStrategy.NCCL, dtype=tensor.dtype
40-
)
41-
42-
torch_op = _allreduce_cache[cache_key]
43-
return torch_op(tensor, all_reduce_params=all_reduce_params)
44-
45-
# ============================================================================
46-
# TRT-LLM Backend Ops (MPI mode)
47-
# ============================================================================
48-
49-
@torch.library.custom_op(
50-
"auto_deploy::trtllm_dist_all_gather", mutates_args=(), device_types="cuda"
51-
)
52-
def trtllm_dist_all_gather(
53-
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
54-
) -> torch.Tensor:
55-
"""All gather using TRT-LLM optimized backend.
37+
# Use Strategy.AUTO for optimal performance
38+
_allreduce_cache[cache_key] = AllReduce(
39+
mapping=p_config, strategy=AllReduceStrategy.NCCL, dtype=tensor.dtype
40+
)
5641

57-
This op always uses TRT-LLM's optimized allgather and is used in MPI mode.
58-
"""
59-
return trtllm_allgather(tensor, dim=dim, sizes=sizes)
42+
torch_op = _allreduce_cache[cache_key]
43+
return torch_op(tensor, all_reduce_params=all_reduce_params)
6044

61-
@trtllm_dist_all_gather.register_fake
62-
def trtllm_dist_all_gather_fake(tensor, dim=0, sizes=None):
63-
return torch.cat([torch.empty_like(tensor) for _ in range(get_world_size())], dim=dim)
6445

65-
@torch.library.custom_op(
66-
"auto_deploy::trtllm_dist_all_reduce", mutates_args=(), device_types="cuda"
67-
)
68-
def trtllm_dist_all_reduce(t: torch.Tensor) -> torch.Tensor:
69-
"""All_reduce using TRT-LLM optimized backend. Reduction op is SUM.
46+
# ============================================================================
47+
# TRT-LLM Backend Ops (MPI mode)
48+
# ============================================================================
7049

71-
This op always uses TRT-LLM's optimized allreduce and is used in MPI mode.
72-
"""
73-
return trtllm_allreduce(t, op=ReduceOp.SUM)
7450

75-
@trtllm_dist_all_reduce.register_fake
76-
def trtllm_dist_all_reduce_fake(tensor):
77-
return torch.empty_like(tensor)
51+
@torch.library.custom_op(
52+
"auto_deploy::trtllm_dist_all_gather", mutates_args=(), device_types="cuda"
53+
)
54+
def trtllm_dist_all_gather(
55+
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
56+
) -> torch.Tensor:
57+
"""All gather using TRT-LLM optimized backend.
58+
59+
This op always uses TRT-LLM's optimized allgather and is used in MPI mode.
60+
"""
61+
return trtllm_allgather(tensor, dim=dim, sizes=sizes)
62+
63+
64+
@trtllm_dist_all_gather.register_fake
65+
def trtllm_dist_all_gather_fake(tensor, dim=0, sizes=None):
66+
return torch.cat([torch.empty_like(tensor) for _ in range(get_world_size())], dim=dim)
7867

79-
# TRT-LLM fused op (atomic - always uses TRT-LLM backend)
80-
@torch.library.custom_op(
81-
"dist::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda"
82-
)
83-
def trtllm_fused_allreduce_residual_rmsnorm(
84-
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
85-
) -> tuple[torch.Tensor, torch.Tensor]:
86-
"""Fused allreduce + residual + rmsnorm using TRT-LLM optimized kernel.
87-
88-
This op always uses TRT-LLM's fused kernel and is used in MPI mode.
89-
"""
90-
all_reduce_params = AllReduceParams(
91-
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
92-
bias=None,
93-
residual=residual,
94-
norm_weight=norm_weight,
95-
eps=eps,
96-
)
97-
return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params)
9868

99-
@trtllm_fused_allreduce_residual_rmsnorm.register_fake
100-
def trtllm_fused_allreduce_residual_rmsnorm_fake(
101-
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
102-
) -> tuple[torch.Tensor, torch.Tensor]:
103-
return torch.empty_like(tensor), torch.empty_like(tensor)
69+
@torch.library.custom_op(
70+
"auto_deploy::trtllm_dist_all_reduce", mutates_args=(), device_types="cuda"
71+
)
72+
def trtllm_dist_all_reduce(t: torch.Tensor) -> torch.Tensor:
73+
"""All_reduce using TRT-LLM optimized backend. Reduction op is SUM.
10474
105-
TRTLLM_OP_AVAILABLE = True
106-
except ImportError:
75+
This op always uses TRT-LLM's optimized allreduce and is used in MPI mode.
76+
"""
77+
return trtllm_allreduce(t, op=ReduceOp.SUM)
10778

108-
def trtllm_allgather(tensor, dim, sizes=None):
109-
raise ImportError("TRT-LLM is not available.")
11079

111-
def trtllm_allreduce(tensor, op, all_reduce_params=None):
112-
raise ImportError("TRT-LLM is not available.")
80+
@trtllm_dist_all_reduce.register_fake
81+
def trtllm_dist_all_reduce_fake(tensor):
82+
return torch.empty_like(tensor)
83+
84+
85+
# TRT-LLM fused op (atomic - always uses TRT-LLM backend)
86+
@torch.library.custom_op(
87+
"dist::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda"
88+
)
89+
def trtllm_fused_allreduce_residual_rmsnorm(
90+
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
91+
) -> tuple[torch.Tensor, torch.Tensor]:
92+
"""Fused allreduce + residual + rmsnorm using TRT-LLM optimized kernel.
93+
94+
This op always uses TRT-LLM's fused kernel and is used in MPI mode.
95+
"""
96+
all_reduce_params = AllReduceParams(
97+
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
98+
bias=None,
99+
residual=residual,
100+
norm_weight=norm_weight,
101+
eps=eps,
102+
)
103+
return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params)
104+
113105

114-
TRTLLM_OP_AVAILABLE = False
106+
@trtllm_fused_allreduce_residual_rmsnorm.register_fake
107+
def trtllm_fused_allreduce_residual_rmsnorm_fake(
108+
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
109+
) -> tuple[torch.Tensor, torch.Tensor]:
110+
return torch.empty_like(tensor), torch.empty_like(tensor)
115111

116112

117113
def is_trtllm_op_available():
118114
"""Check if TRT-LLM ops are available and running with MPI."""
119-
return TRTLLM_OP_AVAILABLE and is_ompi()
115+
return is_ompi()

0 commit comments

Comments
 (0)