5
5
from ._compatibility import compatibility
6
6
7
7
import copy
8
+ from dataclasses import dataclass
8
9
from typing import Callable , Dict , List , NamedTuple , Optional , Set , Union
9
10
import torch
10
11
11
- __all__ = ['Match' , 'replace_pattern' , 'replace_pattern_with_filters' ]
12
+ __all__ = ['Match' , 'replace_pattern' , 'replace_pattern_with_filters' , "ReplacedPatterns" ]
12
13
13
14
@compatibility (is_backward_compatible = True )
14
15
class Match (NamedTuple ):
@@ -17,6 +18,15 @@ class Match(NamedTuple):
17
18
# Maps nodes in the pattern subgraph to nodes in the larger graph
18
19
nodes_map : Dict [Node , Node ]
19
20
21
+ @compatibility (is_backward_compatible = False )
22
+ @dataclass
23
+ class ReplacedPatterns :
24
+ # Node from which the match was found
25
+ anchor : Node
26
+ # Maps nodes in the pattern subgraph to nodes in the larger graph
27
+ nodes_map : Dict [Node , Node ]
28
+ # List of nodes that were added into the graph
29
+ replacements : List [Node ]
20
30
21
31
def _replace_submodules (gm : GraphModule , replacement : torch .nn .Module ) -> None :
22
32
gm .delete_all_unused_submodules ()
@@ -183,7 +193,8 @@ def forward(self, x, w1, w2):
183
193
add_2 = add_1 + max_2
184
194
return add_2
185
195
"""
186
- return _replace_pattern (gm , pattern , replacement )
196
+ match_and_replacements = _replace_pattern (gm , pattern , replacement )
197
+ return [Match (anchor = m .anchor , nodes_map = m .nodes_map ) for m in match_and_replacements ]
187
198
188
199
189
200
# Experimental API, not backward compatible
@@ -193,7 +204,7 @@ def replace_pattern_with_filters(
193
204
pattern : Union [Callable , GraphModule ],
194
205
replacement : Union [Callable , GraphModule ],
195
206
match_filters : List [Callable [["InternalMatch" , Graph , Graph ], bool ]], # type: ignore[name-defined]
196
- ) -> List [Match ]:
207
+ ) -> List [ReplacedPatterns ]:
197
208
"""
198
209
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
199
210
@@ -211,8 +222,8 @@ def _replace_pattern(
211
222
gm : GraphModule ,
212
223
pattern : Union [Callable , GraphModule ],
213
224
replacement : Union [Callable , GraphModule ],
214
- match_filters : List [Callable [["InternalMatch" , Graph , Graph ], bool ]] = None # type: ignore[name-defined]
215
- ) -> List [Match ]:
225
+ match_filters : List [Callable [["InternalMatch" , Graph , Graph ], bool ]] = None , # type: ignore[name-defined]
226
+ ) -> List [ReplacedPatterns ]:
216
227
217
228
from torch .fx .passes .utils .matcher_utils import SubgraphMatcher , InternalMatch
218
229
@@ -248,6 +259,7 @@ def _replace_pattern(
248
259
# As we progressively replace nodes, we'll need to keep track of how the match results should change
249
260
match_changed_node : Dict [Node , Node ] = {}
250
261
262
+ match_and_replacements = []
251
263
for match in _matches :
252
264
253
265
# Build connecting between replacement graph's input and original graph input producer node
@@ -285,6 +297,20 @@ def _replace_pattern(
285
297
if isinstance (copied_returning_nodes , Node ):
286
298
copied_returning_nodes = (copied_returning_nodes , )
287
299
300
+ # Get a list of nodes that have been replaced into the graph
301
+ replacement_nodes = []
302
+
303
+ def get_replacement_nodes (curr_node : Node ):
304
+ nonlocal replacement_nodes
305
+ for arg in curr_node .args :
306
+ if isinstance (arg , Node ):
307
+ if arg not in val_map .values ():
308
+ get_replacement_nodes (arg )
309
+ replacement_nodes .append (curr_node )
310
+
311
+ for ret_node in copied_returning_nodes :
312
+ get_replacement_nodes (ret_node )
313
+
288
314
# Hook the output Node of the replacement subgraph in to the
289
315
# original Graph at the correct location
290
316
assert len (match .returning_nodes ) == len (copied_returning_nodes )
@@ -297,6 +323,14 @@ def _replace_pattern(
297
323
gn = match .nodes_map [node ]
298
324
gm .graph .erase_node (gn )
299
325
326
+ match_and_replacements .append (
327
+ ReplacedPatterns (
328
+ anchor = match .anchors [0 ],
329
+ nodes_map = match .nodes_map ,
330
+ replacements = replacement_nodes
331
+ )
332
+ )
333
+
300
334
# Update the passed-in GraphModule to reflect the new state of
301
335
# `original_graph`
302
336
gm .recompile ()
@@ -306,6 +340,4 @@ def _replace_pattern(
306
340
if isinstance (replacement , torch .nn .Module ):
307
341
_replace_submodules (gm , replacement )
308
342
309
- # Convert _matches: InternalMatch to Match to comply with backward compatibility of this function
310
- matches : List [Match ] = [Match (anchor = match .anchors [0 ], nodes_map = match .nodes_map ) for match in _matches ]
311
- return matches
343
+ return match_and_replacements
0 commit comments