Skip to content

Commit 7b73fdf

Browse files
James Reedfacebook-github-bot
James Reed
authored andcommitted
[FX] Fix retracing wrapped functions (pytorch#58061)
Summary: Pull Request resolved: pytorch#58061 Test Plan: Imported from OSS Reviewed By: yuhc Differential Revision: D28358801 Pulled By: jamesr66a fbshipit-source-id: c7c9a8a80e5bfe1eb1f6d2cf858ac7e57153a860
1 parent 5fa4541 commit 7b73fdf

12 files changed

+38
-15
lines changed

test/test_fx.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,18 @@ def forward(self, x: torch.Tensor):
327327
ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
328328
self.assertEqual(ref_batchnorm1d(input), m(input))
329329

330+
def test_wrapped_retrace(self):
331+
def to_trace(y):
332+
return wrapped_via_decorator(y)
333+
334+
m = symbolic_trace(to_trace)
335+
self.assertIn('wrapped_via_decorator', m.code)
336+
self.assertEqual(m(0), 1)
337+
338+
retraced = symbolic_trace(m)
339+
self.assertIn('wrapped_via_decorator', retraced.code)
340+
self.assertEqual(retraced(0), 1)
341+
330342
def test_graph_edit_with_proxy(self):
331343
class M(torch.nn.Module):
332344
def forward(self, a, b):
@@ -2547,7 +2559,7 @@ def f(b, a):
25472559

25482560

25492561
def run_getitem_target():
2550-
from torch.fx.symbolic_trace import _wrapped_methods_to_patch
2562+
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
25512563
_wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
25522564
try:
25532565
TestFX().getitem_inner()

test/test_fx_experimental.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
import numbers
77
from typing import Callable, Dict, Union, List, Optional
8-
from torch.fx.symbolic_trace import symbolic_trace
8+
from torch.fx._symbolic_trace import symbolic_trace
99
from torch.fx.graph_module import GraphModule
1010
from torch.fx.node import Node
1111
from torch.fx.experimental import graph_manipulation

torch/fx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def forward(self, x):
8282
'''
8383

8484
from .graph_module import GraphModule
85-
from .symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta
85+
from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta
8686
from .graph import Graph
8787
from .node import Node, map_arg
8888
from .proxy import Proxy

torch/fx/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ from .graph import Graph as Graph
22
from .graph_module import GraphModule as GraphModule
33
from .node import Node as Node, map_arg as map_arg
44
from .proxy import Proxy as Proxy
5-
from .symbolic_trace import Tracer as Tracer, symbolic_trace as symbolic_trace, wrap as wrap
5+
from ._symbolic_trace import Tracer as Tracer, symbolic_trace as symbolic_trace, wrap as wrap
66
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
77
from .subgraph_rewriter import replace_pattern as replace_pattern

torch/fx/symbolic_trace.py renamed to torch/fx/_symbolic_trace.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,9 @@ def wrapped(*args, **kwargs):
620620
"""
621621
proxy = _find_proxy(args, kwargs)
622622
if proxy is not None:
623-
return proxy.tracer.create_proxy('call_function', orig_fn, args, kwargs)
623+
return_proxy = proxy.tracer.create_proxy('call_function', orig_fn, args, kwargs)
624+
return_proxy.node.meta['is_wrapped'] = True
625+
return return_proxy
624626
return orig_fn(*args, **kwargs)
625627

626628
return wrapped

torch/fx/experimental/merge_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.fx.graph import Graph
44
from torch.fx.graph_module import GraphModule
55
from torch.fx.node import Node
6-
from torch.fx.symbolic_trace import symbolic_trace
6+
from torch.fx._symbolic_trace import symbolic_trace
77

88
import itertools
99
import operator

torch/fx/experimental/rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copy
55
from types import FunctionType
66
from typing import cast, Union, Callable, Dict, Optional, Any
7-
from torch.fx.symbolic_trace import Tracer
7+
from torch.fx._symbolic_trace import Tracer
88
from torch.fx.graph import Graph
99
from torch.jit.frontend import normalize_source_lines
1010
import torch

torch/fx/graph.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
if TYPE_CHECKING:
18-
from .graph_module import GraphModule
18+
from .graph_module import GraphModule # noqa: F401
1919

2020

2121
# Mapping of builtins to their `typing` equivalent.
@@ -805,6 +805,7 @@ def _python_code(self, root_module: str, namespace: _Namespace) -> PythonCode:
805805
free_vars: List[str] = []
806806
body: List[str] = []
807807
globals_: Dict[str, Any] = {}
808+
wrapped_fns: Dict[str, None] = {}
808809

809810
# Wrap string in list to pass by reference
810811
maybe_return_annotation : List[str] = ['']
@@ -919,6 +920,8 @@ def emit_node(node : Node):
919920
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
920921
return
921922
body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
923+
if node.meta.get('is_wrapped', False):
924+
wrapped_fns.setdefault(global_name)
922925
return
923926
elif node.op == 'call_module':
924927
assert isinstance(node.target, str)
@@ -960,18 +963,24 @@ def emit_node(node : Node):
960963
else:
961964
orig_args = free_vars
962965

966+
if len(wrapped_fns) > 0:
967+
wrap_name = add_global('wrap', torch.fx.wrap)
968+
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
969+
else:
970+
wrap_stmts = ''
971+
963972
# If the original function didn't have self as its first argument, we
964973
# would have added it.
965974
if len(orig_args) == 0 or orig_args[0] != 'self':
966975
orig_args.insert(0, 'self')
967976
code = ''.join(body)
968977
code = '\n'.join(' ' + line for line in code.split('\n'))
969978
fn_code = f"""
979+
{wrap_stmts}
980+
970981
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
971982
{code}"""
972-
973-
return PythonCode(fn_code,
974-
globals_)
983+
return PythonCode(fn_code, globals_)
975984

976985
def __str__(self) -> str:
977986
"""

torch/fx/graph_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(self, body):
9696
# Try to retrieve the forward source in a backward-compatible way
9797
CodeOnlyModule.forward = forward
9898

99-
from .symbolic_trace import Tracer
99+
from ._symbolic_trace import Tracer
100100

101101
# we shouldn't trace into any of the submodules, they were not
102102
# because they were not traced in the original GraphModule

torch/fx/interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .graph import Graph
33
from .node import Argument, Node, Target, map_arg, map_aggregate
44
from .proxy import Proxy
5-
from .symbolic_trace import Tracer
5+
from ._symbolic_trace import Tracer
66
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
77

88
class Interpreter:

torch/fx/subgraph_rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .graph_module import GraphModule
22
from .graph import Graph
33
from .node import Node
4-
from .symbolic_trace import symbolic_trace
4+
from ._symbolic_trace import symbolic_trace
55

66
import copy
77
from typing import Callable, Dict, List, NamedTuple, Optional, Set

torch/quantization/quantize_fx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch.fx import GraphModule
3-
from torch.fx.symbolic_trace import Tracer
3+
from torch.fx._symbolic_trace import Tracer
44
from torch.fx.node import Target, Node, Argument
55
from .fx import Fuser # noqa: F401
66
from .fx import Quantizer # noqa: F401

0 commit comments

Comments
 (0)