Skip to content

Commit aeb7410

Browse files
committed
added dist_backend arg
Signed-off-by: Eran Geva <[email protected]>
1 parent 572a89c commit aeb7410

File tree

4 files changed

+249
-14
lines changed

4 files changed

+249
-14
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ transforms:
7979
sharding_source: ['factory','heuristic']
8080
support_partial_config: true
8181
sharding_dims: ['tp', 'ep', 'bmm']
82+
dist_backend: auto
8283
requires_shape_prop: true
8384
sharding_transform_executor:
8485
stage: sharding

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from ...utils.sharding_utils import (
4141
BMMShardingInfo,
42+
DistBackend,
4243
EPShardingInfo,
4344
LayerType,
4445
ParameterUpdateInfo,
@@ -134,6 +135,7 @@ def _process_simple_shard(
134135
world_size=world_size,
135136
dist_op="all_gather",
136137
min_local_shape=1,
138+
dist_backend=sharding_config.dist_backend,
137139
)
138140
)
139141
)
@@ -152,6 +154,7 @@ class ShardingTransformConfig(TransformConfig):
152154
sharding_dims: List[ShardingDim] = Field(
153155
default_factory=lambda: [ShardingDim.SSM, ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM]
154156
)
157+
dist_backend: DistBackend = Field(default=DistBackend.AUTO)
155158

156159

157160
@TransformRegistry.register("detect_sharding")
@@ -209,6 +212,7 @@ def _apply(
209212
sharding_config.support_partial_config = self.config.support_partial_config
210213
sharding_config.sharding_dims = self.config.sharding_dims
211214
sharding_config.sharding_source = self.config.sharding_source
215+
sharding_config.dist_backend = self.config.dist_backend
212216

213217
sharding_config.validate_config()
214218

@@ -348,6 +352,7 @@ def _process_ssm_sharding(
348352
dist_op=None,
349353
min_local_shape=min_local_shape,
350354
fused_weight_dims=fused_weight_dims["in_proj"],
355+
dist_backend=sharding_config.dist_backend,
351356
)
352357
)
353358

@@ -386,6 +391,7 @@ def _process_ssm_sharding(
386391
dist_op=None,
387392
min_local_shape=min_local_shape,
388393
fused_weight_dims=fused_dims,
394+
dist_backend=sharding_config.dist_backend,
389395
)
390396
)
391397

@@ -422,6 +428,7 @@ def _process_ssm_sharding(
422428
rank=rank,
423429
world_size=world_size,
424430
dist_op="all_reduce",
431+
dist_backend=sharding_config.dist_backend,
425432
)
426433
)
427434
return 1
@@ -448,6 +455,7 @@ def _process_column_sharding(
448455
world_size=world_size,
449456
dist_op=None, # for column sharding, no dist op is performed
450457
min_local_shape=min_local_shape,
458+
dist_backend=sharding_config.dist_backend,
451459
)
452460
)
453461

@@ -581,6 +589,7 @@ def detect_sharding_from_factory_config(
581589
world_size=world_size,
582590
dist_op=None,
583591
min_local_shape=min_local_shape,
592+
dist_backend=sharding_config.dist_backend,
584593
)
585594
)
586595
num_row_col_shards += 1
@@ -593,6 +602,7 @@ def detect_sharding_from_factory_config(
593602
world_size=world_size,
594603
dist_op="all_reduce",
595604
min_local_shape=min_local_shape,
605+
dist_backend=sharding_config.dist_backend,
596606
)
597607
)
598608
num_row_col_shards += 1
@@ -606,6 +616,7 @@ def detect_sharding_from_factory_config(
606616
dist_op=None,
607617
min_local_shape=min_local_shape,
608618
layer_type=LayerType.MAMBA,
619+
dist_backend=sharding_config.dist_backend,
609620
)
610621
)
611622
num_row_col_shards += 1
@@ -626,6 +637,7 @@ def detect_sharding_from_factory_config(
626637
world_size=world_size,
627638
dist_op=None,
628639
min_local_shape=min_local_shape,
640+
dist_backend=sharding_config.dist_backend,
629641
)
630642
)
631643
elif col_row_action == "rowwise":
@@ -637,6 +649,7 @@ def detect_sharding_from_factory_config(
637649
world_size=world_size,
638650
dist_op="all_reduce",
639651
min_local_shape=min_local_shape,
652+
dist_backend=sharding_config.dist_backend,
640653
)
641654
)
642655
num_row_col_shards += 1
@@ -951,6 +964,7 @@ def detect_column_row_shard(
951964
world_size=world_size,
952965
dist_op="all_reduce",
953966
min_local_shape=min_local_shape,
967+
dist_backend=sharding_config.dist_backend,
954968
)
955969
)
956970

@@ -1028,6 +1042,7 @@ def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Tra
10281042
world_size=world_size,
10291043
start_idx=start_idx,
10301044
end_idx=end_idx,
1045+
dist_backend=sharding_config.dist_backend,
10311046
)
10321047
)
10331048
ad_logger.debug(
@@ -1070,6 +1085,7 @@ def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Transfo
10701085
node,
10711086
rank=rank,
10721087
world_size=world_size,
1088+
dist_backend=sharding_config.dist_backend,
10731089
)
10741090
)
10751091
num_moe_patterns += 1

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,47 @@
2929
)
3030

3131

32-
def _get_dist_ops():
32+
def _get_dist_ops(backend: str):
3333
"""Get the appropriate distributed ops based on backend availability.
3434
35+
Args:
36+
backend: The distributed backend to use. Can be 'auto', 'trtllm', or 'torch'.
37+
'auto' will automatically select based on availability.
38+
3539
Returns tuple of (all_gather_op, all_reduce_op) for the current backend.
3640
"""
37-
from ..distributed.trtllm import is_trtllm_op_available
41+
from ..custom_ops.trtllm_dist import is_trtllm_op_available
42+
43+
# Handle DistBackend enum or string
44+
if hasattr(backend, "value"):
45+
backend = backend.value
3846

39-
if is_trtllm_op_available():
40-
# Use TRT-LLM optimized ops in MPI mode
47+
if backend == "trtllm":
48+
# Force TRT-LLM ops
4149
return (
4250
torch.ops.auto_deploy.trtllm_dist_all_gather.default,
4351
torch.ops.auto_deploy.trtllm_dist_all_reduce.default,
4452
)
45-
else:
46-
# Use PyTorch distributed ops in demollm mode
53+
elif backend == "torch":
54+
# Force PyTorch distributed ops
4755
return (
4856
torch.ops.auto_deploy.torch_dist_all_gather.default,
4957
torch.ops.auto_deploy.torch_dist_all_reduce.default,
5058
)
59+
else: # auto
60+
# Automatically select based on availability
61+
if is_trtllm_op_available():
62+
# Use TRT-LLM optimized ops in MPI mode
63+
return (
64+
torch.ops.auto_deploy.trtllm_dist_all_gather.default,
65+
torch.ops.auto_deploy.trtllm_dist_all_reduce.default,
66+
)
67+
else:
68+
# Use PyTorch distributed ops in demollm mode
69+
return (
70+
torch.ops.auto_deploy.torch_dist_all_gather.default,
71+
torch.ops.auto_deploy.torch_dist_all_reduce.default,
72+
)
5173

5274

5375
def _load_hook(
@@ -251,6 +273,7 @@ def _insert_sharded_mamba(
251273
dim: int,
252274
rank: int,
253275
world_size: int,
276+
dist_backend: str,
254277
add_dist: bool = False,
255278
min_local_shape: int = 1,
256279
weights_to_shard: Optional[list[str]] = None,
@@ -359,6 +382,7 @@ def _insert_sharded_mamba(
359382
dim=SplitDimension.COLUMN,
360383
rank=rank,
361384
world_size=world_size,
385+
dist_backend=dist_backend,
362386
add_dist=False,
363387
min_local_shape=min_local_shape,
364388
fused_weight_dims=entry_fused_dims,
@@ -422,6 +446,7 @@ def _shard_parameter_node(
422446
dim: int,
423447
rank: int,
424448
world_size: int,
449+
dist_backend: str,
425450
add_dist: bool = False,
426451
min_local_shape: int = 1,
427452
fused_weight_dims: Optional[list] = None,
@@ -507,7 +532,7 @@ def _shard_parameter_node(
507532
return
508533

509534
# figure out the right dist op (backend-aware)
510-
all_gather_op, all_reduce_op = _get_dist_ops()
535+
all_gather_op, all_reduce_op = _get_dist_ops(dist_backend)
511536
dist_lookup = {
512537
0: (all_gather_op, -1),
513538
1: (all_reduce_op,),
@@ -595,6 +620,7 @@ class WeightShardingInfo(ShardingTransformInfo):
595620
layer_type: LayerType = LayerType.MLP
596621
# used for TP sharding of fused weights
597622
fused_weight_dims: Optional[list] = None
623+
dist_backend: str = "auto"
598624

599625
def quantization_cb(
600626
self,
@@ -644,6 +670,7 @@ def apply(self, gm: GraphModule, node: Node) -> None:
644670
dim=self.split_dim.value,
645671
rank=self.rank,
646672
world_size=self.world_size,
673+
dist_backend=self.dist_backend,
647674
add_dist=self.dist_op is not None,
648675
min_local_shape=self.min_local_shape,
649676
fused_weight_dims=self.fused_weight_dims
@@ -658,6 +685,7 @@ def apply(self, gm: GraphModule, node: Node) -> None:
658685
dim=self.split_dim.value,
659686
rank=self.rank,
660687
world_size=self.world_size,
688+
dist_backend=self.dist_backend,
661689
add_dist=self.dist_op is not None,
662690
min_local_shape=self.min_local_shape,
663691
fused_weight_dims=self.fused_weight_dims,
@@ -860,6 +888,7 @@ class BMMShardingInfo(ShardingTransformInfo):
860888
world_size: int
861889
start_idx: int
862890
end_idx: int
891+
dist_backend: str = "auto"
863892

864893
def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
865894
"""Validate the transformation configuration."""
@@ -947,7 +976,7 @@ def slice_tensor(t: torch.Tensor) -> torch.Tensor:
947976
handle_tensor(node, rhs_tensor, 1, self.start_idx, self.end_idx)
948977

949978
# Add all_gather node after BMM to collect results
950-
all_gather_op, _ = _get_dist_ops()
979+
all_gather_op, _ = _get_dist_ops(self.dist_backend)
951980
with gm.graph.inserting_after(node):
952981
gather_node = gm.graph.call_function(
953982
all_gather_op,
@@ -962,6 +991,7 @@ def _insert_sharded_moe(
962991
node: Node,
963992
rank: int,
964993
world_size: int,
994+
dist_backend: str,
965995
scale_names: Sequence[str] = (),
966996
):
967997
"""Update the torch_moe node with sharded weight lists,
@@ -1036,7 +1066,7 @@ def get_partition(lst, world_size, rank):
10361066
node.args = tuple(args)
10371067

10381068
# -- add an all_reduce node --
1039-
_, all_reduce_op = _get_dist_ops()
1069+
_, all_reduce_op = _get_dist_ops(dist_backend)
10401070
with gm.graph.inserting_after(node):
10411071
dist_node = gm.graph.call_function(all_reduce_op, args=(node,))
10421072
node.replace_all_uses_with(dist_node)
@@ -1066,6 +1096,7 @@ def _insert_sharded_mxfp4_mlp_ep(
10661096
node: Node,
10671097
rank: int,
10681098
world_size: int,
1099+
dist_backend: str,
10691100
):
10701101
"""
10711102
Transform a call to auto_deploy::triton_mxfp4_moe into:
@@ -1107,7 +1138,7 @@ def _insert_sharded_mxfp4_mlp_ep(
11071138
node.args = args_ep
11081139

11091140
# Add a dist all-reduce after the op (sum partial results across EP ranks)
1110-
_, all_reduce_op = _get_dist_ops()
1141+
_, all_reduce_op = _get_dist_ops(dist_backend)
11111142
with gm.graph.inserting_after(node):
11121143
red = gm.graph.call_function(all_reduce_op, args=(node,))
11131144
node.replace_all_uses_with(red)
@@ -1120,6 +1151,7 @@ class EPShardingInfo(ShardingTransformInfo):
11201151

11211152
rank: int
11221153
world_size: int
1154+
dist_backend: str = "auto"
11231155

11241156
@classmethod
11251157
def from_node(cls, node: Node, **kwargs) -> "EPShardingInfo":
@@ -1138,7 +1170,7 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
11381170

11391171
def apply(self, gm: GraphModule, node: Node) -> None:
11401172
"""Apply EP sharding transformation to the graph module."""
1141-
_insert_sharded_moe(gm, node, self.rank, self.world_size, [])
1173+
_insert_sharded_moe(gm, node, self.rank, self.world_size, self.dist_backend, [])
11421174

11431175

11441176
class MXFP4EPShardingInfo(EPShardingInfo):
@@ -1152,7 +1184,7 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
11521184
return True
11531185

11541186
def apply(self, gm: GraphModule, node: Node) -> None:
1155-
_insert_sharded_mxfp4_mlp_ep(gm, node, self.rank, self.world_size)
1187+
_insert_sharded_mxfp4_mlp_ep(gm, node, self.rank, self.world_size, self.dist_backend)
11561188

11571189

11581190
class FP8EPShardingInfo(EPShardingInfo, QuantizationShardingMixin):
@@ -1168,7 +1200,9 @@ def scale_names(self) -> List[str]:
11681200
return ["input_scale", "weight_scale"]
11691201

11701202
def apply(self, gm: GraphModule, node: Node) -> None:
1171-
_insert_sharded_moe(gm, node, self.rank, self.world_size, self.scale_names())
1203+
_insert_sharded_moe(
1204+
gm, node, self.rank, self.world_size, self.dist_backend, self.scale_names()
1205+
)
11721206

11731207

11741208
class NVFP4EPShardingInfo(EPShardingInfo, QuantizationShardingMixin):
@@ -1184,7 +1218,9 @@ def scale_names(self) -> List[str]:
11841218
return ["input_scale", "weight_scale", "alpha"]
11851219

11861220
def apply(self, gm: GraphModule, node: Node) -> None:
1187-
_insert_sharded_moe(gm, node, self.rank, self.world_size, self.scale_names())
1221+
_insert_sharded_moe(
1222+
gm, node, self.rank, self.world_size, self.dist_backend, self.scale_names()
1223+
)
11881224

11891225

11901226
EP_SHARDING_RULES = [
@@ -1222,6 +1258,14 @@ class ShardingDim(Enum):
12221258
BMM = "bmm"
12231259

12241260

1261+
class DistBackend(Enum):
1262+
"""Enum for distributed backend."""
1263+
1264+
AUTO = "auto"
1265+
TRTLLM = "trtllm"
1266+
TORCH = "torch"
1267+
1268+
12251269
class ShardingConfig(BaseModel):
12261270
"""Configuration for sharding the model."""
12271271

@@ -1237,6 +1281,7 @@ class ShardingConfig(BaseModel):
12371281
sharding_dims: List[ShardingDim] = Field(
12381282
default_factory=lambda: [ShardingDim.SSM, ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM]
12391283
)
1284+
dist_backend: DistBackend = Field(default=DistBackend.AUTO)
12401285
weight_sharding_transforms: List[WeightShardingInfo] = Field(default_factory=list)
12411286
parameter_update_transforms: List[ParameterUpdateInfo] = Field(default_factory=list)
12421287
bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list)

0 commit comments

Comments
 (0)