diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index cfbf97eca90..d6de54b22d2 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -292,7 +292,7 @@ def detect_sharding_from_factory_config( num_simple_shards = 0 num_row_col_shards = 0 - for lin_node in filtered_nodes(gm.graph.nodes, is_linear_op): + for lin_node in filtered_nodes(gm.graph.nodes, [is_linear_op, is_fake_quantized_linear_op]): # use node's weight name to get the module name module_name = lin_node.args[1].target @@ -368,7 +368,7 @@ def detect_sharding_from_factory_config( ) num_row_col_shards += 1 else: - ad_logger.warning("Invalid sharding config. Skipping.") + ad_logger.warning(f"Unsupported sharding action {config}. Skipping.") else: # TODO: local refers to hybrid EP+TP parallelism. Not supported yet. ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.") @@ -387,7 +387,19 @@ def detect_sharding_from_factory_config( ) num_simple_shards += 1 else: - ad_logger.warning("Invalid sharding config. Skipping.") + ad_logger.warning( + f"Unsupported sharding action {config}. Fallback to simple shard" + ) + sharding_config.tp_transforms.append( + TPShardingInfo.from_node( + lin_node, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op="all_gather", + min_local_shape=1, + ) + ) # after successful match, break the loop break diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 38348fe64f7..a4cfdb18fad 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -239,6 +239,12 @@ def filtered_nodes( for node in nodes: if target(node): yield node + elif isinstance(target, Iterable) and all(isinstance(t, Callable) for t in target): + for node in nodes: + for t in target: + if t(node): + yield node + break else: # Handle the case where target or ops contains operations operations = ops if ops is not None else target