2929)
3030
3131
32- def _get_dist_ops ():
32+ def _get_dist_ops (backend : str ):
3333 """Get the appropriate distributed ops based on backend availability.
3434
35+ Args:
36+ backend: The distributed backend to use. Can be 'auto', 'trtllm', or 'torch'.
37+ 'auto' will automatically select based on availability.
38+
3539 Returns tuple of (all_gather_op, all_reduce_op) for the current backend.
3640 """
37- from ..distributed .trtllm import is_trtllm_op_available
41+ from ..custom_ops .trtllm_dist import is_trtllm_op_available
42+
43+ # Handle DistBackend enum or string
44+ if hasattr (backend , "value" ):
45+ backend = backend .value
3846
39- if is_trtllm_op_available () :
40- # Use TRT-LLM optimized ops in MPI mode
47+ if backend == "trtllm" :
48+ # Force TRT-LLM ops
4149 return (
4250 torch .ops .auto_deploy .trtllm_dist_all_gather .default ,
4351 torch .ops .auto_deploy .trtllm_dist_all_reduce .default ,
4452 )
45- else :
46- # Use PyTorch distributed ops in demollm mode
53+ elif backend == "torch" :
54+ # Force PyTorch distributed ops
4755 return (
4856 torch .ops .auto_deploy .torch_dist_all_gather .default ,
4957 torch .ops .auto_deploy .torch_dist_all_reduce .default ,
5058 )
59+ else : # auto
60+ # Automatically select based on availability
61+ if is_trtllm_op_available ():
62+ # Use TRT-LLM optimized ops in MPI mode
63+ return (
64+ torch .ops .auto_deploy .trtllm_dist_all_gather .default ,
65+ torch .ops .auto_deploy .trtllm_dist_all_reduce .default ,
66+ )
67+ else :
68+ # Use PyTorch distributed ops in demollm mode
69+ return (
70+ torch .ops .auto_deploy .torch_dist_all_gather .default ,
71+ torch .ops .auto_deploy .torch_dist_all_reduce .default ,
72+ )
5173
5274
5375def _load_hook (
@@ -251,6 +273,7 @@ def _insert_sharded_mamba(
251273 dim : int ,
252274 rank : int ,
253275 world_size : int ,
276+ dist_backend : str ,
254277 add_dist : bool = False ,
255278 min_local_shape : int = 1 ,
256279 weights_to_shard : Optional [list [str ]] = None ,
@@ -359,6 +382,7 @@ def _insert_sharded_mamba(
359382 dim = SplitDimension .COLUMN ,
360383 rank = rank ,
361384 world_size = world_size ,
385+ dist_backend = dist_backend ,
362386 add_dist = False ,
363387 min_local_shape = min_local_shape ,
364388 fused_weight_dims = entry_fused_dims ,
@@ -422,6 +446,7 @@ def _shard_parameter_node(
422446 dim : int ,
423447 rank : int ,
424448 world_size : int ,
449+ dist_backend : str ,
425450 add_dist : bool = False ,
426451 min_local_shape : int = 1 ,
427452 fused_weight_dims : Optional [list ] = None ,
@@ -507,7 +532,7 @@ def _shard_parameter_node(
507532 return
508533
509534 # figure out the right dist op (backend-aware)
510- all_gather_op , all_reduce_op = _get_dist_ops ()
535+ all_gather_op , all_reduce_op = _get_dist_ops (dist_backend )
511536 dist_lookup = {
512537 0 : (all_gather_op , - 1 ),
513538 1 : (all_reduce_op ,),
@@ -595,6 +620,7 @@ class WeightShardingInfo(ShardingTransformInfo):
595620 layer_type : LayerType = LayerType .MLP
596621 # used for TP sharding of fused weights
597622 fused_weight_dims : Optional [list ] = None
623+ dist_backend : str = "auto"
598624
599625 def quantization_cb (
600626 self ,
@@ -644,6 +670,7 @@ def apply(self, gm: GraphModule, node: Node) -> None:
644670 dim = self .split_dim .value ,
645671 rank = self .rank ,
646672 world_size = self .world_size ,
673+ dist_backend = self .dist_backend ,
647674 add_dist = self .dist_op is not None ,
648675 min_local_shape = self .min_local_shape ,
649676 fused_weight_dims = self .fused_weight_dims
@@ -658,6 +685,7 @@ def apply(self, gm: GraphModule, node: Node) -> None:
658685 dim = self .split_dim .value ,
659686 rank = self .rank ,
660687 world_size = self .world_size ,
688+ dist_backend = self .dist_backend ,
661689 add_dist = self .dist_op is not None ,
662690 min_local_shape = self .min_local_shape ,
663691 fused_weight_dims = self .fused_weight_dims ,
@@ -860,6 +888,7 @@ class BMMShardingInfo(ShardingTransformInfo):
860888 world_size : int
861889 start_idx : int
862890 end_idx : int
891+ dist_backend : str = "auto"
863892
864893 def validate (self , gm : GraphModule = None , node : Node = None ) -> bool :
865894 """Validate the transformation configuration."""
@@ -947,7 +976,7 @@ def slice_tensor(t: torch.Tensor) -> torch.Tensor:
947976 handle_tensor (node , rhs_tensor , 1 , self .start_idx , self .end_idx )
948977
949978 # Add all_gather node after BMM to collect results
950- all_gather_op , _ = _get_dist_ops ()
979+ all_gather_op , _ = _get_dist_ops (self . dist_backend )
951980 with gm .graph .inserting_after (node ):
952981 gather_node = gm .graph .call_function (
953982 all_gather_op ,
@@ -962,6 +991,7 @@ def _insert_sharded_moe(
962991 node : Node ,
963992 rank : int ,
964993 world_size : int ,
994+ dist_backend : str ,
965995 scale_names : Sequence [str ] = (),
966996):
967997 """Update the torch_moe node with sharded weight lists,
@@ -1036,7 +1066,7 @@ def get_partition(lst, world_size, rank):
10361066 node .args = tuple (args )
10371067
10381068 # -- add an all_reduce node --
1039- _ , all_reduce_op = _get_dist_ops ()
1069+ _ , all_reduce_op = _get_dist_ops (dist_backend )
10401070 with gm .graph .inserting_after (node ):
10411071 dist_node = gm .graph .call_function (all_reduce_op , args = (node ,))
10421072 node .replace_all_uses_with (dist_node )
@@ -1066,6 +1096,7 @@ def _insert_sharded_mxfp4_mlp_ep(
10661096 node : Node ,
10671097 rank : int ,
10681098 world_size : int ,
1099+ dist_backend : str ,
10691100):
10701101 """
10711102 Transform a call to auto_deploy::triton_mxfp4_moe into:
@@ -1107,7 +1138,7 @@ def _insert_sharded_mxfp4_mlp_ep(
11071138 node .args = args_ep
11081139
11091140 # Add a dist all-reduce after the op (sum partial results across EP ranks)
1110- _ , all_reduce_op = _get_dist_ops ()
1141+ _ , all_reduce_op = _get_dist_ops (dist_backend )
11111142 with gm .graph .inserting_after (node ):
11121143 red = gm .graph .call_function (all_reduce_op , args = (node ,))
11131144 node .replace_all_uses_with (red )
@@ -1120,6 +1151,7 @@ class EPShardingInfo(ShardingTransformInfo):
11201151
11211152 rank : int
11221153 world_size : int
1154+ dist_backend : str = "auto"
11231155
11241156 @classmethod
11251157 def from_node (cls , node : Node , ** kwargs ) -> "EPShardingInfo" :
@@ -1138,7 +1170,7 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
11381170
11391171 def apply (self , gm : GraphModule , node : Node ) -> None :
11401172 """Apply EP sharding transformation to the graph module."""
1141- _insert_sharded_moe (gm , node , self .rank , self .world_size , [])
1173+ _insert_sharded_moe (gm , node , self .rank , self .world_size , self . dist_backend , [])
11421174
11431175
11441176class MXFP4EPShardingInfo (EPShardingInfo ):
@@ -1152,7 +1184,7 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
11521184 return True
11531185
11541186 def apply (self , gm : GraphModule , node : Node ) -> None :
1155- _insert_sharded_mxfp4_mlp_ep (gm , node , self .rank , self .world_size )
1187+ _insert_sharded_mxfp4_mlp_ep (gm , node , self .rank , self .world_size , self . dist_backend )
11561188
11571189
11581190class FP8EPShardingInfo (EPShardingInfo , QuantizationShardingMixin ):
@@ -1168,7 +1200,9 @@ def scale_names(self) -> List[str]:
11681200 return ["input_scale" , "weight_scale" ]
11691201
11701202 def apply (self , gm : GraphModule , node : Node ) -> None :
1171- _insert_sharded_moe (gm , node , self .rank , self .world_size , self .scale_names ())
1203+ _insert_sharded_moe (
1204+ gm , node , self .rank , self .world_size , self .dist_backend , self .scale_names ()
1205+ )
11721206
11731207
11741208class NVFP4EPShardingInfo (EPShardingInfo , QuantizationShardingMixin ):
@@ -1184,7 +1218,9 @@ def scale_names(self) -> List[str]:
11841218 return ["input_scale" , "weight_scale" , "alpha" ]
11851219
11861220 def apply (self , gm : GraphModule , node : Node ) -> None :
1187- _insert_sharded_moe (gm , node , self .rank , self .world_size , self .scale_names ())
1221+ _insert_sharded_moe (
1222+ gm , node , self .rank , self .world_size , self .dist_backend , self .scale_names ()
1223+ )
11881224
11891225
11901226EP_SHARDING_RULES = [
@@ -1222,6 +1258,14 @@ class ShardingDim(Enum):
12221258 BMM = "bmm"
12231259
12241260
1261+ class DistBackend (Enum ):
1262+ """Enum for distributed backend."""
1263+
1264+ AUTO = "auto"
1265+ TRTLLM = "trtllm"
1266+ TORCH = "torch"
1267+
1268+
12251269class ShardingConfig (BaseModel ):
12261270 """Configuration for sharding the model."""
12271271
@@ -1237,6 +1281,7 @@ class ShardingConfig(BaseModel):
12371281 sharding_dims : List [ShardingDim ] = Field (
12381282 default_factory = lambda : [ShardingDim .SSM , ShardingDim .TP , ShardingDim .EP , ShardingDim .BMM ]
12391283 )
1284+ dist_backend : DistBackend = Field (default = DistBackend .AUTO )
12401285 weight_sharding_transforms : List [WeightShardingInfo ] = Field (default_factory = list )
12411286 parameter_update_transforms : List [ParameterUpdateInfo ] = Field (default_factory = list )
12421287 bmm_transforms : List [BMMShardingInfo ] = Field (default_factory = list )
0 commit comments