Skip to content

Commit 866d9d4

Browse files
suofacebook-github-bot
authored andcommitted
[jit] Fix name collision on load (pytorch#35720)
Summary: Pull Request resolved: pytorch#35720 When modules are saved, all relevant types are serialized according to their qualified name with a compilation unit. Since qualified names are guaranteed to be unique within a compilation unit, this normally works fine. On load, all types are registered in a compilation unit owned by the script::Module. Type names are not unique across compilation units, so if you load two modules with colliding type names, make them submodules of yet another module, and save that module, there is the potential of a name collision. See the added tests for examples if that description is confusing. The solution is to unique type names when serializing code by mangling them if we detect a name collision. Test Plan: Imported from OSS Differential Revision: D20749423 Pulled By: suo fbshipit-source-id: a8827ff1d4a89f3e7964dbbb49b4381863da3e6a
1 parent ee6f7c3 commit 866d9d4

16 files changed

+507
-52
lines changed

aten/src/ATen/core/jit_type.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ struct CAFFE2_API Type : std::enable_shared_from_this<Type> {
112112
//
113113
// Takes a custom printer that users can pass in to customize the output of
114114
// this method.
115-
std::string python_str(TypePrinter printer = nullptr) const {
115+
std::string python_str(TypePrinter printer) const {
116116
if (printer) {
117117
// the printer can return nullopt to fall through to the default impl
118118
if (auto renamed = printer(shared_from_this())) {
@@ -121,6 +121,11 @@ struct CAFFE2_API Type : std::enable_shared_from_this<Type> {
121121
}
122122
return python_str_impl(printer);
123123
}
124+
std::string python_str() const {
125+
// Overload instead of define a default value for `printer` to help
126+
// debuggers out.
127+
return python_str(nullptr);
128+
}
124129

125130
TypeKind kind() const {
126131
return kind_;

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
467467
${TORCH_SRC_DIR}/csrc/jit/runtime/jit_exception.cpp
468468
${TORCH_SRC_DIR}/csrc/jit/frontend/string_to_type.cpp
469469
${TORCH_SRC_DIR}/csrc/jit/serialization/source_range_serialization.cpp
470+
${TORCH_SRC_DIR}/csrc/jit/serialization/type_name_uniquer.cpp
470471
${TORCH_SRC_DIR}/csrc/jit/frontend/tracer.cpp
471472
${TORCH_SRC_DIR}/csrc/jit/testing/hooks_for_testing.cpp
472473
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp

test/jit/test_save_load.py

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
import os
2+
import io
3+
import sys
4+
5+
import torch
6+
from torch import Tensor
7+
from typing import NamedTuple
8+
9+
# Make the helper files in test/ importable
10+
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
11+
sys.path.append(pytorch_test_dir)
12+
from torch.testing._internal.jit_utils import JitTestCase, clear_class_registry
13+
14+
if __name__ == "__main__":
15+
raise RuntimeError(
16+
"This test file is not meant to be run directly, use:\n\n"
17+
"\tpython test/test_jit.py TESTNAME\n\n"
18+
"instead."
19+
)
20+
21+
22+
class TestSaveLoad(JitTestCase):
23+
def test_different_modules(self):
24+
"""
25+
Exercise the situation where we have the same qualified name
26+
in two different CompilationUnits on save/load.
27+
"""
28+
class Foo(torch.nn.Module):
29+
def __init__(self):
30+
super(Foo, self).__init__()
31+
self.foo = torch.nn.Linear(2, 2)
32+
self.bar = torch.nn.Linear(2, 2)
33+
34+
def forward(self, x):
35+
x = self.foo(x)
36+
x = self.bar(x)
37+
return x
38+
39+
first_script_module = torch.jit.script(Foo())
40+
first_saved_module = io.BytesIO()
41+
torch.jit.save(first_script_module, first_saved_module)
42+
first_saved_module.seek(0)
43+
44+
clear_class_registry()
45+
46+
class Foo(torch.nn.Module):
47+
def __init__(self):
48+
super(Foo, self).__init__()
49+
self.foo = torch.nn.Linear(2, 2)
50+
51+
def forward(self, x):
52+
x = self.foo(x)
53+
return x
54+
55+
second_script_module = torch.jit.script(Foo())
56+
second_saved_module = io.BytesIO()
57+
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
58+
second_saved_module.seek(0)
59+
60+
clear_class_registry()
61+
62+
self.assertEqual(
63+
first_script_module._c.qualified_name, second_script_module._c.qualified_name
64+
)
65+
66+
class ContainsBoth(torch.nn.Module):
67+
def __init__(self):
68+
super().__init__()
69+
self.add_module("second", torch.jit.load(second_saved_module))
70+
self.add_module("first", torch.jit.load(first_saved_module))
71+
72+
def forward(self, x):
73+
x = self.first(x)
74+
x = self.second(x)
75+
return x
76+
77+
sm = torch.jit.script(ContainsBoth())
78+
contains_both = io.BytesIO()
79+
torch.jit.save(sm, contains_both)
80+
contains_both.seek(0)
81+
sm = torch.jit.load(contains_both)
82+
83+
def test_different_functions(self):
84+
"""
85+
Exercise the situation where we have the same qualified name
86+
in two different CompilationUnits on save/load.
87+
"""
88+
def lol(x):
89+
return x
90+
91+
class Foo(torch.nn.Module):
92+
def forward(self, x):
93+
return lol(x)
94+
95+
first_script_module = torch.jit.script(Foo())
96+
first_saved_module = io.BytesIO()
97+
torch.jit.save(first_script_module, first_saved_module)
98+
first_saved_module.seek(0)
99+
100+
clear_class_registry()
101+
102+
def lol(x): # noqa: F811
103+
return "hello"
104+
105+
class Foo(torch.nn.Module):
106+
def forward(self, x):
107+
return lol(x)
108+
109+
second_script_module = torch.jit.script(Foo())
110+
second_saved_module = io.BytesIO()
111+
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
112+
second_saved_module.seek(0)
113+
114+
clear_class_registry()
115+
116+
self.assertEqual(
117+
first_script_module._c.qualified_name, second_script_module._c.qualified_name
118+
)
119+
120+
class ContainsBoth(torch.nn.Module):
121+
def __init__(self):
122+
super().__init__()
123+
self.add_module("second", torch.jit.load(second_saved_module))
124+
self.add_module("first", torch.jit.load(first_saved_module))
125+
126+
def forward(self, x):
127+
x = self.first(x)
128+
x = self.second(x)
129+
return x
130+
131+
sm = torch.jit.script(ContainsBoth())
132+
contains_both = io.BytesIO()
133+
torch.jit.save(sm, contains_both)
134+
contains_both.seek(0)
135+
sm = torch.jit.load(contains_both)
136+
137+
def test_different_interfaces(self):
138+
"""
139+
Exercise the situation where we have the same qualified name
140+
in two different CompilationUnits on save/load.
141+
"""
142+
@torch.jit.interface
143+
class MyInterface(object):
144+
def bar(self, x):
145+
# type: (Tensor) -> Tensor
146+
pass
147+
148+
@torch.jit.script
149+
class ImplementInterface(object):
150+
def __init__(self):
151+
pass
152+
153+
def bar(self, x):
154+
return x
155+
156+
class Foo(torch.nn.Module):
157+
__annotations__ = {"interface": MyInterface}
158+
159+
def __init__(self):
160+
super().__init__()
161+
self.interface = ImplementInterface()
162+
163+
def forward(self, x):
164+
return self.interface.bar(x)
165+
166+
first_script_module = torch.jit.script(Foo())
167+
first_saved_module = io.BytesIO()
168+
torch.jit.save(first_script_module, first_saved_module)
169+
first_saved_module.seek(0)
170+
171+
clear_class_registry()
172+
173+
@torch.jit.interface
174+
class MyInterface(object):
175+
def not_bar(self, x):
176+
# type: (Tensor) -> Tensor
177+
pass
178+
179+
@torch.jit.script # noqa: F811
180+
class ImplementInterface(object): # noqa: F811
181+
def __init__(self):
182+
pass
183+
184+
def not_bar(self, x):
185+
return x
186+
187+
class Foo(torch.nn.Module):
188+
__annotations__ = {"interface": MyInterface}
189+
190+
def __init__(self):
191+
super().__init__()
192+
self.interface = ImplementInterface()
193+
194+
def forward(self, x):
195+
return self.interface.not_bar(x)
196+
197+
second_script_module = torch.jit.script(Foo())
198+
second_saved_module = io.BytesIO()
199+
torch.jit.save(torch.jit.script(Foo()), second_saved_module)
200+
second_saved_module.seek(0)
201+
202+
clear_class_registry()
203+
204+
self.assertEqual(
205+
first_script_module._c.qualified_name, second_script_module._c.qualified_name
206+
)
207+
208+
class ContainsBoth(torch.nn.Module):
209+
def __init__(self):
210+
super().__init__()
211+
self.add_module("second", torch.jit.load(second_saved_module))
212+
self.add_module("first", torch.jit.load(first_saved_module))
213+
214+
def forward(self, x):
215+
x = self.first(x)
216+
x = self.second(x)
217+
return x
218+
219+
sm = torch.jit.script(ContainsBoth())
220+
contains_both = io.BytesIO()
221+
torch.jit.save(sm, contains_both)
222+
contains_both.seek(0)
223+
sm = torch.jit.load(contains_both)
224+
225+
def test_many_collisions(self):
226+
class MyCoolNamedTuple(NamedTuple):
227+
a: int
228+
229+
@torch.jit.interface
230+
class MyInterface(object):
231+
def bar(self, x):
232+
# type: (Tensor) -> Tensor
233+
pass
234+
235+
@torch.jit.script
236+
class ImplementInterface(object):
237+
def __init__(self):
238+
pass
239+
240+
def bar(self, x):
241+
return x
242+
243+
def lol(x):
244+
return x
245+
246+
class Foo(torch.nn.Module):
247+
interface: MyInterface
248+
249+
def __init__(self):
250+
super().__init__()
251+
self.foo = torch.nn.Linear(2, 2)
252+
self.bar = torch.nn.Linear(2, 2)
253+
self.interface = ImplementInterface()
254+
255+
def forward(self, x):
256+
x = self.foo(x)
257+
x = self.bar(x)
258+
x = lol(x)
259+
x = self.interface.bar(x)
260+
261+
return x, MyCoolNamedTuple(a=5)
262+
263+
264+
first_script_module = torch.jit.script(Foo())
265+
first_saved_module = io.BytesIO()
266+
torch.jit.save(first_script_module, first_saved_module)
267+
first_saved_module.seek(0)
268+
269+
clear_class_registry()
270+
271+
@torch.jit.interface
272+
class MyInterface(object):
273+
def not_bar(self, x):
274+
# type: (Tensor) -> Tensor
275+
pass
276+
277+
@torch.jit.script # noqa F811
278+
class ImplementInterface(object): # noqa F811
279+
def __init__(self):
280+
pass
281+
282+
def not_bar(self, x):
283+
return x
284+
285+
def lol(x): # noqa F811
286+
return "asdofij"
287+
288+
class MyCoolNamedTuple(NamedTuple): # noqa F811
289+
a: str
290+
291+
class Foo(torch.nn.Module):
292+
interface: MyInterface
293+
294+
def __init__(self):
295+
super().__init__()
296+
self.foo = torch.nn.Linear(2, 2)
297+
self.interface = ImplementInterface()
298+
299+
def forward(self, x):
300+
x = self.foo(x)
301+
self.interface.not_bar(x)
302+
x = lol(x)
303+
return x, MyCoolNamedTuple(a="hello")
304+
305+
second_script_module = torch.jit.script(Foo())
306+
second_saved_module = io.BytesIO()
307+
torch.jit.save(second_script_module, second_saved_module)
308+
second_saved_module.seek(0)
309+
310+
clear_class_registry()
311+
312+
self.assertEqual(
313+
first_script_module._c.qualified_name, second_script_module._c.qualified_name
314+
)
315+
316+
class ContainsBoth(torch.nn.Module):
317+
def __init__(self):
318+
super().__init__()
319+
self.add_module("second", torch.jit.load(second_saved_module))
320+
self.add_module("first", torch.jit.load(first_saved_module))
321+
322+
def forward(self, x):
323+
x, named_tuple_1 = self.first(x)
324+
x, named_tuple_2 = self.second(x)
325+
return len(x + named_tuple_2.a) + named_tuple_1.a
326+
327+
sm = torch.jit.script(ContainsBoth())
328+
contains_both = io.BytesIO()
329+
torch.jit.save(sm, contains_both)
330+
contains_both.seek(0)
331+
sm = torch.jit.load(contains_both)

test/test_jit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401
2525
from jit.test_freezing import TestFreezing # noqa: F401
2626
from jit.test_functional_blocks import TestFunctionalBlocks # noqa: F401
27+
from jit.test_save_load import TestSaveLoad # noqa: F401
2728

2829
# Torch
2930
from torch import Tensor

tools/build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ libtorch_sources = [
8787
"torch/csrc/jit/ir/type_hashing.cpp",
8888
"torch/csrc/jit/serialization/export.cpp",
8989
"torch/csrc/jit/serialization/export_module.cpp",
90+
"torch/csrc/jit/serialization/type_name_uniquer.cpp",
9091
"torch/csrc/jit/passes/pass_manager.cpp",
9192
"torch/csrc/jit/serialization/pickler.cpp",
9293
"torch/csrc/jit/serialization/unpickler.cpp",

torch/csrc/distributed/rpc/utils.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,10 @@ std::string wireSerialize(
264264
}
265265

266266
if (!tensors.empty()) {
267-
torch::jit::Pickler pickler(
268-
[&](const void* buf, size_t sz) -> size_t {
269-
metaEntry.append(static_cast<const char*>(buf), sz);
270-
return sz;
271-
},
272-
nullptr);
267+
torch::jit::Pickler pickler([&](const void* buf, size_t sz) -> size_t {
268+
metaEntry.append(static_cast<const char*>(buf), sz);
269+
return sz;
270+
});
273271
pickler.protocol();
274272
pickler.pushIValue(cloneSparseTensors(tensors));
275273
pickler.stop();

0 commit comments

Comments
 (0)