Skip to content

[Group Partitioner] Optimize Speed #12844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 25, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions exir/backend/canonical_partitioners/group_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
)
self.node_to_group = collections.defaultdict(int)
self.all_nodes_in_groups = set()
if node_groups:
if self.node_groups:
for i, group in enumerate(self.node_groups):
for node in group:
# Node is in multiple groups - not allowed
Expand All @@ -101,19 +101,25 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
p2_nodes = set(partitions_by_id[p2].nodes.keys())
combined_nodes = p1_nodes.union(p2_nodes)

for node in combined_nodes:
# Get all downstream nodes that are not in the combined partition
external_downstreams = {
n
for n in self.dependency_viewer.downstreams_of(node)
if n not in combined_nodes
}
user_nodes = []
# topologically, p2_nodes comes before p1_nodes, so we only
# need to check the downstream nodes of p2.
# Additionally, we don't need to check all the downstream nodes
# of p2, we only need to check the nodes directly outside of p2.
# example:
# partition[a --> b --> c] --> d --> e --> f
# we don't need to check [d, e, f] we only need to check [d] because
# the downstream users of [d] will include [e, f]
for node in p2_nodes:
for user in node.users:
if user not in combined_nodes:
user_nodes.append(user)

for external_node in user_nodes:
# Check if any external downstream nodes have downstream nodes in the combined partition
for external_node in external_downstreams:
downstream_nodes = self.dependency_viewer.downstreams_of(external_node)
if any(n in combined_nodes for n in downstream_nodes):
return False
downstream_nodes = self.dependency_viewer.downstreams_of(external_node)
if any(n in combined_nodes for n in downstream_nodes):
return False

return True

Expand All @@ -133,13 +139,30 @@ def _process_node_groups(
if not self.node_groups:
return group_to_partition_id

for i, group in enumerate(self.node_groups):
# Create a partition for each group
processed_nodes = set()

# We have to create the partitions in reverse topological order
# so we find the groups as we traverse backwards in the graph
# this likely needs to be combined with the process_remaining_nodes
# TODO: this currently doesn't work with _process_remaining_nodes so
# if a user provides grouped nodes with operatorsupport, then this will
# faile
for node in reversed(self.graph_module.graph.nodes):
if node not in self.node_to_group:
continue

if node in processed_nodes:
continue

group_idx = self.node_to_group[node]
group = self.node_groups[group_idx]

# Create a partition for group
partition_id = next(new_partition_id)
partition = Partition(id=partition_id, nodes=set())
partitions_by_id[partition_id] = partition
partitions_order[partition_id] = partition_id
group_to_partition_id[i] = partition_id
group_to_partition_id[group_idx] = partition_id

# Add all supported nodes from the group to the partition
for node in group:
Expand All @@ -164,6 +187,12 @@ def _process_node_groups(
partition_map[partition_id].add(target_id)
partition_map[partition_id].update(partition_map[target_id])

# all the nodes in the group have now been processed
# so skip if we encoutner them again in our rev topo
# iteration
for node in group:
processed_nodes.add(node)

return group_to_partition_id

def _process_remaining_nodes(
Expand Down Expand Up @@ -209,7 +238,6 @@ def _merge_partitions(

# Set to track removed partitions from initial static list so we can skip them
already_merged = set()

# Try to merge each pair of partitions
for i, p1 in enumerate(partition_ids):
# Skip if this partition has been already merged
Expand Down
Loading