Skip to content

Commit 9e0b47a

Browse files
committed
split dist custom ops and use templated patterns
Signed-off-by: Eran Geva <[email protected]>
1 parent 46dd988 commit 9e0b47a

File tree

7 files changed

+430
-125
lines changed

7 files changed

+430
-125
lines changed
Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
"""Custom ops required for implementing tensor parallelism."""
1+
"""Custom ops required for implementing tensor parallelism.
2+
3+
This module defines atomic distributed ops - each op uses a specific backend
4+
(torch.distributed or TRT-LLM) without internal dispatch logic.
5+
"""
26

37
from typing import List, Optional
48

@@ -7,38 +11,82 @@
711
from ..distributed import common as dist
812
from ..distributed import trtllm as trtllm_dist
913

14+
# ============================================================================
15+
# PyTorch Distributed Backend Ops (demollm mode)
16+
# ============================================================================
17+
1018

1119
@torch.library.custom_op("auto_deploy::torch_dist_all_gather", mutates_args=(), device_types="cuda")
12-
def all_gather(
20+
def torch_dist_all_gather(
1321
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
1422
) -> torch.Tensor:
15-
"""All gather followed by concat in dim = 0. This is the default nccl behavior."""
16-
if trtllm_dist.is_trtllm_op_available():
17-
return trtllm_dist.trtllm_allgather(tensor, dim=dim, sizes=sizes)
23+
"""All gather using PyTorch distributed backend.
24+
25+
This op always uses torch.distributed.all_gather and is used in demollm mode.
26+
"""
1827
tl = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
1928
dist.all_gather(tl, tensor)
2029
return torch.cat(tl, dim=dim)
2130

2231

23-
@all_gather.register_fake
24-
def all_gather_fake(tensor, dim=0):
32+
@torch_dist_all_gather.register_fake
33+
def torch_dist_all_gather_fake(tensor, dim=0, sizes=None):
2534
return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim)
2635

2736

2837
@torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda")
29-
def all_reduce(t: torch.Tensor) -> torch.Tensor:
30-
"""All_reduce across the ranks. Reduction op is SUM.
38+
def torch_dist_all_reduce(t: torch.Tensor) -> torch.Tensor:
39+
"""All_reduce using PyTorch distributed backend. Reduction op is SUM.
40+
41+
This op always uses torch.distributed.all_reduce and is used in demollm mode.
3142
3243
NOTE: this op requires an extra memory copy and should ONLY be used for debugging + testing. For
3344
efficient all_reduce ops one should write/replace it with a fused op.
3445
"""
35-
if trtllm_dist.is_trtllm_op_available():
36-
return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM)
3746
t_res = t.clone()
3847
dist.all_reduce(t_res, op=dist.ReduceOp.SUM)
3948
return t_res
4049

4150

42-
@all_reduce.register_fake
43-
def all_reduce_fake(tensor):
51+
@torch_dist_all_reduce.register_fake
52+
def torch_dist_all_reduce_fake(tensor):
53+
return torch.empty_like(tensor)
54+
55+
56+
# ============================================================================
57+
# TRT-LLM Backend Ops (MPI mode)
58+
# ============================================================================
59+
60+
61+
@torch.library.custom_op(
62+
"auto_deploy::trtllm_dist_all_gather", mutates_args=(), device_types="cuda"
63+
)
64+
def trtllm_dist_all_gather(
65+
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
66+
) -> torch.Tensor:
67+
"""All gather using TRT-LLM optimized backend.
68+
69+
This op always uses TRT-LLM's optimized allgather and is used in MPI mode.
70+
"""
71+
return trtllm_dist.trtllm_allgather(tensor, dim=dim, sizes=sizes)
72+
73+
74+
@trtllm_dist_all_gather.register_fake
75+
def trtllm_dist_all_gather_fake(tensor, dim=0, sizes=None):
76+
return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim)
77+
78+
79+
@torch.library.custom_op(
80+
"auto_deploy::trtllm_dist_all_reduce", mutates_args=(), device_types="cuda"
81+
)
82+
def trtllm_dist_all_reduce(t: torch.Tensor) -> torch.Tensor:
83+
"""All_reduce using TRT-LLM optimized backend. Reduction op is SUM.
84+
85+
This op always uses TRT-LLM's optimized allreduce and is used in MPI mode.
86+
"""
87+
return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM)
88+
89+
90+
@trtllm_dist_all_reduce.register_fake
91+
def trtllm_dist_all_reduce_fake(tensor):
4492
return torch.empty_like(tensor)

tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,73 @@ def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso
2424
@simple.register_fake
2525
def simple_fake(input, weight, bias):
2626
"""Fake implementation of simple_linear."""
27-
# return torch.empty(
28-
# input.shape[:-1] + (weight.shape[-1],), dtype=input.dtype, device=input.device
29-
# )
3027
return torch.ops.aten.linear(input, weight, bias)
3128

3229

30+
# ============================================================================
31+
# Fused Linear + AllReduce Ops (Atomic - Backend Specific)
32+
# ============================================================================
33+
34+
3335
@torch.library.custom_op(
34-
"auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
36+
"auto_deploy::torch_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
3537
)
36-
def fused_linear_all_reduce(
38+
def torch_fused_linear_all_reduce(
3739
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
3840
) -> torch.Tensor:
39-
"""Fused linear followed by all_reduce on the output."""
41+
"""Fused linear + all_reduce using PyTorch backend.
42+
43+
This op always uses torch.distributed and is used in demollm mode.
44+
"""
4045
output = torch.ops.aten.linear(input, weight, bias)
41-
if trtllm_dist.is_trtllm_op_available():
42-
return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM)
4346
dist.all_reduce(output, op=dist.ReduceOp.SUM)
4447
return output
4548

4649

47-
@fused_linear_all_reduce.register_fake
48-
def fused_linear_all_reduce_fake(input, weight, bias):
50+
@torch_fused_linear_all_reduce.register_fake
51+
def torch_fused_linear_all_reduce_fake(input, weight, bias):
52+
return torch.ops.aten.linear(input, weight, bias)
53+
54+
55+
@torch.library.custom_op(
56+
"auto_deploy::trtllm_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
57+
)
58+
def trtllm_fused_linear_all_reduce(
59+
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
60+
) -> torch.Tensor:
61+
"""Fused linear + all_reduce using TRT-LLM backend.
62+
63+
This op always uses TRT-LLM's optimized allreduce and is used in MPI mode.
64+
"""
65+
output = torch.ops.aten.linear(input, weight, bias)
66+
return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM)
67+
68+
69+
@trtllm_fused_linear_all_reduce.register_fake
70+
def trtllm_fused_linear_all_reduce_fake(input, weight, bias):
71+
return torch.ops.aten.linear(input, weight, bias)
72+
73+
74+
# ============================================================================
75+
# Legacy op name for backward compatibility
76+
# ============================================================================
77+
78+
79+
@torch.library.custom_op(
80+
"auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
81+
)
82+
def trtllm_dist_fused_linear_all_reduce(
83+
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
84+
) -> torch.Tensor:
85+
"""Legacy name for trtllm_fused_linear_all_reduce.
86+
87+
Kept for backward compatibility with existing code.
88+
This is an alias that directly implements the same logic.
89+
"""
90+
output = torch.ops.aten.linear(input, weight, bias)
91+
return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM)
92+
93+
94+
@trtllm_dist_fused_linear_all_reduce.register_fake
95+
def trtllm_dist_fused_linear_all_reduce_fake(input, weight, bias):
4996
return torch.ops.aten.linear(input, weight, bias)

tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,26 +240,104 @@ def fp8_linear_fake(
240240
return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias)
241241

242242

243+
# ============================================================================
244+
# Fused FP8 Linear + AllReduce Ops (Atomic - Backend Specific)
245+
# ============================================================================
246+
247+
248+
@torch.library.custom_op("auto_deploy::torch_fused_fp8_linear_all_reduce", mutates_args=())
249+
@torch.compile(dynamic=True)
250+
def torch_fused_fp8_linear_all_reduce(
251+
input: torch.Tensor,
252+
weight_fp8: torch.Tensor,
253+
bias: Optional[torch.Tensor] = None,
254+
input_scale: Optional[torch.Tensor] = None,
255+
weight_scale: Optional[torch.Tensor] = None,
256+
) -> torch.Tensor:
257+
"""Fused FP8 linear + all_reduce using PyTorch backend.
258+
259+
This op always uses torch.distributed and is used in demollm mode.
260+
"""
261+
out = torch.ops.auto_deploy.torch_quant_fp8_linear(
262+
input, weight_fp8, bias, input_scale, weight_scale
263+
)
264+
dist.all_reduce(out, op=dist.ReduceOp.SUM)
265+
return out
266+
267+
268+
@torch_fused_fp8_linear_all_reduce.register_fake
269+
def torch_fused_fp8_linear_all_reduce_fake(
270+
input: torch.Tensor,
271+
weight_fp8: torch.Tensor,
272+
bias: Optional[torch.Tensor] = None,
273+
input_scale: Optional[torch.Tensor] = None,
274+
weight_scale: Optional[torch.Tensor] = None,
275+
) -> torch.Tensor:
276+
return torch.ops.auto_deploy.torch_quant_fp8_linear(
277+
input, weight_fp8, bias, input_scale, weight_scale
278+
)
279+
280+
281+
@torch.library.custom_op("auto_deploy::trtllm_fused_fp8_linear_all_reduce", mutates_args=())
282+
@torch.compile(dynamic=True)
283+
def trtllm_fused_fp8_linear_all_reduce(
284+
input: torch.Tensor,
285+
weight_fp8: torch.Tensor,
286+
bias: Optional[torch.Tensor] = None,
287+
input_scale: Optional[torch.Tensor] = None,
288+
weight_scale: Optional[torch.Tensor] = None,
289+
) -> torch.Tensor:
290+
"""Fused FP8 linear + all_reduce using TRT-LLM backend.
291+
292+
This op always uses TRT-LLM's optimized allreduce and is used in MPI mode.
293+
"""
294+
out = torch.ops.auto_deploy.torch_quant_fp8_linear(
295+
input, weight_fp8, bias, input_scale, weight_scale
296+
)
297+
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM)
298+
299+
300+
@trtllm_fused_fp8_linear_all_reduce.register_fake
301+
def trtllm_fused_fp8_linear_all_reduce_fake(
302+
input: torch.Tensor,
303+
weight_fp8: torch.Tensor,
304+
bias: Optional[torch.Tensor] = None,
305+
input_scale: Optional[torch.Tensor] = None,
306+
weight_scale: Optional[torch.Tensor] = None,
307+
) -> torch.Tensor:
308+
return torch.ops.auto_deploy.torch_quant_fp8_linear(
309+
input, weight_fp8, bias, input_scale, weight_scale
310+
)
311+
312+
313+
# ============================================================================
314+
# Legacy op name for backward compatibility
315+
# ============================================================================
316+
317+
243318
@torch.library.custom_op("auto_deploy::torch_quant_fused_fp8_linear_all_reduce", mutates_args=())
244319
@torch.compile(dynamic=True)
245-
def fused_fp8_linear_all_reduce(
320+
def torch_quant_fused_fp8_linear_all_reduce(
246321
input: torch.Tensor,
247322
weight_fp8: torch.Tensor,
248323
bias: Optional[torch.Tensor] = None,
249324
input_scale: Optional[torch.Tensor] = None,
250325
weight_scale: Optional[torch.Tensor] = None,
251326
) -> torch.Tensor:
327+
"""Legacy name for torch_fused_fp8_linear_all_reduce.
328+
329+
Kept for backward compatibility with existing code.
330+
Defaults to torch backend (demollm mode).
331+
"""
252332
out = torch.ops.auto_deploy.torch_quant_fp8_linear(
253333
input, weight_fp8, bias, input_scale, weight_scale
254334
)
255-
if trtllm_dist.is_trtllm_op_available():
256-
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM)
257335
dist.all_reduce(out, op=dist.ReduceOp.SUM)
258336
return out
259337

260338

261-
@fused_fp8_linear_all_reduce.register_fake
262-
def fused_fp8_linear_all_reduce_fake(
339+
@torch_quant_fused_fp8_linear_all_reduce.register_fake
340+
def torch_quant_fused_fp8_linear_all_reduce_fake(
263341
input: torch.Tensor,
264342
weight_fp8: torch.Tensor,
265343
bias: Optional[torch.Tensor] = None,

0 commit comments

Comments
 (0)