From edaca2c538dfdf60e31d2db14f0fe70829a7bc89 Mon Sep 17 00:00:00 2001 From: Leonardo Gama Date: Tue, 4 Oct 2022 12:29:25 -0300 Subject: [PATCH 1/2] new is_pickled_module() function; fix _identify_module() --- dill/__init__.py | 2 +- dill/session.py | 107 ++++++++++++++++++++++++++++++------- dill/tests/test_session.py | 32 +++++++++++ 3 files changed, 122 insertions(+), 19 deletions(-) diff --git a/dill/__init__.py b/dill/__init__.py index 6f71bbe5..8d5a656c 100644 --- a/dill/__init__.py +++ b/dill/__init__.py @@ -31,7 +31,7 @@ UnpicklingWarning, ) from .session import ( - dump_module, load_module, load_module_asdict, + dump_module, load_module, load_module_asdict, is_pickled_module, dump_session, load_session # backward compatibility ) from . import detect, logger, session, source, temp diff --git a/dill/session.py b/dill/session.py index 6acdd432..29368792 100644 --- a/dill/session.py +++ b/dill/session.py @@ -11,7 +11,7 @@ """ __all__ = [ - 'dump_module', 'load_module', 'load_module_asdict', + 'dump_module', 'load_module', 'load_module_asdict', 'is_pickled_module', 'dump_session', 'load_session' # backward compatibility ] @@ -19,12 +19,14 @@ import sys import warnings -from dill import _dill, Pickler, Unpickler +from dill import _dill +from dill import Pickler, Unpickler, UnpicklingError from ._dill import ( BuiltinMethodType, FunctionType, MethodType, ModuleType, TypeType, _import_module, _is_builtin_module, _is_imported_module, _main_module, _reverse_typemap, __builtin__, ) +from ._utils import _open # Type hints. from typing import Optional, Union @@ -285,26 +287,95 @@ def _make_peekable(stream): def _identify_module(file, main=None): """identify the name of the module stored in the given file-type object""" - from pickletools import genops - UNICODE = {'UNICODE', 'BINUNICODE', 'SHORT_BINUNICODE'} - found_import = False + import pickletools + NEUTRAL = {'PROTO', 'FRAME', 'PUT', 'BINPUT', 'MEMOIZE', 'MARK', 'STACK_GLOBAL'} try: - for opcode, arg, pos in genops(file.peek(256)): - if not found_import: - if opcode.name in ('GLOBAL', 'SHORT_BINUNICODE') and \ - arg.endswith('_import_module'): - found_import = True - else: - if opcode.name in UNICODE: - return arg - else: - raise UnpicklingError("reached STOP without finding main module") + opcodes = ((opcode.name, arg) for opcode, arg, pos in pickletools.genops(file.peek(256)) + if opcode.name not in NEUTRAL) + opcode, arg = next(opcodes) + if (opcode, arg) == ('SHORT_BINUNICODE', 'dill._dill'): + # The file uses STACK_GLOBAL instead of GLOBAL. + opcode, arg = next(opcodes) + if not (opcode in ('SHORT_BINUNICODE', 'GLOBAL') and arg.split()[-1] == '_import_module'): + raise ValueError + opcode, arg = next(opcodes) + if not opcode in ('SHORT_BINUNICODE', 'BINUNICODE', 'UNICODE'): + raise ValueError + module_name = arg + if not ( + next(opcodes)[0] in ('TUPLE1', 'TUPLE') and + next(opcodes)[0] == 'REDUCE' and + next(opcodes)[0] in ('EMPTY_DICT', 'DICT') + ): + raise ValueError + return module_name + except StopIteration: + raise UnpicklingError("reached STOP without finding module") from None except (NotImplementedError, ValueError) as error: - # ValueError occours when the end of the chunk is reached (without a STOP). + # ValueError also occours when the end of the chunk is reached (without a STOP). if isinstance(error, NotImplementedError) and main is not None: - # file is not peekable, but we have main. + # The file is not peekable, but we have the argument main. return None - raise UnpicklingError("unable to identify main module") from error + raise UnpicklingError("unable to identify module") from error + +def is_pickled_module( + filename, importable: bool = True, identify: bool = False +) -> Union[bool, str]: + """Check if a file can be loaded with :func:`load_module`. + + Check if the file is a pickle file generated with :func:`dump_module`, + and thus can be loaded with :func:`load_module`. + + Parameters: + filename: a path-like object or a readable stream. + importable: expected kind of the file's saved module. Use `True` for + importable modules (the default) or `False` for module-type objects. + identify: if `True`, return the module name if the test succeeds. + + Returns: + `True` if the pickle file at ``filename`` was generated with + :func:`dump_module` **AND** the module whose state is saved in it is + of the kind specified by the ``importable`` argument. `False` otherwise. + If `identify` is set, return the name of the module instead of `True`. + + Examples: + Create three types of pickle files: + + >>> import dill + >>> import types + >>> dill.dump_module('module_session.pkl') # saves __main__ + >>> dill.dump_module('module_object.pkl', module=types.ModuleType('example')) + >>> with open('common_object.pkl', 'wb') as file: + >>> dill.dump('example', file) + + Test each file's kind: + + >>> dill.is_pickled_module('module_session.pkl') # the module is importable + True + >>> dill.is_pickled_module('module_session.pkl', importable=False) + False + >>> dill.is_pickled_module('module_object.pkl') # the module is not importable + False + >>> dill.is_pickled_module('module_object.pkl', importable=False) + True + >>> dill.is_pickled_module('module_object.pkl', importable=False, identify=True) + 'example' + >>> dill.is_pickled_module('common_object.pkl') # always return False + False + >>> dill.is_pickled_module('common_object.pkl', importable=False) + False + """ + with _open(filename, 'rb', peekable=True) as file: + try: + pickle_main = _identify_module(file) + except UnpicklingError: + return False + is_runtime_mod = pickle_main.startswith('__runtime__.') + res = importable ^ is_runtime_mod + if res and identify: + return pickle_main.partition('.')[-1] if is_runtime_mod else pickle_main + else: + return res def load_module( filename = str(TEMPDIR/'session.pkl'), diff --git a/dill/tests/test_session.py b/dill/tests/test_session.py index 51128916..e6a7921a 100644 --- a/dill/tests/test_session.py +++ b/dill/tests/test_session.py @@ -11,6 +11,7 @@ import __main__ from contextlib import suppress from io import BytesIO +from types import ModuleType import dill @@ -271,6 +272,36 @@ def test_load_module_asdict(): assert 'y' not in main_vars assert 'empty' in main_vars +def test_is_pickled_module(): + import tempfile + import warnings + + # Module saved with dump(). + pickle_file = tempfile.NamedTemporaryFile(mode='wb') + dill.dump(os, pickle_file) + pickle_file.flush() + assert not dill.is_pickled_module(pickle_file.name) + assert not dill.is_pickled_module(pickle_file.name, importable=False) + pickle_file.close() + + # Importable module saved with dump_module(). + pickle_file = tempfile.NamedTemporaryFile(mode='wb') + dill.dump_module(pickle_file, local_mod) + pickle_file.flush() + assert dill.is_pickled_module(pickle_file.name) + assert not dill.is_pickled_module(pickle_file.name, importable=False) + assert dill.is_pickled_module(pickle_file.name, identify=True) == local_mod.__name__ + pickle_file.close() + + # Module-type object saved with dump_module(). + pickle_file = tempfile.NamedTemporaryFile(mode='wb') + dill.dump_module(pickle_file, ModuleType('runtime')) + pickle_file.flush() + assert not dill.is_pickled_module(pickle_file.name) + assert dill.is_pickled_module(pickle_file.name, importable=False) + assert dill.is_pickled_module(pickle_file.name, importable=False, identify=True) == 'runtime' + pickle_file.close() + if __name__ == '__main__': test_session_main(refimported=False) test_session_main(refimported=True) @@ -278,3 +309,4 @@ def test_load_module_asdict(): test_runtime_module() test_refimported_imported_as() test_load_module_asdict() + test_is_pickled_module() From a0de1c975e13f5b4df6bdad58d280c18bb81f2af Mon Sep 17 00:00:00 2001 From: Leonardo Gama Date: Tue, 4 Oct 2022 12:35:45 -0300 Subject: [PATCH 2/2] didn't add new files... --- dill/_utils.py | 104 +++++++++++++++++++++++++++++++++++++++ dill/tests/test_utils.py | 63 ++++++++++++++++++++++++ 2 files changed, 167 insertions(+) create mode 100644 dill/_utils.py create mode 100644 dill/tests/test_utils.py diff --git a/dill/_utils.py b/dill/_utils.py new file mode 100644 index 00000000..ae11b3c8 --- /dev/null +++ b/dill/_utils.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# +# Author: Leonardo Gama (@leogama) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE +""" +Auxiliary classes and functions used in more than one module, defined here to +avoid circular import problems. +""" + +import contextlib +import io +from contextlib import suppress + + +## File-related utilities ## + +class _PeekableReader(contextlib.AbstractContextManager): + """lightweight readable stream wrapper that implements peek()""" + def __init__(self, stream, closing=True): + self.stream = stream + self.closing = closing + def __exit__(self, *exc_info): + if self.closing: + self.stream.close() + def read(self, n): + return self.stream.read(n) + def readline(self): + return self.stream.readline() + def tell(self): + return self.stream.tell() + def close(self): + return self.stream.close() + def peek(self, n): + stream = self.stream + try: + if hasattr(stream, 'flush'): + stream.flush() + position = stream.tell() + stream.seek(position) # assert seek() works before reading + chunk = stream.read(n) + stream.seek(position) + return chunk + except (AttributeError, OSError): + raise NotImplementedError("stream is not peekable: %r", stream) from None + +class _SeekableWriter(io.BytesIO, contextlib.AbstractContextManager): + """works as an unlimited buffer, writes to file on close""" + def __init__(self, stream, closing=True, *args, **kwds): + super().__init__(*args, **kwds) + self.stream = stream + self.closing = closing + def __exit__(self, *exc_info): + self.close() + def close(self): + self.stream.write(self.getvalue()) + with suppress(AttributeError): + self.stream.flush() + super().close() + if self.closing: + self.stream.close() + +def _open(file, mode, *, peekable=False, seekable=False): + """return a context manager with an opened file-like object""" + readonly = ('r' in mode and '+' not in mode) + if not readonly and peekable: + raise ValueError("the 'peekable' option is invalid for writable files") + if readonly and seekable: + raise ValueError("the 'seekable' option is invalid for read-only files") + should_close = not hasattr(file, 'read' if readonly else 'write') + if should_close: + file = open(file, mode) + # Wrap stream in a helper class if necessary. + if peekable and not hasattr(file, 'peek'): + # Try our best to return it as an object with a peek() method. + if hasattr(file, 'seekable'): + file_seekable = file.seekable() + elif hasattr(file, 'seek') and hasattr(file, 'tell'): + try: + file.seek(file.tell()) + file_seekable = True + except Exception: + file_seekable = False + else: + file_seekable = False + if file_seekable: + file = _PeekableReader(file, closing=should_close) + else: + try: + file = io.BufferedReader(file) + except Exception: + # It won't be peekable, but will fail gracefully in _identify_module(). + file = _PeekableReader(file, closing=should_close) + elif seekable and ( + not hasattr(file, 'seek') + or not hasattr(file, 'truncate') + or (hasattr(file, 'seekable') and not file.seekable()) + ): + file = _SeekableWriter(file, closing=should_close) + if should_close or isinstance(file, (_PeekableReader, _SeekableWriter)): + return file + else: + return contextlib.nullcontext(file) diff --git a/dill/tests/test_utils.py b/dill/tests/test_utils.py new file mode 100644 index 00000000..49a5b86d --- /dev/null +++ b/dill/tests/test_utils.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python + +# Author: Leonardo Gama (@leogama) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE + +"""test general utilities in _utils.py""" + +import io +import os + +from dill import _utils + +def test_open(): + file_unpeekable = open(__file__, 'rb', buffering=0) + assert not hasattr(file_unpeekable, 'peek') + + content = file_unpeekable.read() + peeked_chars = content[:10] + first_line = content[:100].partition(b'\n')[0] + b'\n' + file_unpeekable.seek(0) + + # Test _PeekableReader for seekable stream + with _utils._open(file_unpeekable, 'r', peekable=True) as file: + assert isinstance(file, _utils._PeekableReader) + assert file.peek(10)[:10] == peeked_chars + assert file.readline() == first_line + assert not file_unpeekable.closed + file_unpeekable.close() + + _pipe_r, _pipe_w = os.pipe() + pipe_r = io.FileIO(_pipe_r, closefd=False) + pipe_w = io.FileIO(_pipe_w, mode='w') + assert not hasattr(pipe_r, 'peek') + assert not pipe_r.seekable() + assert not pipe_w.seekable() + + # Test io.BufferedReader for unseekable stream + with _utils._open(pipe_r, 'r', peekable=True) as file: + assert isinstance(file, io.BufferedReader) + pipe_w.write(content[:100]) + assert file.peek(10)[:10] == peeked_chars + assert file.readline() == first_line + assert not pipe_r.closed + + # Test _SeekableWriter for unseekable stream + with _utils._open(pipe_w, 'w', seekable=True) as file: + # pipe_r is closed here for some reason... + assert isinstance(file, _utils._SeekableWriter) + file.write(content) + file.flush() + file.seek(0) + file.truncate() + file.write(b'a line of text\n') + assert not pipe_w.closed + pipe_r = io.FileIO(_pipe_r) + assert pipe_r.readline() == b'a line of text\n' + pipe_r.close() + pipe_w.close() + +if __name__ == '__main__': + test_open()