Skip to content

Commit 68df4d4

Browse files
dreissfacebook-github-bot
authored andcommitted
show_pickle/model_dump: Handle invalid UTF-8 in pickles (pytorch#57661)
Summary: Pull Request resolved: pytorch#57661 Thie Pickle "specification" (pickletools.py) states that the argument to a BINUNICODE opcode must be UTF-8 encoded. However, if a PyTorch custom class returns a non-UTF-8 std::string from its pickle method the libtorch Pickler will write it to the output pickle without complaining. Python's _Unpickler (the Python implementation of Unpickler) always throws an exception when trying to deserialize these invalid pickles. We still want to be able to dump these pickle files. Update DumpUnpickler to create its own opcode dispatch table (initialized as a clone of the _Unpickler dispatch table) and patch in a custom function for the BINUNICODE op. We try to emulate the default behavior, but any UnicodeDecodeError is caught and replaced with a dummy object. This could violate the assumptions of a user that expects a str in that position, so we disable this behavior by default. Update model_dump to recognize this special object and allow it to be rendered. Test Plan: Dumped and viewed a model with an invalid string in an object state. Reviewed By: malfet Differential Revision: D28531392 Pulled By: dreiss fbshipit-source-id: ab5aea20975a0ef53ef52a880deaa2c5a626e4a2
1 parent ba3a90b commit 68df4d4

File tree

3 files changed

+50
-3
lines changed

3 files changed

+50
-3
lines changed

torch/utils/model_dump/__init__.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,17 @@ def hierarchical_pickle(data):
162162
assert isinstance(name, str)
163163
# Just forget that it was a device and return the name.
164164
return name
165+
if typename == "builtin.UnicodeDecodeError":
166+
assert data.state is None
167+
msg, = data.args
168+
assert isinstance(msg, str)
169+
# Hack: Pretend this is a module so we don't need custom serialization.
170+
# Hack: Wrap the message in a tuple so it looks like a nice state object.
171+
# TODO: Undo at least that second hack. We should support string states.
172+
return {
173+
"__module_type__": typename,
174+
"state": hierarchical_pickle((msg,)),
175+
}
165176
raise Exception(f"Can't prepare fake object of type for JS: {typename}")
166177
raise Exception(f"Can't prepare data of type for JS: {type(data)}")
167178

@@ -210,7 +221,7 @@ def get_model_info(
210221
version = zf.read(path_prefix + "/version").decode("utf-8").strip()
211222

212223
with zf.open(path_prefix + "/data.pkl") as handle:
213-
raw_model_data = torch.utils.show_pickle.DumpUnpickler(handle).load()
224+
raw_model_data = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load()
214225
model_data = hierarchical_pickle(raw_model_data)
215226

216227
# Intern strings that are likely to be re-used.
@@ -287,7 +298,7 @@ def ist(s):
287298
# TODO: handle errors here and just ignore the file?
288299
# NOTE: For a lot of these files (like bytecode),
289300
# we could get away with just unpickling, but this should be safer.
290-
obj = torch.utils.show_pickle.DumpUnpickler(handle).load()
301+
obj = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load()
291302
buf = io.StringIO()
292303
pprint.pprint(obj, buf)
293304
contents = buf.getvalue()

torch/utils/model_dump/code.js

+5
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ class ModelData extends Component {
276276
}
277277
} else if (mstate.__tuple_values__) {
278278
parts.push(html`<br/><${ModelData} prefix="" indent=${new_indent} data=${mstate} />`);
279+
} else if (mstate.__module_type__) {
280+
// We normally wouldn't have the state of a module be another module,
281+
// but we use "modules" to encode special values (like Unicode decode
282+
// errors) that might be valid states. Just go with it.
283+
parts.push(html`<br/><${ModelData} prefix="" indent=${new_indent} data=${mstate} />`);
279284
} else {
280285
throw new Error("Bad module state");
281286
}

torch/utils/show_pickle.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#!/usr/bin/env python3
22
import sys
33
import pickle
4+
import struct
45
import pprint
56
import zipfile
67
import fnmatch
7-
from typing import IO, BinaryIO, Union
8+
from typing import Any, IO, BinaryIO, Union
89

910

1011
class FakeObject(object):
@@ -58,12 +59,42 @@ def fake_new(self, *args):
5859

5960

6061
class DumpUnpickler(pickle._Unpickler): # type: ignore[name-defined]
62+
def __init__(
63+
self,
64+
file,
65+
*,
66+
catch_invalid_utf8=False,
67+
**kwargs):
68+
super().__init__(file, **kwargs)
69+
self.catch_invalid_utf8 = catch_invalid_utf8
70+
6171
def find_class(self, module, name):
6272
return FakeClass(module, name)
6373

6474
def persistent_load(self, pid):
6575
return FakeObject("pers", "obj", (pid,))
6676

77+
dispatch = dict(pickle._Unpickler.dispatch) # type: ignore[attr-defined]
78+
79+
# Custom objects in TorchScript are able to return invalid UTF-8 strings
80+
# from their pickle (__getstate__) functions. Install a custom loader
81+
# for strings that catches the decode exception and replaces it with
82+
# a sentinel object.
83+
def load_binunicode(self):
84+
strlen, = struct.unpack("<I", self.read(4))
85+
if strlen > sys.maxsize:
86+
raise Exception("String too long.")
87+
str_bytes = self.read(strlen)
88+
obj: Any
89+
try:
90+
obj = str(str_bytes, "utf-8", "surrogatepass")
91+
except UnicodeDecodeError as exn:
92+
if not self.catch_invalid_utf8:
93+
raise
94+
obj = FakeObject("builtin", "UnicodeDecodeError", (str(exn),))
95+
self.append(obj)
96+
dispatch[pickle.BINUNICODE[0]] = load_binunicode
97+
6798
@classmethod
6899
def dump(cls, in_stream, out_stream):
69100
value = cls(in_stream).load()

0 commit comments

Comments
 (0)