15
15
import executorch .exir as exir
16
16
import torch
17
17
from executorch .exir import ExecutorchBackendConfig , ExecutorchProgramManager , to_edge
18
+ from executorch .exir .capture ._capture import patch_forward
18
19
from executorch .exir .dynamic_shape import DynamicMemoryPlanningMode
19
20
from executorch .exir .passes import (
20
21
DebugPass ,
@@ -70,6 +71,7 @@ def export(
70
71
export_joint_graph : bool = False ,
71
72
external_constants : bool = False ,
72
73
export_state_names : bool = False ,
74
+ share_mutable_buffers : bool = False ,
73
75
) -> "ExportedModule" :
74
76
"""
75
77
Creates a new ExportedModule for the specified module class.
@@ -134,10 +136,13 @@ def return_wrapper():
134
136
# all exported methods must have the same signature so just pick the first one.
135
137
methods [0 ],
136
138
)
137
- trace_inputs : Sequence = get_trace_inputs ()
139
+ inputs : Sequence = get_trace_inputs ()
138
140
method_name_to_args = {}
139
141
for method in methods :
140
- method_name_to_args [method ] = trace_inputs
142
+ if hasattr (eager_module , "get_random_inputs_per_method" ):
143
+ # pyre-ignore
144
+ inputs = eager_module .get_random_inputs_per_method ()[method ]
145
+ method_name_to_args [method ] = inputs
141
146
142
147
method_name_to_dynamic_shapes = None
143
148
if hasattr (eager_module , "get_dynamic_shapes" ):
@@ -149,23 +154,18 @@ def return_wrapper():
149
154
method_name_to_dynamic_shapes [method ] = trace_dynamic_shapes
150
155
151
156
memory_planning_pass = MemoryPlanningPass (
152
- alloc_mutable_buffers = not export_state_names
157
+ alloc_mutable_buffers = not export_state_names ,
158
+ share_mutable_buffers = share_mutable_buffers ,
153
159
)
154
160
if hasattr (eager_module , "get_memory_planning_pass" ):
155
161
memory_planning_pass = eager_module .get_memory_planning_pass () # type: ignore[operator]
156
162
157
- class WrapperModule (nn .Module ):
158
- def __init__ (self , method ):
159
- super ().__init__ ()
160
- self .forward = method
161
-
162
163
exported_methods = {}
163
164
# These cleanup passes are required to convert the `add` op to its out
164
165
# variant, along with some other transformations.
165
166
for method_name , method_input in method_name_to_args .items ():
166
167
# if not isinstance(eager_module, torch.nn.Module):
167
168
if export_joint_graph :
168
- # _export was having issues with WrapperModule.
169
169
assert method_name == "forward"
170
170
ep = _export (
171
171
eager_module ,
@@ -179,15 +179,16 @@ def __init__(self, method):
179
179
)
180
180
exported_methods [method_name ] = _export_forward_backward (ep )
181
181
else :
182
- exported_methods [method_name ] = export (
183
- eager_module ,
184
- method_input , # type: ignore[arg-type]
185
- dynamic_shapes = (
186
- method_name_to_dynamic_shapes [method_name ]
187
- if method_name_to_dynamic_shapes
188
- else None
189
- ),
190
- )
182
+ with patch_forward (eager_module , getattr (eager_module , method_name )):
183
+ exported_methods [method_name ] = export (
184
+ eager_module ,
185
+ method_input , # type: ignore[arg-type]
186
+ dynamic_shapes = (
187
+ method_name_to_dynamic_shapes [method_name ]
188
+ if method_name_to_dynamic_shapes
189
+ else None
190
+ ),
191
+ )
191
192
192
193
exec_prog = to_edge (
193
194
exported_methods ,
@@ -229,6 +230,6 @@ def __init__(self, method):
229
230
methods = methods ,
230
231
executorch_program = exec_prog ,
231
232
exported_program = exported_program ,
232
- trace_inputs = trace_inputs ,
233
+ trace_inputs = inputs ,
233
234
get_random_inputs_fn = get_random_inputs_fn ,
234
235
)
0 commit comments