Skip to content

Commit d8fd387

Browse files
yushangdizhudada0120
authored andcommitted
[annotate] Annotate bw nodes before eliminate dead code (pytorch#165782)
Fixes pytorch/torchtitan#1907 Pull Request resolved: pytorch#165782 Approved by: https://github.com/SherlockNoMad
1 parent 4aad712 commit d8fd387

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

torch/_functorch/_aot_autograd/graph_capture.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,12 +468,16 @@ def aot_dispatch_autograd_graph(
468468
# a fake tensor. Unlikely.
469469
# See Note: [Fake Modules and AOTAutograd]
470470
torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)
471+
472+
# Have to copy before eliminate_dead_code otherwise the
473+
# fw node match might be erased
474+
copy_fwd_metadata_to_bw_nodes(fx_g)
475+
471476
fx_g.graph.eliminate_dead_code()
472477
if not aot_config.disable_functionalization:
473478
# There should be *NO* mutating ops in the graph at this point.
474479
assert_functional_graph(fx_g.graph)
475480

476-
copy_fwd_metadata_to_bw_nodes(fx_g)
477481
fx_g.recompile()
478482

479483
# TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect

0 commit comments

Comments
 (0)