Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New is_pickled_module() function #556

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dill/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
UnpicklingWarning,
)
from .session import (
dump_module, load_module, load_module_asdict,
dump_module, load_module, load_module_asdict, is_pickled_module,
dump_session, load_session # backward compatibility
)
from . import detect, logger, session, source, temp
Expand Down
104 changes: 104 additions & 0 deletions dill/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#!/usr/bin/env python
#
# Author: Leonardo Gama (@leogama)
# Copyright (c) 2022 The Uncertainty Quantification Foundation.
# License: 3-clause BSD. The full license text is available at:
# - https://github.com/uqfoundation/dill/blob/master/LICENSE
"""
Auxiliary classes and functions used in more than one module, defined here to
avoid circular import problems.
"""

import contextlib
import io
from contextlib import suppress


## File-related utilities ##

class _PeekableReader(contextlib.AbstractContextManager):
"""lightweight readable stream wrapper that implements peek()"""
def __init__(self, stream, closing=True):
self.stream = stream
self.closing = closing
def __exit__(self, *exc_info):
if self.closing:
self.stream.close()
def read(self, n):
return self.stream.read(n)
def readline(self):
return self.stream.readline()
def tell(self):
return self.stream.tell()
def close(self):
return self.stream.close()
def peek(self, n):
stream = self.stream
try:
if hasattr(stream, 'flush'):
stream.flush()
position = stream.tell()
stream.seek(position) # assert seek() works before reading
chunk = stream.read(n)
stream.seek(position)
return chunk
except (AttributeError, OSError):
raise NotImplementedError("stream is not peekable: %r", stream) from None

class _SeekableWriter(io.BytesIO, contextlib.AbstractContextManager):
"""works as an unlimited buffer, writes to file on close"""
def __init__(self, stream, closing=True, *args, **kwds):
super().__init__(*args, **kwds)
self.stream = stream
self.closing = closing
def __exit__(self, *exc_info):
self.close()
def close(self):
self.stream.write(self.getvalue())
with suppress(AttributeError):
self.stream.flush()
super().close()
if self.closing:
self.stream.close()

def _open(file, mode, *, peekable=False, seekable=False):
"""return a context manager with an opened file-like object"""
readonly = ('r' in mode and '+' not in mode)
if not readonly and peekable:
raise ValueError("the 'peekable' option is invalid for writable files")
if readonly and seekable:
raise ValueError("the 'seekable' option is invalid for read-only files")
should_close = not hasattr(file, 'read' if readonly else 'write')
if should_close:
file = open(file, mode)
# Wrap stream in a helper class if necessary.
if peekable and not hasattr(file, 'peek'):
# Try our best to return it as an object with a peek() method.
if hasattr(file, 'seekable'):
file_seekable = file.seekable()
elif hasattr(file, 'seek') and hasattr(file, 'tell'):
try:
file.seek(file.tell())
file_seekable = True
except Exception:
file_seekable = False
else:
file_seekable = False
if file_seekable:
file = _PeekableReader(file, closing=should_close)
else:
try:
file = io.BufferedReader(file)
except Exception:
# It won't be peekable, but will fail gracefully in _identify_module().
file = _PeekableReader(file, closing=should_close)
elif seekable and (
not hasattr(file, 'seek')
or not hasattr(file, 'truncate')
or (hasattr(file, 'seekable') and not file.seekable())
):
file = _SeekableWriter(file, closing=should_close)
if should_close or isinstance(file, (_PeekableReader, _SeekableWriter)):
return file
else:
return contextlib.nullcontext(file)
107 changes: 89 additions & 18 deletions dill/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@
"""

__all__ = [
'dump_module', 'load_module', 'load_module_asdict',
'dump_module', 'load_module', 'load_module_asdict', 'is_pickled_module',
'dump_session', 'load_session' # backward compatibility
]

import re
import sys
import warnings

from dill import _dill, Pickler, Unpickler
from dill import _dill
from dill import Pickler, Unpickler, UnpicklingError
from ._dill import (
BuiltinMethodType, FunctionType, MethodType, ModuleType, TypeType,
_import_module, _is_builtin_module, _is_imported_module, _main_module,
_reverse_typemap, __builtin__,
)
from ._utils import _open

# Type hints.
from typing import Optional, Union
Expand Down Expand Up @@ -285,26 +287,95 @@ def _make_peekable(stream):

def _identify_module(file, main=None):
"""identify the name of the module stored in the given file-type object"""
from pickletools import genops
UNICODE = {'UNICODE', 'BINUNICODE', 'SHORT_BINUNICODE'}
found_import = False
import pickletools
NEUTRAL = {'PROTO', 'FRAME', 'PUT', 'BINPUT', 'MEMOIZE', 'MARK', 'STACK_GLOBAL'}
try:
for opcode, arg, pos in genops(file.peek(256)):
if not found_import:
if opcode.name in ('GLOBAL', 'SHORT_BINUNICODE') and \
arg.endswith('_import_module'):
found_import = True
else:
if opcode.name in UNICODE:
return arg
else:
raise UnpicklingError("reached STOP without finding main module")
opcodes = ((opcode.name, arg) for opcode, arg, pos in pickletools.genops(file.peek(256))
if opcode.name not in NEUTRAL)
opcode, arg = next(opcodes)
if (opcode, arg) == ('SHORT_BINUNICODE', 'dill._dill'):
# The file uses STACK_GLOBAL instead of GLOBAL.
opcode, arg = next(opcodes)
if not (opcode in ('SHORT_BINUNICODE', 'GLOBAL') and arg.split()[-1] == '_import_module'):
raise ValueError
opcode, arg = next(opcodes)
if not opcode in ('SHORT_BINUNICODE', 'BINUNICODE', 'UNICODE'):
raise ValueError
module_name = arg
if not (
next(opcodes)[0] in ('TUPLE1', 'TUPLE') and
next(opcodes)[0] == 'REDUCE' and
next(opcodes)[0] in ('EMPTY_DICT', 'DICT')
):
raise ValueError
return module_name
except StopIteration:
raise UnpicklingError("reached STOP without finding module") from None
except (NotImplementedError, ValueError) as error:
# ValueError occours when the end of the chunk is reached (without a STOP).
# ValueError also occours when the end of the chunk is reached (without a STOP).
if isinstance(error, NotImplementedError) and main is not None:
# file is not peekable, but we have main.
# The file is not peekable, but we have the argument main.
return None
raise UnpicklingError("unable to identify main module") from error
raise UnpicklingError("unable to identify module") from error

def is_pickled_module(
filename, importable: bool = True, identify: bool = False
) -> Union[bool, str]:
"""Check if a file can be loaded with :func:`load_module`.

Check if the file is a pickle file generated with :func:`dump_module`,
and thus can be loaded with :func:`load_module`.

Parameters:
filename: a path-like object or a readable stream.
importable: expected kind of the file's saved module. Use `True` for
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doc is slightly confusing, especially if people might think that the file itself is importable (without first unpickling). Fundamentally, it's either a file-like module object or a module class instance... so is something in that vein a better name? It's a bit of an unusual thing for people to think about, so lets give the name some extra thought.

importable modules (the default) or `False` for module-type objects.
identify: if `True`, return the module name if the test succeeds.

Returns:
`True` if the pickle file at ``filename`` was generated with
:func:`dump_module` **AND** the module whose state is saved in it is
of the kind specified by the ``importable`` argument. `False` otherwise.
If `identify` is set, return the name of the module instead of `True`.

Examples:
Create three types of pickle files:

>>> import dill
>>> import types
>>> dill.dump_module('module_session.pkl') # saves __main__
>>> dill.dump_module('module_object.pkl', module=types.ModuleType('example'))
>>> with open('common_object.pkl', 'wb') as file:
>>> dill.dump('example', file)

Test each file's kind:

>>> dill.is_pickled_module('module_session.pkl') # the module is importable
True
>>> dill.is_pickled_module('module_session.pkl', importable=False)
False
>>> dill.is_pickled_module('module_object.pkl') # the module is not importable
False
>>> dill.is_pickled_module('module_object.pkl', importable=False)
True
>>> dill.is_pickled_module('module_object.pkl', importable=False, identify=True)
'example'
>>> dill.is_pickled_module('common_object.pkl') # always return False
False
>>> dill.is_pickled_module('common_object.pkl', importable=False)
False
"""
with _open(filename, 'rb', peekable=True) as file:
try:
pickle_main = _identify_module(file)
except UnpicklingError:
return False
is_runtime_mod = pickle_main.startswith('__runtime__.')
res = importable ^ is_runtime_mod
if res and identify:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function could be named pickled_module or identify_module, and the identify kwarg set to always be true and then removed.

return pickle_main.partition('.')[-1] if is_runtime_mod else pickle_main
else:
return res

def load_module(
filename = str(TEMPDIR/'session.pkl'),
Expand Down
32 changes: 32 additions & 0 deletions dill/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import __main__
from contextlib import suppress
from io import BytesIO
from types import ModuleType

import dill

Expand Down Expand Up @@ -271,10 +272,41 @@ def test_load_module_asdict():
assert 'y' not in main_vars
assert 'empty' in main_vars

def test_is_pickled_module():
import tempfile
import warnings

# Module saved with dump().
pickle_file = tempfile.NamedTemporaryFile(mode='wb')
dill.dump(os, pickle_file)
pickle_file.flush()
assert not dill.is_pickled_module(pickle_file.name)
assert not dill.is_pickled_module(pickle_file.name, importable=False)
pickle_file.close()

# Importable module saved with dump_module().
pickle_file = tempfile.NamedTemporaryFile(mode='wb')
dill.dump_module(pickle_file, local_mod)
pickle_file.flush()
assert dill.is_pickled_module(pickle_file.name)
assert not dill.is_pickled_module(pickle_file.name, importable=False)
assert dill.is_pickled_module(pickle_file.name, identify=True) == local_mod.__name__
pickle_file.close()

# Module-type object saved with dump_module().
pickle_file = tempfile.NamedTemporaryFile(mode='wb')
dill.dump_module(pickle_file, ModuleType('runtime'))
pickle_file.flush()
assert not dill.is_pickled_module(pickle_file.name)
assert dill.is_pickled_module(pickle_file.name, importable=False)
assert dill.is_pickled_module(pickle_file.name, importable=False, identify=True) == 'runtime'
pickle_file.close()

if __name__ == '__main__':
test_session_main(refimported=False)
test_session_main(refimported=True)
test_session_other()
test_runtime_module()
test_refimported_imported_as()
test_load_module_asdict()
test_is_pickled_module()
63 changes: 63 additions & 0 deletions dill/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python

# Author: Leonardo Gama (@leogama)
# Copyright (c) 2022 The Uncertainty Quantification Foundation.
# License: 3-clause BSD. The full license text is available at:
# - https://github.com/uqfoundation/dill/blob/master/LICENSE

"""test general utilities in _utils.py"""

import io
import os

from dill import _utils

def test_open():
file_unpeekable = open(__file__, 'rb', buffering=0)
assert not hasattr(file_unpeekable, 'peek')

content = file_unpeekable.read()
peeked_chars = content[:10]
first_line = content[:100].partition(b'\n')[0] + b'\n'
file_unpeekable.seek(0)

# Test _PeekableReader for seekable stream
with _utils._open(file_unpeekable, 'r', peekable=True) as file:
assert isinstance(file, _utils._PeekableReader)
assert file.peek(10)[:10] == peeked_chars
assert file.readline() == first_line
assert not file_unpeekable.closed
file_unpeekable.close()

_pipe_r, _pipe_w = os.pipe()
pipe_r = io.FileIO(_pipe_r, closefd=False)
pipe_w = io.FileIO(_pipe_w, mode='w')
assert not hasattr(pipe_r, 'peek')
assert not pipe_r.seekable()
assert not pipe_w.seekable()

# Test io.BufferedReader for unseekable stream
with _utils._open(pipe_r, 'r', peekable=True) as file:
assert isinstance(file, io.BufferedReader)
pipe_w.write(content[:100])
assert file.peek(10)[:10] == peeked_chars
assert file.readline() == first_line
assert not pipe_r.closed

# Test _SeekableWriter for unseekable stream
with _utils._open(pipe_w, 'w', seekable=True) as file:
# pipe_r is closed here for some reason...
assert isinstance(file, _utils._SeekableWriter)
file.write(content)
file.flush()
file.seek(0)
file.truncate()
file.write(b'a line of text\n')
assert not pipe_w.closed
pipe_r = io.FileIO(_pipe_r)
assert pipe_r.readline() == b'a line of text\n'
pipe_r.close()
pipe_w.close()

if __name__ == '__main__':
test_open()