Skip to content

Commit eeb3e49

Browse files
voznesenskympytorchmergebot
authored andcommitted
Torch package support in dynamo (pytorch#91821)
Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#91821 Approved by: https://github.com/suo
1 parent 73e5379 commit eeb3e49

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

test/dynamo/test_misc.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ def my_custom_function(x):
3939
return x + 1
4040

4141

42+
class MyPickledModule(torch.nn.Module):
43+
def __init__(self, z):
44+
super().__init__()
45+
self.z = z
46+
47+
def forward(self, x, y):
48+
return x * x * x + y + self.z
49+
50+
4251
class MiscTests(torch._dynamo.test_case.TestCase):
4352
def test_boolarg(self):
4453
def boolarg(aa, bb, flag):
@@ -3260,6 +3269,31 @@ def fn(x, y):
32603269
res = opt_fn(x, y)
32613270
self.assertTrue(same(ref, res))
32623271

3272+
def test_torch_package_working_with_inductor_trace(self):
3273+
inputs = [torch.randn([2, 2]), torch.randn([2, 2])]
3274+
3275+
optimized_model = torch._dynamo.optimize(backend="inductor")(
3276+
MyPickledModule(torch.randn([2, 2]))
3277+
)
3278+
from torch import package
3279+
3280+
path = "/tmp/MyPickledModule.pt"
3281+
package_name = "MyPickledModule"
3282+
resource_name = "MyPickledModule.pkl"
3283+
3284+
model = MyPickledModule(torch.randn([2, 2]))
3285+
3286+
with package.PackageExporter(path) as exp:
3287+
exp.extern("**")
3288+
exp.save_pickle(package_name, resource_name, model)
3289+
3290+
imp = package.PackageImporter(path)
3291+
loaded_model = imp.load_pickle(package_name, resource_name)
3292+
3293+
optimized_loaded_model = torch._dynamo.optimize(backend="inductor")(
3294+
loaded_model
3295+
)
3296+
32633297

32643298
class CustomFunc1(torch.autograd.Function):
32653299
@staticmethod

torch/_dynamo/guards.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,19 @@ def __init__(
9797
else:
9898
scope = dict()
9999
self.scope: Dict[str, object] = scope
100+
101+
if "__builtins__" not in self.scope:
102+
self.scope["__builtins__"] = {}
103+
for (
104+
name,
105+
package_module,
106+
) in torch.package.package_importer._package_imported_modules.items():
107+
name = name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
108+
# Write the package module into the scope so that we can import it
109+
self.scope["__builtins__"][name] = package_module # type: ignore[index]
110+
# Write the demangled name to the scope so that we can use it
111+
self.scope[name] = package_module
112+
100113
self.argnames: List[str] = []
101114
# Code is python expression strings generated for each guard
102115
self.code: List[str] = []

torch/_dynamo/symbolic_convert.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -662,8 +662,16 @@ def STORE_GLOBAL(self, inst):
662662

663663
def import_source(self, module_name):
664664
"""Create an alias to a module for use in guards"""
665-
value = importlib.import_module(module_name)
666-
alias = f"__import_{module_name.replace('.', '_dot_')}"
665+
if "torch_package" in module_name:
666+
value = torch.package.package_importer._package_imported_modules[
667+
module_name
668+
]
669+
alias = (
670+
module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
671+
)
672+
else:
673+
value = importlib.import_module(module_name)
674+
alias = f"__import_{module_name.replace('.', '_dot_')}"
667675
f_globals = self.output.root_globals
668676
assert alias not in f_globals or f_globals[alias] is value
669677
f_globals[alias] = value

0 commit comments

Comments
 (0)