|
1 | 1 | #!/usr/bin/env python3
|
2 | 2 | import sys
|
3 | 3 | import pickle
|
| 4 | +import struct |
4 | 5 | import pprint
|
5 | 6 | import zipfile
|
6 | 7 | import fnmatch
|
7 |
| -from typing import IO, BinaryIO, Union |
| 8 | +from typing import Any, IO, BinaryIO, Union |
8 | 9 |
|
9 | 10 |
|
10 | 11 | class FakeObject(object):
|
@@ -58,12 +59,42 @@ def fake_new(self, *args):
|
58 | 59 |
|
59 | 60 |
|
60 | 61 | class DumpUnpickler(pickle._Unpickler): # type: ignore[name-defined]
|
| 62 | + def __init__( |
| 63 | + self, |
| 64 | + file, |
| 65 | + *, |
| 66 | + catch_invalid_utf8=False, |
| 67 | + **kwargs): |
| 68 | + super().__init__(file, **kwargs) |
| 69 | + self.catch_invalid_utf8 = catch_invalid_utf8 |
| 70 | + |
61 | 71 | def find_class(self, module, name):
|
62 | 72 | return FakeClass(module, name)
|
63 | 73 |
|
64 | 74 | def persistent_load(self, pid):
|
65 | 75 | return FakeObject("pers", "obj", (pid,))
|
66 | 76 |
|
| 77 | + dispatch = dict(pickle._Unpickler.dispatch) # type: ignore[attr-defined] |
| 78 | + |
| 79 | + # Custom objects in TorchScript are able to return invalid UTF-8 strings |
| 80 | + # from their pickle (__getstate__) functions. Install a custom loader |
| 81 | + # for strings that catches the decode exception and replaces it with |
| 82 | + # a sentinel object. |
| 83 | + def load_binunicode(self): |
| 84 | + strlen, = struct.unpack("<I", self.read(4)) |
| 85 | + if strlen > sys.maxsize: |
| 86 | + raise Exception("String too long.") |
| 87 | + str_bytes = self.read(strlen) |
| 88 | + obj: Any |
| 89 | + try: |
| 90 | + obj = str(str_bytes, "utf-8", "surrogatepass") |
| 91 | + except UnicodeDecodeError as exn: |
| 92 | + if not self.catch_invalid_utf8: |
| 93 | + raise |
| 94 | + obj = FakeObject("builtin", "UnicodeDecodeError", (str(exn),)) |
| 95 | + self.append(obj) |
| 96 | + dispatch[pickle.BINUNICODE[0]] = load_binunicode |
| 97 | + |
67 | 98 | @classmethod
|
68 | 99 | def dump(cls, in_stream, out_stream):
|
69 | 100 | value = cls(in_stream).load()
|
|
0 commit comments