Skip to content

Commit 957a9b6

Browse files
SherlockNoMadpytorchmergebot
authored andcommitted
fx.replace_pattern accepts pattern/replacement as GraphModule (pytorch#88479)
Symbolic tracer is no longer the default tracer to produce fx graph. SubgraphRewriter should thus accept a raw GraphModule, rather than use symbolic tracer by default. Pull Request resolved: pytorch#88479 Approved by: https://github.com/jerryzh168
1 parent 4bb5c2c commit 957a9b6

2 files changed

+21
-9
lines changed

test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,4 @@ torch.fx.proxy.TracerBase.iter(self, obj: 'Proxy') -> Iterator
7171
torch.fx.proxy.TracerBase.keys(self, obj: 'Proxy') -> Any
7272
torch.fx.proxy.TracerBase.proxy(self, node: torch.fx.node.Node) -> 'Proxy'
7373
torch.fx.proxy.TracerBase.to_bool(self, obj: 'Proxy') -> bool
74-
torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Callable, replacement: Callable) -> List[torch.fx.subgraph_rewriter.Match]
74+
torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Union[Callable, torch.fx.graph_module.GraphModule], replacement: Union[Callable, torch.fx.graph_module.GraphModule]) -> List[torch.fx.subgraph_rewriter.Match]

torch/fx/subgraph_rewriter.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ._compatibility import compatibility
66

77
import copy
8-
from typing import Callable, Dict, List, NamedTuple, Optional, Set
8+
from typing import Callable, Dict, List, NamedTuple, Optional, Set, Union
99
import torch
1010

1111
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters']
@@ -65,7 +65,11 @@ def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Mo
6565

6666

6767
@compatibility(is_backward_compatible=True)
68-
def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]:
68+
def replace_pattern(
69+
gm: GraphModule,
70+
pattern: Union[Callable, GraphModule],
71+
replacement: Union[Callable, GraphModule]
72+
) -> List[Match]:
6973
"""
7074
Matches all possible non-overlapping sets of operators and their
7175
data dependencies (``pattern``) in the Graph of a GraphModule
@@ -187,8 +191,8 @@ def forward(self, x, w1, w2):
187191
@compatibility(is_backward_compatible=False)
188192
def replace_pattern_with_filters(
189193
gm: GraphModule,
190-
pattern: Callable,
191-
replacement: Callable,
194+
pattern: Union[Callable, GraphModule],
195+
replacement: Union[Callable, GraphModule],
192196
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]], # type: ignore[name-defined]
193197
) -> List[Match]:
194198
"""
@@ -205,8 +209,8 @@ def replace_pattern_with_filters(
205209

206210
def _replace_pattern(
207211
gm: GraphModule,
208-
pattern: Callable,
209-
replacement: Callable,
212+
pattern: Union[Callable, GraphModule],
213+
replacement: Union[Callable, GraphModule],
210214
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined]
211215
) -> List[Match]:
212216

@@ -217,8 +221,16 @@ def _replace_pattern(
217221

218222
# Get the graphs for `gm`, `pattern`, `replacement`
219223
original_graph: Graph = gm.graph
220-
pattern_graph: Graph = symbolic_trace(pattern).graph
221-
replacement_graph: Graph = symbolic_trace(replacement).graph
224+
225+
if isinstance(pattern, GraphModule):
226+
pattern_graph = pattern.graph
227+
else:
228+
pattern_graph = symbolic_trace(pattern).graph
229+
230+
if isinstance(replacement, GraphModule):
231+
replacement_graph = replacement.graph
232+
else:
233+
replacement_graph = symbolic_trace(replacement).graph
222234

223235
matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
224236
remove_overlapping_matches=True)

0 commit comments

Comments
 (0)