Skip to content

Commit 6a4bf3b

Browse files
jon-chuangpytorchmergebot
authored andcommitted
feat(fx): make_fx should be aware of functions wrapped with @fx.wrap (pytorch#93273)
Fixes pytorch#89421 The strategy is to patch the given function wrapped with `@torch.fx.wrap` so that if a tensor tracer is active, we will `proxy_call` the function. `proxy_call` will also skip certain checks if the function to proxy call is not a torch op (checked with `isinstance(.., OpOverload)`. @IvanYashchuk @ezyang @Chillee Pull Request resolved: pytorch#93273 Approved by: https://github.com/ezyang
1 parent dd8662d commit 6a4bf3b

File tree

4 files changed

+100
-27
lines changed

4 files changed

+100
-27
lines changed

test/test_fx.py

+40
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch.fx.passes import shape_prop
3232
from torch.fx.immutable_collections import immutable_dict, immutable_list
3333
from torch.fx.experimental.rewriter import RewritingTracer
34+
from torch.fx.experimental.proxy_tensor import make_fx
3435
from torch.fx.operator_schemas import get_signature_for_torch_op
3536
from copy import deepcopy
3637
from collections import namedtuple
@@ -477,6 +478,45 @@ def to_trace(y):
477478
self.assertIn('wrapped_decorated_fn', m.code)
478479
self.assertEqual(m(1), 1)
479480

481+
@unittest.skipIf(sys.version_info >= (3, 11, 0), "FX currently does not have 3.11 support")
482+
def test_wrap_with_make_fx(self):
483+
def to_trace(y):
484+
return a_lifted_leaf((4, y), 3) * a_lifted_leaf((3, 4), 5) * a_lifted_leaf((y, y), y)
485+
486+
expected_code = """def forward(self, y_1):
487+
a_lifted_leaf = __main___a_lifted_leaf((4, y_1), 3)
488+
a_lifted_leaf_1 = __main___a_lifted_leaf((3, 4), 5)
489+
mul = torch.ops.aten.mul.Tensor(a_lifted_leaf, 12); a_lifted_leaf = None
490+
a_lifted_leaf_2 = __main___a_lifted_leaf((y_1, y_1), y_1); y_1 = None
491+
mul_1 = torch.ops.aten.mul.Tensor(mul, a_lifted_leaf_2); mul = a_lifted_leaf_2 = None
492+
return mul_1"""
493+
494+
m = make_fx(to_trace, tracing_mode="real")(torch.tensor([10]))
495+
self.assertIn('a_lifted_leaf', m.code)
496+
# aten.add.Tensor should be internal to `a_lifted_leaf` when some of the parameters are tensors.
497+
# However, it should not be traced as the function is marked as opaque.
498+
self.assertNotIn('aten.add.Tensor', m.code)
499+
self.assertExpectedInline(
500+
m.code.strip(),
501+
expected_code
502+
)
503+
504+
m = make_fx(to_trace, tracing_mode="fake")(torch.tensor([10]))
505+
self.assertIn('a_lifted_leaf', m.code)
506+
self.assertNotIn('aten.add.Tensor', m.code)
507+
self.assertExpectedInline(
508+
m.code.strip(),
509+
expected_code
510+
)
511+
512+
m = make_fx(to_trace, tracing_mode="symbolic")(torch.tensor([10]))
513+
self.assertIn('a_lifted_leaf', m.code)
514+
self.assertNotIn('aten.add.Tensor', m.code)
515+
self.assertExpectedInline(
516+
m.code.strip(),
517+
expected_code
518+
)
519+
480520
def test_graph_edit_with_proxy(self):
481521
class M(torch.nn.Module):
482522
def forward(self, a, b):

torch/fx/_symbolic_trace.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,18 @@ def wrapped(*args, **kwargs):
849849
)
850850
return_proxy.node.meta["is_wrapped"] = True
851851
return return_proxy
852+
853+
# import here to avoid circular imports
854+
from .experimental.proxy_tensor import get_innermost_proxy_mode, proxy_call, disable_proxy_modes_tracing
855+
856+
# If there is no input with proxy, see if we are tracing with proxy tensors
857+
proxy_mode = get_innermost_proxy_mode()
858+
if proxy_mode is not None:
859+
# Disable tracing of the interior of the wrapped fn while evaluating
860+
with disable_proxy_modes_tracing():
861+
out = proxy_call(proxy_mode, orig_fn, args, kwargs)
862+
return out
863+
852864
return orig_fn(*args, **kwargs)
853865

854866
return wrapped
@@ -868,6 +880,18 @@ def wrapped(*args, **kwargs):
868880
proxy = _find_proxy(args, kwargs)
869881
if proxy is not None:
870882
return proxy.tracer.create_proxy("call_method", name, args, kwargs)
883+
884+
# import here to avoid circular imports
885+
from .experimental.proxy_tensor import get_innermost_proxy_mode, proxy_call, disable_proxy_modes_tracing
886+
887+
# If there is no input with proxy, see if we are tracing with proxy tensors
888+
proxy_mode = get_innermost_proxy_mode()
889+
if proxy_mode is not None:
890+
# Disable tracing of the interior of the wrapped method while evaluating
891+
with disable_proxy_modes_tracing():
892+
out = proxy_call(proxy_mode, orig_fn, args, kwargs)
893+
return out
894+
871895
return orig_fn(*args, **kwargs)
872896

873897
return wrapped
@@ -913,7 +937,7 @@ def patch(
913937
"""
914938
Replace frame_dict[name] with new_fn until we exit the context manager.
915939
"""
916-
new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
940+
setattr(new_fn, "__fx_already_patched", deduplicate) # noqa: B010
917941
if name not in frame_dict and hasattr(builtins, name):
918942
self.patches_made.append(_PatchedFnDel(frame_dict, name, None))
919943
elif getattr(frame_dict[name], "__fx_already_patched", False):
@@ -923,19 +947,21 @@ def patch(
923947
_PatchedFnSetItem(frame_dict, name, frame_dict[name])
924948
)
925949
frame_dict[name] = new_fn
950+
assert(getattr(frame_dict[name], "__fx_already_patched", False) == deduplicate)
926951

927952
def patch_method(
928953
self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True
929954
):
930955
"""
931956
Replace object_or_dict.name with new_fn until we exit the context manager.
932957
"""
933-
new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
958+
setattr(new_fn, "__fx_already_patched", deduplicate) # noqa: B010
934959
orig_fn = getattr(cls, name)
935960
if getattr(orig_fn, "__fx_already_patched", False):
936961
return # already patched, no need to do it again
937962
self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn))
938963
setattr(cls, name, new_fn)
964+
assert(getattr(getattr(cls, name), "__fx_already_patched", False) == deduplicate)
939965

940966
def visit_once(self, thing: Any):
941967
"""Return True on the first call to with thing, otherwise false"""

torch/fx/experimental/proxy_tensor.py

+31-24
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ def fetch_tensor_proxy(tracer):
235235
HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter)
236236

237237
def proxy_call(proxy_mode, func, args, kwargs):
238+
# `__torch_dispatch__` is only called on torch ops, which must subclass `OpOverload`
239+
# We treat all other functions as an `external_call`, for instance, a function decorated
240+
# with `@torch.tx.wrap`
241+
external_call = not isinstance(func, torch._ops.OpOverload)
242+
238243
def can_handle_tensor(x):
239244
return type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
240245

@@ -243,17 +248,17 @@ def can_handle_tensor(x):
243248
if not pytree.tree_all_only(torch.Tensor, can_handle_tensor, (args, kwargs)):
244249
return NotImplemented
245250

246-
if func in CURRENT_DECOMPOSITION_TABLE:
251+
if not external_call:
252+
if func in CURRENT_DECOMPOSITION_TABLE:
253+
with proxy_mode:
254+
r = CURRENT_DECOMPOSITION_TABLE[func](*args, **kwargs)
255+
if r is not NotImplemented:
256+
return r
247257
with proxy_mode:
248-
r = CURRENT_DECOMPOSITION_TABLE[func](*args, **kwargs)
258+
r = func.decompose(*args, **kwargs)
249259
if r is not NotImplemented:
250260
return r
251261

252-
with proxy_mode:
253-
r = func.decompose(*args, **kwargs)
254-
if r is not NotImplemented:
255-
return r
256-
257262
tracer = proxy_mode.tracer
258263
f_args, f_kwargs = pytree.tree_map_only(torch.Tensor, fetch_tensor_proxy(tracer), (args, kwargs))
259264

@@ -266,8 +271,7 @@ def can_handle_tensor(x):
266271
# this can happen
267272
and pytree.tree_all_only((SymInt, SymFloat, SymBool), lambda _: False, (args, kwargs))
268273
)
269-
270-
if torch.Tag.data_dependent_output in func.tags: # type: ignore[attr-defined]
274+
if not external_call and torch.Tag.data_dependent_output in func.tags: # type: ignore[attr-defined]
271275
# Check if all of the Tensor inputs are constants
272276
if all_constant:
273277
const_args, const_kwargs = pytree.tree_map_only(
@@ -327,20 +331,23 @@ def can_handle_tensor(x):
327331
if func is torch.ops.aten.lift_fresh.default:
328332
func = torch.ops.aten.lift_fresh_copy.default
329333

330-
proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs,
331-
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__))
332-
333-
# This makes DCE marginally less likely to DCE inplace operations.
334-
# It is not strictly necessary
335-
# Kind of a hacky way to test if an op is in-place or not
336-
if func.overloadpacket.__name__[-1] == "_" and func.overloadpacket.__name__[0] != "_":
337-
if isinstance(args[0], List):
338-
# e.g., c10d::allreduce_ returns a list of tensors as the first element
339-
# in the output.
340-
for i, a in enumerate(args[0]):
341-
a.proxy = proxy_out[0][i]
342-
else:
343-
args[0].proxy = proxy_out
334+
if external_call:
335+
proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs, name=func.__name__)
336+
else:
337+
proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs,
338+
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__))
339+
340+
# This makes DCE marginally less likely to DCE inplace operations.
341+
# It is not strictly necessary
342+
# Kind of a hacky way to test if an op is in-place or not
343+
if func.overloadpacket.__name__[-1] == "_" and func.overloadpacket.__name__[0] != "_":
344+
if isinstance(args[0], List):
345+
# e.g., c10d::allreduce_ returns a list of tensors as the first element
346+
# in the output.
347+
for i, a in enumerate(args[0]):
348+
a.proxy = proxy_out[0][i]
349+
else:
350+
args[0].proxy = proxy_out
344351

345352
out = func(*args, **kwargs)
346353

@@ -376,7 +383,7 @@ def can_handle_tensor(x):
376383
with maybe_disable_fake_tensor_mode():
377384
constant = args[0].clone()
378385
elif (
379-
torch.Tag.nondeterministic_seeded not in func.tags # type: ignore[attr-defined]
386+
(external_call or torch.Tag.nondeterministic_seeded not in func.tags) # type: ignore[attr-defined]
380387
and all_constant
381388
and any_constant
382389
and pytree.tree_all_only(torch.Tensor, lambda t: t.numel() <= CONSTANT_NUMEL_LIMIT, out)

0 commit comments

Comments
 (0)