Skip to content

[Group Partitioner] leverage group partitioner for config-based partitioner #12845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/backend/canonical_partitioners/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ runtime.python_library(
deps = [
"//caffe2:torch",
"//executorch/exir/backend:partitioner",
":group_partitioner_lib",
],
)

Expand Down
78 changes: 68 additions & 10 deletions exir/backend/canonical_partitioners/config_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,24 @@
import torch
from executorch.exir.backend.backend_details import ExportedProgram
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_partitions_from_list_of_nodes,
generate_grouped_partitions_from_list_of_nodes,
)
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)

from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch.fx.passes.infra.partitioner import Partition


def is_constant_data(ep: ExportedProgram, node: torch.fx.Node) -> bool:
return (
is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node)
)


def format_target_name(target_name: str) -> str:
"""
We remove the dialect name space from the target name. We generally
Expand Down Expand Up @@ -100,6 +108,35 @@ def get_partition(
pass


class DSJ:
"""
Disjoint set union data structure used to find connected components in the graph.
"""

def __init__(self):
self.parent = {}

def find(self, x):
self.parent.setdefault(x, x)
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]

def union(self, x, y):
self.parent[self.find(x)] = self.find(y)

def contains(self, x):
return x in self.parent

def gen_groups(self):
groups = {}
for node in self.parent.keys():
root = self.find(node)
groups.setdefault(root, set()).add(node)

return [list(group) for group in groups.values()]


class ConfigerationBasedPartitioner(Partitioner):
def __init__(
self,
Expand Down Expand Up @@ -162,23 +199,44 @@ def filter_fn(node: torch.fx.Node) -> bool:
def get_matched_nodes_from_configs(
self, ep: ExportedProgram
) -> List[List[torch.fx.Node]]:
# disjoint set union for merging partitions
dsj = DSJ()

# gather supported nodes
matched_nodes = []
gm = ep.graph_module
for node in gm.graph.nodes:
if node.op == "call_function":
target = format_target_name(node.target.__name__)
if target in self.target_partitioner_configs:
node_config = self.target_partitioner_configs[target]
if node_config.check_constraints(node, ep):
matched_nodes.append(node_config.get_partition(node, ep))
if node.op != "call_function":
continue
target = format_target_name(node.target.__name__)

if target not in self.target_partitioner_configs:
continue

node_config = self.target_partitioner_configs[target]
if not node_config.check_constraints(node, ep):
continue

partition_candidate = node_config.get_partition(node, ep)
partition = []
for node in partition_candidate:
# partitioner infra copies constant data across partitions, so it
# is ok if this partition doesn't have it
if is_constant_data(ep, node) and dsj.contains(node):
continue
partition.append(node)

# Union overlaps into a single group
if len(partition) > 0:
dsj.find(partition[0])
for i in range(1, len(partition)):
dsj.union(partition[0], partition[i])

return matched_nodes
return dsj.gen_groups()

def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
matched_nodes = self.get_matched_nodes_from_configs(ep)
# create partitions
partitions = generate_partitions_from_list_of_nodes(
partitions = generate_grouped_partitions_from_list_of_nodes(
ep.graph_module,
matched_nodes,
)
Expand Down
48 changes: 48 additions & 0 deletions exir/backend/canonical_partitioners/pattern_op_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from typing import List, Optional

import torch

from executorch.exir.backend.canonical_partitioners.group_partitioner import (
GroupBasedPartitioner,
)
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
Expand Down Expand Up @@ -56,6 +60,50 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return partition_list


def generate_grouped_partitions_from_list_of_nodes(
graph_module: torch.fx.GraphModule,
pattern_list: Optional[List[List[torch.fx.Node]]] = None,
op_support: Optional[OperatorSupportBase] = None,
) -> List[Partition]:
final_op_support: Optional[OperatorSupportBase] = op_support

if pattern_list is not None:
# Tag all the nodes in these patterns
for node_list in pattern_list:
for node in node_list:
node.meta["match"] = True

class MatchTag(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.meta.get("match", False)

final_op_support = (
MatchTag()
if final_op_support is None
else any_chain(final_op_support, MatchTag())
)

assert (
final_op_support is not None
), "Did not give a pattern or OperatorSupportBase instance to partition with"

# Run the CapabilityBasedPartitioner to return the largest possible
# subgraphs containing the nodes with the tags
group_partitioner = GroupBasedPartitioner(
graph_module,
final_op_support,
node_groups=pattern_list,
allows_single_node_partition=True,
)
partition_list = group_partitioner.propose_partitions()

# Remove the metadata field we added
for partition in partition_list:
for node in partition.nodes:
node.meta.pop("match", False)
return partition_list


def generate_pattern_op_partitions(
graph_module: torch.fx.GraphModule,
patterns: Optional[List[torch.fx.Graph]] = None,
Expand Down
Loading