Skip to content

Commit 6de216a

Browse files
angelayipytorchmergebot
authored andcommitted
[fx] Have replace_pattern return replaced nodes (pytorch#90244)
Summary: Modified replace_pattern in the subgraph rewriter to return a list of pairs of matches along with their corresponding replacement nodes in the modified graph (`List[Tuple[Match, List[Node]]]`). This allows us to easily modify the replaced nodes, including setting the metadata. Test Plan: CI Differential Revision: D41737056 Pull Request resolved: pytorch#90244 Approved by: https://github.com/SherlockNoMad
1 parent 4a1633c commit 6de216a

File tree

2 files changed

+51
-14
lines changed

2 files changed

+51
-14
lines changed

test/fx/test_subgraph_rewriter.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -819,17 +819,22 @@ def second_input_is_scalar(match, original_graph, pattern_graph):
819819
return False
820820
return True
821821

822-
def num_repalcement_node_found(traced):
823-
return sum(1 for node in traced.graph.nodes if node.target == torch.mul)
822+
def check_replacement_nodes(self, traced, matches):
823+
replacement_nodes_in_graph = [node for node in traced.graph.nodes if node.target == torch.mul]
824+
replacement_nodes_in_res = [r for m in matches for r in m.replacements]
825+
self.assertEqual(len(replacement_nodes_in_graph), len(replacement_nodes_in_res))
826+
self.assertEqual(replacement_nodes_in_graph, replacement_nodes_in_res)
827+
return len(replacement_nodes_in_graph)
824828

825829
# match without filter, should find 2 match
826830
traced = symbolic_trace(M())
827-
matches = subgraph_rewriter.replace_pattern(
831+
matches = subgraph_rewriter.replace_pattern_with_filters(
828832
traced,
829833
BinaryOpScalarReLUPattern,
830-
BinaryOpScalarReLUReplacement)
834+
BinaryOpScalarReLUReplacement,
835+
None)
831836
self.assertEqual(len(matches), 2)
832-
self.assertEqual(num_repalcement_node_found(traced), 2)
837+
self.assertEqual(check_replacement_nodes(self, traced, matches), 2)
833838

834839
# match with filter, should find 1 match
835840
traced = symbolic_trace(M())
@@ -839,7 +844,7 @@ def num_repalcement_node_found(traced):
839844
BinaryOpScalarReLUReplacement,
840845
[second_input_is_scalar])
841846
self.assertEqual(len(matches), 1)
842-
self.assertEqual(num_repalcement_node_found(traced), 1)
847+
self.assertEqual(check_replacement_nodes(self, traced, matches), 1)
843848

844849
def test_matching_pattern_with_list_type_arg(self):
845850
class M(torch.nn.Module):

torch/fx/subgraph_rewriter.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from ._compatibility import compatibility
66

77
import copy
8+
from dataclasses import dataclass
89
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Union
910
import torch
1011

11-
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters']
12+
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"]
1213

1314
@compatibility(is_backward_compatible=True)
1415
class Match(NamedTuple):
@@ -17,6 +18,15 @@ class Match(NamedTuple):
1718
# Maps nodes in the pattern subgraph to nodes in the larger graph
1819
nodes_map: Dict[Node, Node]
1920

21+
@compatibility(is_backward_compatible=False)
22+
@dataclass
23+
class ReplacedPatterns:
24+
# Node from which the match was found
25+
anchor: Node
26+
# Maps nodes in the pattern subgraph to nodes in the larger graph
27+
nodes_map: Dict[Node, Node]
28+
# List of nodes that were added into the graph
29+
replacements: List[Node]
2030

2131
def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None:
2232
gm.delete_all_unused_submodules()
@@ -183,7 +193,8 @@ def forward(self, x, w1, w2):
183193
add_2 = add_1 + max_2
184194
return add_2
185195
"""
186-
return _replace_pattern(gm, pattern, replacement)
196+
match_and_replacements = _replace_pattern(gm, pattern, replacement)
197+
return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements]
187198

188199

189200
# Experimental API, not backward compatible
@@ -193,7 +204,7 @@ def replace_pattern_with_filters(
193204
pattern: Union[Callable, GraphModule],
194205
replacement: Union[Callable, GraphModule],
195206
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]], # type: ignore[name-defined]
196-
) -> List[Match]:
207+
) -> List[ReplacedPatterns]:
197208
"""
198209
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
199210
@@ -211,8 +222,8 @@ def _replace_pattern(
211222
gm: GraphModule,
212223
pattern: Union[Callable, GraphModule],
213224
replacement: Union[Callable, GraphModule],
214-
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined]
215-
) -> List[Match]:
225+
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined]
226+
) -> List[ReplacedPatterns]:
216227

217228
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
218229

@@ -248,6 +259,7 @@ def _replace_pattern(
248259
# As we progressively replace nodes, we'll need to keep track of how the match results should change
249260
match_changed_node: Dict[Node, Node] = {}
250261

262+
match_and_replacements = []
251263
for match in _matches:
252264

253265
# Build connecting between replacement graph's input and original graph input producer node
@@ -285,6 +297,20 @@ def _replace_pattern(
285297
if isinstance(copied_returning_nodes, Node):
286298
copied_returning_nodes = (copied_returning_nodes, )
287299

300+
# Get a list of nodes that have been replaced into the graph
301+
replacement_nodes = []
302+
303+
def get_replacement_nodes(curr_node: Node):
304+
nonlocal replacement_nodes
305+
for arg in curr_node.args:
306+
if isinstance(arg, Node):
307+
if arg not in val_map.values():
308+
get_replacement_nodes(arg)
309+
replacement_nodes.append(curr_node)
310+
311+
for ret_node in copied_returning_nodes:
312+
get_replacement_nodes(ret_node)
313+
288314
# Hook the output Node of the replacement subgraph in to the
289315
# original Graph at the correct location
290316
assert len(match.returning_nodes) == len(copied_returning_nodes)
@@ -297,6 +323,14 @@ def _replace_pattern(
297323
gn = match.nodes_map[node]
298324
gm.graph.erase_node(gn)
299325

326+
match_and_replacements.append(
327+
ReplacedPatterns(
328+
anchor=match.anchors[0],
329+
nodes_map=match.nodes_map,
330+
replacements=replacement_nodes
331+
)
332+
)
333+
300334
# Update the passed-in GraphModule to reflect the new state of
301335
# `original_graph`
302336
gm.recompile()
@@ -306,6 +340,4 @@ def _replace_pattern(
306340
if isinstance(replacement, torch.nn.Module):
307341
_replace_submodules(gm, replacement)
308342

309-
# Convert _matches: InternalMatch to Match to comply with backward compatibility of this function
310-
matches: List[Match] = [Match(anchor=match.anchors[0], nodes_map=match.nodes_map) for match in _matches]
311-
return matches
343+
return match_and_replacements

0 commit comments

Comments
 (0)