Skip to content

Commit b1dddcf

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
generate shared state test program
Summary: Had to hack around some legacy code here that assumes all entry points use the same input Differential Revision: D82329519
1 parent 654e722 commit b1dddcf

File tree

4 files changed

+62
-21
lines changed

4 files changed

+62
-21
lines changed

extension/module/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def define_common_targets(is_fbcode=False):
1919
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
2020
"ET_MODULE_ADD_MUL_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.pte])",
2121
"ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])",
22+
"ET_MODULE_SHARED_STATE": "$(location fbcode//executorch/test/models:exported_programs[ModuleSharedState.pte])",
2223
}
2324

2425
for aten_mode in get_aten_mode_options():

test/end2end/exported_module.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import executorch.exir as exir
1616
import torch
1717
from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager, to_edge
18+
from executorch.exir.capture._capture import patch_forward
1819
from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
1920
from executorch.exir.passes import (
2021
DebugPass,
@@ -70,6 +71,7 @@ def export(
7071
export_joint_graph: bool = False,
7172
external_constants: bool = False,
7273
export_state_names: bool = False,
74+
share_mutable_buffers: bool = False,
7375
) -> "ExportedModule":
7476
"""
7577
Creates a new ExportedModule for the specified module class.
@@ -134,10 +136,13 @@ def return_wrapper():
134136
# all exported methods must have the same signature so just pick the first one.
135137
methods[0],
136138
)
137-
trace_inputs: Sequence = get_trace_inputs()
139+
inputs: Sequence = get_trace_inputs()
138140
method_name_to_args = {}
139141
for method in methods:
140-
method_name_to_args[method] = trace_inputs
142+
if hasattr(eager_module, "get_random_inputs_per_method"):
143+
# pyre-ignore
144+
inputs = eager_module.get_random_inputs_per_method()[method]
145+
method_name_to_args[method] = inputs
141146

142147
method_name_to_dynamic_shapes = None
143148
if hasattr(eager_module, "get_dynamic_shapes"):
@@ -149,23 +154,17 @@ def return_wrapper():
149154
method_name_to_dynamic_shapes[method] = trace_dynamic_shapes
150155

151156
memory_planning_pass = MemoryPlanningPass(
152-
alloc_mutable_buffers=not export_state_names
157+
alloc_mutable_buffers=not export_state_names, share_mutable_buffers=share_mutable_buffers
153158
)
154159
if hasattr(eager_module, "get_memory_planning_pass"):
155160
memory_planning_pass = eager_module.get_memory_planning_pass() # type: ignore[operator]
156161

157-
class WrapperModule(nn.Module):
158-
def __init__(self, method):
159-
super().__init__()
160-
self.forward = method
161-
162162
exported_methods = {}
163163
# These cleanup passes are required to convert the `add` op to its out
164164
# variant, along with some other transformations.
165165
for method_name, method_input in method_name_to_args.items():
166166
# if not isinstance(eager_module, torch.nn.Module):
167167
if export_joint_graph:
168-
# _export was having issues with WrapperModule.
169168
assert method_name == "forward"
170169
ep = _export(
171170
eager_module,
@@ -179,15 +178,16 @@ def __init__(self, method):
179178
)
180179
exported_methods[method_name] = _export_forward_backward(ep)
181180
else:
182-
exported_methods[method_name] = export(
183-
eager_module,
184-
method_input, # type: ignore[arg-type]
185-
dynamic_shapes=(
186-
method_name_to_dynamic_shapes[method_name]
187-
if method_name_to_dynamic_shapes
188-
else None
189-
),
190-
)
181+
with patch_forward(eager_module, getattr(eager_module, method_name)):
182+
exported_methods[method_name] = export(
183+
eager_module,
184+
method_input, # type: ignore[arg-type]
185+
dynamic_shapes=(
186+
method_name_to_dynamic_shapes[method_name]
187+
if method_name_to_dynamic_shapes
188+
else None
189+
),
190+
)
191191

192192
exec_prog = to_edge(
193193
exported_methods,
@@ -229,6 +229,6 @@ def __init__(self, method):
229229
methods=methods,
230230
executorch_program=exec_prog,
231231
exported_program=exported_program,
232-
trace_inputs=trace_inputs,
232+
trace_inputs=inputs,
233233
get_random_inputs_fn=get_random_inputs_fn,
234234
)

test/models/export_program.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,38 @@ def forward(self, x):
261261
def get_random_inputs(self):
262262
return (torch.randint(100, [1, 3], dtype=torch.long),)
263263

264+
class ModuleSharedState(torch.nn.Module):
265+
def __init__(self):
266+
super().__init__()
267+
self.register_buffer("state", torch.ones(1))
268+
269+
def forward(self, x):
270+
return self.state.add_(1) + x
271+
272+
def get_state(self):
273+
return self.state
274+
275+
def set_state(self, x):
276+
self.state.copy_(x)
277+
278+
# Including this is tech debt since we will immediately override it with the per method one.
279+
# ExportedModule is really old infra though from before multiple methods were supported. So
280+
# its really obnoxious to change.
281+
def get_random_inputs(self):
282+
return (torch.ones(1),)
283+
284+
def get_random_inputs_per_method(self):
285+
return {"forward" : (torch.ones(1),), "get_state" : (), "set_state" : (torch.ones(1),)}
286+
287+
@staticmethod
288+
def get_method_names_to_export() -> List[str]:
289+
return ["forward", "get_state", "set_state"]
290+
291+
@staticmethod
292+
def share_mutable_buffers():
293+
return True
294+
295+
264296

265297
#
266298
# Main logic.
@@ -280,21 +312,28 @@ def export_module_to_program(
280312
export_kwargs = module_class.get_export_kwargs()
281313
export_joint = False
282314
export_state_names = False
315+
share_mutable_buffers = False
283316
if hasattr(module_class, "export_joint"):
284-
export_joint = module_class.export_joint() # pyre-ignore
317+
# pyre-ignore[16]: pyre just cant figure it out
318+
export_joint = module_class.export_joint()
285319
if hasattr(module_class, "export_state_names"):
320+
# pyre-ignore[16]: pyre just cant figure it out
286321
export_state_names = module_class.export_state_names()
287322
if hasattr(module_class, "get_method_names_to_export"):
288-
# pyre-ignore[16]: pyre doesn't know about get_export_kwargs.
323+
# pyre-ignore[16]: pyre just cant figure it out
289324
methods = module_class.get_method_names_to_export()
290325
else:
291326
methods = ["forward"]
327+
if hasattr(module_class, "share_mutable_buffers"):
328+
# pyre-ignore[16]: pyre just cant figure it out
329+
share_mutable_buffers = module_class.share_mutable_buffers()
292330
module = ExportedModule.export(
293331
module_class,
294332
methods,
295333
export_joint_graph=export_joint,
296334
external_constants=external_constants,
297335
export_state_names=export_state_names,
336+
share_mutable_buffers=share_mutable_buffers,
298337
**export_kwargs,
299338
)
300339
return module.executorch_program

test/models/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def define_common_targets():
7171
"ModuleDynamicCatUnallocatedIO",
7272
"ModuleSimpleTrain",
7373
"ModuleStateful",
74+
"ModuleSharedState",
7475
]
7576

7677
# Generates Executorch .pte program files for various modules at build time.

0 commit comments

Comments
 (0)