15
15
import executorch .exir as exir
16
16
import torch
17
17
from executorch .exir import ExecutorchBackendConfig , ExecutorchProgramManager , to_edge
18
+ from executorch .exir .capture ._capture import patch_forward
18
19
from executorch .exir .dynamic_shape import DynamicMemoryPlanningMode
19
20
from executorch .exir .passes import (
20
21
DebugPass ,
@@ -70,6 +71,7 @@ def export(
70
71
export_joint_graph : bool = False ,
71
72
external_constants : bool = False ,
72
73
export_state_names : bool = False ,
74
+ share_mutable_buffers : bool = False ,
73
75
) -> "ExportedModule" :
74
76
"""
75
77
Creates a new ExportedModule for the specified module class.
@@ -134,10 +136,13 @@ def return_wrapper():
134
136
# all exported methods must have the same signature so just pick the first one.
135
137
methods [0 ],
136
138
)
137
- trace_inputs : Sequence = get_trace_inputs ()
139
+ inputs : Sequence = get_trace_inputs ()
138
140
method_name_to_args = {}
139
141
for method in methods :
140
- method_name_to_args [method ] = trace_inputs
142
+ if hasattr (eager_module , "get_random_inputs_per_method" ):
143
+ # pyre-ignore
144
+ inputs = eager_module .get_random_inputs_per_method ()[method ]
145
+ method_name_to_args [method ] = inputs
141
146
142
147
method_name_to_dynamic_shapes = None
143
148
if hasattr (eager_module , "get_dynamic_shapes" ):
@@ -149,23 +154,17 @@ def return_wrapper():
149
154
method_name_to_dynamic_shapes [method ] = trace_dynamic_shapes
150
155
151
156
memory_planning_pass = MemoryPlanningPass (
152
- alloc_mutable_buffers = not export_state_names
157
+ alloc_mutable_buffers = not export_state_names , share_mutable_buffers = share_mutable_buffers
153
158
)
154
159
if hasattr (eager_module , "get_memory_planning_pass" ):
155
160
memory_planning_pass = eager_module .get_memory_planning_pass () # type: ignore[operator]
156
161
157
- class WrapperModule (nn .Module ):
158
- def __init__ (self , method ):
159
- super ().__init__ ()
160
- self .forward = method
161
-
162
162
exported_methods = {}
163
163
# These cleanup passes are required to convert the `add` op to its out
164
164
# variant, along with some other transformations.
165
165
for method_name , method_input in method_name_to_args .items ():
166
166
# if not isinstance(eager_module, torch.nn.Module):
167
167
if export_joint_graph :
168
- # _export was having issues with WrapperModule.
169
168
assert method_name == "forward"
170
169
ep = _export (
171
170
eager_module ,
@@ -179,15 +178,16 @@ def __init__(self, method):
179
178
)
180
179
exported_methods [method_name ] = _export_forward_backward (ep )
181
180
else :
182
- exported_methods [method_name ] = export (
183
- eager_module ,
184
- method_input , # type: ignore[arg-type]
185
- dynamic_shapes = (
186
- method_name_to_dynamic_shapes [method_name ]
187
- if method_name_to_dynamic_shapes
188
- else None
189
- ),
190
- )
181
+ with patch_forward (eager_module , getattr (eager_module , method_name )):
182
+ exported_methods [method_name ] = export (
183
+ eager_module ,
184
+ method_input , # type: ignore[arg-type]
185
+ dynamic_shapes = (
186
+ method_name_to_dynamic_shapes [method_name ]
187
+ if method_name_to_dynamic_shapes
188
+ else None
189
+ ),
190
+ )
191
191
192
192
exec_prog = to_edge (
193
193
exported_methods ,
@@ -229,6 +229,6 @@ def __init__(self, method):
229
229
methods = methods ,
230
230
executorch_program = exec_prog ,
231
231
exported_program = exported_program ,
232
- trace_inputs = trace_inputs ,
232
+ trace_inputs = inputs ,
233
233
get_random_inputs_fn = get_random_inputs_fn ,
234
234
)
0 commit comments