@@ -49,20 +49,63 @@ def post_process_error_msg(
4949 return constraint_violation_error
5050
5151
52- def clean_nn_module_stack (graph_module : torch .fx .GraphModule ) -> torch .fx .GraphModule :
52+ def clean_nn_module_stack (
53+ graph_module : torch .fx .GraphModule , is_inline_builtin = False
54+ ) -> torch .fx .GraphModule :
55+ """
56+ Clean up nn_module_stack metadata by removing export_root references.
57+
58+ Removes the _export_root module references from nn_module_stack metadata
59+ in graph nodes, which are artifacts from the export process. Fixes two patterns:
60+
61+ 1. Keys: Removes "__export_root_" and "__modules['_export_root']_" prefixes
62+ - Normal case: "L__self____export_root_child" -> "L__self__child"
63+ - inline_builtin case: Uses numeric ID strings like "140468831433840"
64+
65+ 2. Values: Removes "._export_root" and "._modules['_export_root']" from child names
66+ e.g., "L['self']._export_root.child" -> "L['self'].child"
67+ e.g., "L['self']._modules['_export_root'].child" -> "L['self'].child"
68+
69+ Also removes the root export entry "L__self____export_root" entirely.
70+
71+ Args:
72+ graph_module: The GraphModule to clean up
73+ is_inline_builtin: If True, keys are numeric ID strings and self references
74+ (L['self']) are filtered out
75+
76+ Returns:
77+ The cleaned GraphModule (modified in-place)
78+ """
5379 for node in graph_module .graph .nodes :
54- if "nn_module_stack" in node .meta :
55- nn_module_stack = node .meta ["nn_module_stack" ].copy ()
56- first_key = next (iter (nn_module_stack .keys ()))
57- if "export_root" in first_key :
58- del nn_module_stack [first_key ]
59- nn_module_stack_corrected = {}
60- for k , v in nn_module_stack .items ():
61- k_new = "" .join (k .split ("__export_root" ))
62- child_name , child_class = v
63- child_name = child_name .replace ("._export_root" , "" )
64- nn_module_stack_corrected [k_new ] = (child_name , child_class )
65- node .meta ["nn_module_stack" ] = nn_module_stack_corrected
80+ if "nn_module_stack" not in node .meta :
81+ continue
82+
83+ nn_module_stack = node .meta ["nn_module_stack" ].copy ()
84+
85+ if "L__self____export_root" in nn_module_stack :
86+ del nn_module_stack ["L__self____export_root" ]
87+
88+ # Clean up remaining entries
89+ cleaned_stack = {}
90+ for key , (child_name , child_class ) in nn_module_stack .items ():
91+ # Clean key by removing export_root patterns
92+ clean_key = key .replace ("__modules['_export_root']_" , "" ).replace (
93+ "__export_root_" , ""
94+ )
95+
96+ # Clean child_name by removing export_root patterns
97+ clean_name = child_name .replace ("._modules['_export_root']" , "" ).replace (
98+ "._export_root" , ""
99+ )
100+
101+ # Skip self reference for inline builtin case
102+ if is_inline_builtin and clean_name == "L['self']" :
103+ continue
104+
105+ cleaned_stack [clean_key ] = (clean_name , child_class )
106+
107+ node .meta ["nn_module_stack" ] = cleaned_stack
108+
66109 return graph_module
67110
68111
@@ -71,7 +114,11 @@ def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
71114
72115 # Clean parameter names: L__self____export_root_param -> L__self___param
73116 def clean_name (name ) -> str :
74- return name .replace ("__export_root_" , "_" ) if "__export_root_" in name else name
117+ if "____modules___export_root_" in name :
118+ return name .replace ("____modules___export_root_" , "_" )
119+ if "__export_root_" in name :
120+ return name .replace ("__export_root_" , "_" )
121+ return name
75122
76123 # Update get_attr nodes in-place
77124 for node in graph_module .graph .nodes :
@@ -409,7 +456,9 @@ def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
409456 )
410457 transformed_graph .recompile ()
411458
412- clean_nn_module_stack (transformed_graph )
459+ clean_nn_module_stack (
460+ transformed_graph , torch ._dynamo .config .inline_inbuilt_nn_modules
461+ )
413462 clean_export_root (transformed_graph )
414463
415464 transformed_graph .meta ["module_call_specs" ] = module_call_spec
0 commit comments