diff --git a/dill/__init__.py b/dill/__init__.py index 1cf886ce..34ac5169 100644 --- a/dill/__init__.py +++ b/dill/__init__.py @@ -25,7 +25,8 @@ from ._dill import dump, dumps, load, loads, dump_session, load_session, \ Pickler, Unpickler, register, copy, pickle, pickles, check, \ HIGHEST_PROTOCOL, DEFAULT_PROTOCOL, PicklingError, UnpicklingError, \ - HANDLE_FMODE, CONTENTS_FMODE, FILE_FMODE + HANDLE_FMODE, CONTENTS_FMODE, FILE_FMODE, PickleError, PickleWarning, \ + PicklingWarning, UnpicklingWarning from . import source, temp, detect # get global settings diff --git a/dill/_dill.py b/dill/_dill.py index 9c1813c9..2dac61c6 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -18,7 +18,8 @@ __all__ = ['dump','dumps','load','loads','dump_session','load_session', 'Pickler','Unpickler','register','copy','pickle','pickles', 'check','HIGHEST_PROTOCOL','DEFAULT_PROTOCOL','PicklingError', - 'UnpicklingError','HANDLE_FMODE','CONTENTS_FMODE','FILE_FMODE'] + 'UnpicklingError','HANDLE_FMODE','CONTENTS_FMODE','FILE_FMODE', + 'PickleError','PickleWarning','PicklingWarning','UnpicklingWarning'] import logging log = logging.getLogger("dill") @@ -28,8 +29,7 @@ def _trace(boolean): if boolean: log.setLevel(logging.INFO) else: log.setLevel(logging.WARN) return - -stack = dict() # record of 'recursion-sensitive' pickled objects +import warnings import os import sys @@ -39,6 +39,7 @@ def _trace(boolean): # OLDER: 3.0 <= x < 3.4 *OR* x < 2.7.10 #NOTE: guessing relevant versions OLDER = (PY3 and sys.hexversion < 0x3040000) or (sys.hexversion < 0x2070ab1) OLD33 = (sys.hexversion < 0x3030000) +OLD37 = (sys.hexversion < 0x3070000) PY34 = (0x3040000 <= sys.hexversion < 0x3050000) if PY3: #XXX: get types from .objtypes ? import builtins as __builtin__ @@ -72,7 +73,7 @@ def _trace(boolean): GeneratorType, DictProxyType, XRangeType, SliceType, TracebackType, \ NotImplementedType, EllipsisType, FrameType, ModuleType, \ BufferType, BuiltinMethodType, TypeType -from pickle import HIGHEST_PROTOCOL, PicklingError, UnpicklingError +from pickle import HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError try: from pickle import DEFAULT_PROTOCOL except ImportError: @@ -261,6 +262,34 @@ def get_file_type(*args, **kwargs): except NameError: ExitType = None singletontypes = [] +import inspect + +### Shims for different versions of Python and dill +class Sentinel(object): + """ + Create a unique sentinel object that is pickled as a constant. + """ + def __init__(self, name, module_name=None): + self.name = name + if module_name is None: + # Use the calling frame's module + self.__module__ = inspect.currentframe().f_back.f_globals['__name__'] + else: + self.__module__ = module_name # pragma: no cover + def __repr__(self): + return self.__module__ + '.' + self.name # pragma: no cover + def __copy__(self): + return self # pragma: no cover + def __deepcopy__(self, memo): + return self # pragma: no cover + def __reduce__(self): + return self.name + def __reduce_ex__(self, protocol): + return self.name + +from . import _shims +from ._shims import Reduce, Getattr + ### File modes #: Pickles the file handle, preserving mode. The position of the unpickled #: object is as for a new file handle. @@ -460,6 +489,14 @@ def __missing__(self, key): else: raise KeyError() +class PickleWarning(Warning, PickleError): + pass + +class PicklingWarning(PickleWarning, PicklingError): + pass + +class UnpicklingWarning(PickleWarning, UnpicklingError): + pass ### Extend the Picklers class Pickler(StockPickler): @@ -481,9 +518,9 @@ def __init__(self, *args, **kwds): self._strictio = False #_strictio self._fmode = settings['fmode'] if _fmode is None else _fmode self._recurse = settings['recurse'] if _recurse is None else _recurse + self._postproc = {} def dump(self, obj): #NOTE: if settings change, need to update attributes - stack.clear() # clear record of 'recursion-sensitive' pickled objects # register if the object is a numpy ufunc # thanks to Paul Kienzle for pointing out ufuncs didn't pickle if NumpyUfuncType and numpyufunc(obj): @@ -528,7 +565,6 @@ def save_numpy_array(pickler, obj): raise PicklingError(msg) else: StockPickler.dump(self, obj) - stack.clear() # clear record of 'recursion-sensitive' pickled objects return dump.__doc__ = StockPickler.dump.__doc__ pass @@ -680,6 +716,7 @@ def _create_function(fcode, fglobals, fname=None, fdefaults=None, # thus we need to make sure that we have __builtins__ as well if "__builtins__" not in func.__globals__: func.__globals__["__builtins__"] = globals()["__builtins__"] + # assert id(fglobals) == id(func.__globals__) return func def _create_code(*args): @@ -876,12 +913,25 @@ def __getattribute__(self, attr): attrs[index] = ".".join([attrs[index], attr]) return type(self)(attrs, index) +# _CELL_REF and _CELL_EMPTY are used to stay compatible with versions of dill +# whose _create_cell functions do not have a default value. +# _CELL_REF can be safely removed entirely (replaced by empty tuples for calls +# to _create_cell) once breaking changes are allowed. +_CELL_REF = None +_CELL_EMPTY = Sentinel('_CELL_EMPTY') + if PY3: - def _create_cell(contents): - return (lambda y: contents).__closure__[0] + def _create_cell(contents=None): + if contents is not _CELL_EMPTY: + value = contents + return (lambda: value).__closure__[0] + else: - def _create_cell(contents): - return (lambda y: contents).func_closure[0] + def _create_cell(contents=None): + if contents is not _CELL_EMPTY: + value = contents + return (lambda: value).func_closure[0] + def _create_weakref(obj, *args): from weakref import ref @@ -977,6 +1027,56 @@ def _locate_function(obj, session=False): found = _import_module(obj.__module__ + '.' + obj.__name__, safe=True) return found is obj + +def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO_DEFAULT, postproc_list=None): + if obj is Getattr.NO_DEFAULT: + obj = Reduce(reduction) # pragma: no cover + + if is_pickler_dill is None: + is_pickler_dill = is_dill(pickler, child=True) + if is_pickler_dill: + # assert id(obj) not in pickler._postproc, str(obj) + ' already pushed on stack!' + # if not hasattr(pickler, 'x'): pickler.x = 0 + # print(pickler.x*' ', 'push', obj, id(obj), pickler._recurse) + # pickler.x += 1 + if postproc_list is None: + postproc_list = [] + + # Recursive object not supported. Default to a global instead. + if id(obj) in pickler._postproc: + name = '%s.%s ' % (obj.__module__, getattr(obj, '__qualname__', obj.__name__)) if hasattr(obj, '__module__') else '' + warnings.warn('Cannot pickle %r: %shas recursive self-references that trigger a RecursionError.' % (obj, name), PicklingWarning) + pickler.save_global(obj) + return + pickler._postproc[id(obj)] = postproc_list + + # TODO: Use state_setter in Python 3.8 to allow for faster cPickle implementations + pickler.save_reduce(*reduction, obj=obj) + + if is_pickler_dill: + # pickler.x -= 1 + # print(pickler.x*' ', 'pop', obj, id(obj)) + postproc = pickler._postproc.pop(id(obj)) + # assert postproc_list == postproc, 'Stack tampered!' + for reduction in reversed(postproc): + if reduction[0] is dict.update and type(reduction[1][0]) is dict: + # use the internal machinery of pickle.py to speedup when + # updating a dictionary in postproc + dest, source = reduction[1] + if source: + pickler.write(pickler.get(pickler.memo[id(dest)][0])) + pickler._batch_setitems(iter(source.items())) + else: + # Updating with an empty dictionary. Same as doing nothing. + continue + else: + pickler.save_reduce(*reduction) + # pop None created by calling preprocessing step off stack + if PY3: + pickler.write(bytes('0', 'UTF-8')) + else: + pickler.write('0') + #@register(CodeType) #def save_code(pickler, obj): # log.info("Co: %s" % obj) @@ -1067,7 +1167,6 @@ def save_module_dict(pickler, obj): @register(ClassType) def save_classobj(pickler, obj): #FIXME: enable pickler._byref - #stack[id(obj)] = len(stack), obj if obj.__module__ == '__main__': #XXX: use _main_module.__name__ everywhere? log.info("C1: %s" % obj) pickler.save_reduce(ClassType, (obj.__name__, obj.__bases__, @@ -1280,6 +1379,10 @@ def save_wrapper_descriptor(pickler, obj): @register(MethodWrapperType) def save_instancemethod(pickler, obj): log.info("Mw: %s" % obj) + if IS_PYPY2 and obj.__self__ is None and obj.im_class: + # Can be a class method in PYPY2 if __self__ is none + pickler.save_reduce(getattr, (obj.im_class, obj.__name__), obj=obj) + return pickler.save_reduce(getattr, (obj.__self__, obj.__name__), obj=obj) log.info("# Mw") return @@ -1296,10 +1399,44 @@ def save_wrapper_descriptor(pickler, obj): @register(CellType) def save_cell(pickler, obj): - log.info("Ce: %s" % obj) - f = obj.cell_contents + try: + f = obj.cell_contents + except: + log.info("Ce3: %s" % obj) + # _shims._CELL_EMPTY is defined in _shims.py to support PyPy 2.7. + # It unpickles to a sentinel object _dill._CELL_EMPTY, also created in + # _shims.py. This object is not present in Python 3 because the cell's + # contents can be deleted in newer versions of Python. The reduce object + # will instead unpickle to None if unpickled in Python 3. + + # When breaking changes are made to dill, (_shims._CELL_EMPTY,) can + # be replaced by () OR the delattr function can be removed repending on + # whichever is more convienient. + pickler.save_reduce(_create_cell, (_shims._CELL_EMPTY,), obj=obj) + # Call the function _delattr on the cell's cell_contents attribute + # The result of this function call will be None + pickler.save_reduce(_shims._delattr, (obj, 'cell_contents')) + # pop None created by calling _delattr off stack + if PY3: + pickler.write(bytes('0', 'UTF-8')) + else: + pickler.write('0') + log.info("# Ce3") + return + if is_dill(pickler, child=True): + postproc = pickler._postproc.get(id(f)) + if postproc is not None: + log.info("Ce2: %s" % obj) + # _CELL_REF is defined in _shims.py to support older versions of + # dill. When breaking changes are made to dill, (_CELL_REF,) can + # be replaced by () + postproc.append((_shims._setattr, (obj, 'cell_contents', f))) + pickler.save_reduce(_create_cell, (_CELL_REF,), obj=obj) + log.info("# Ce2") + return + log.info("Ce1: %s" % obj) pickler.save_reduce(_create_cell, (f,), obj=obj) - log.info("# Ce") + log.info("# Ce1") return if not IS_PYPY: @@ -1412,22 +1549,22 @@ def save_weakproxy(pickler, obj): @register(ModuleType) def save_module(pickler, obj): if False: #_use_diff: - if obj.__name__ != "dill": + if obj.__name__.split('.', 1)[0] != "dill": try: changed = diff.whats_changed(obj, seen=pickler._diff_cache)[0] except RuntimeError: # not memorised module, probably part of dill pass else: - log.info("M1: %s with diff" % obj) + log.info("M2: %s with diff" % obj) log.info("Diff: %s", changed.keys()) pickler.save_reduce(_import_module, (obj.__name__,), obj=obj, state=changed) - log.info("# M1") + log.info("# M2") return - log.info("M2: %s" % obj) + log.info("M1: %s" % obj) pickler.save_reduce(_import_module, (obj.__name__,), obj=obj) - log.info("# M2") + log.info("# M1") else: # if a module file name starts with prefix, it should be a builtin # module, so should be pickled as a reference @@ -1440,7 +1577,7 @@ def save_module(pickler, obj): 'site-packages' in obj.__file__) else: builtin_mod = True - if obj.__name__ not in ("builtins", "dill") \ + if obj.__name__ not in ("builtins", "dill", "dill._dill") \ and not builtin_mod or is_dill(pickler, child=True) and obj is pickler._main: log.info("M1: %s" % obj) _main_dict = obj.__dict__.copy() #XXX: better no copy? option to copy? @@ -1449,6 +1586,10 @@ def save_module(pickler, obj): pickler.save_reduce(_import_module, (obj.__name__,), obj=obj, state=_main_dict) log.info("# M1") + elif PY3 and obj.__name__ == "dill._dill": + log.info("M2: %s" % obj) + pickler.save_global(obj, name="_dill") + log.info("# M2") else: log.info("M2: %s" % obj) pickler.save_reduce(_import_module, (obj.__name__,), obj=obj) @@ -1457,8 +1598,7 @@ def save_module(pickler, obj): return @register(TypeType) -def save_type(pickler, obj): - #stack[id(obj)] = len(stack), obj #XXX: probably don't obj in all cases below +def save_type(pickler, obj, postproc_list=None): if obj in _typemap: log.info("T1: %s" % obj) pickler.save_reduce(_load_type, (_typemap[obj],), obj=obj) @@ -1469,33 +1609,7 @@ def save_type(pickler, obj): pickler.save_reduce(_create_namedtuple, (getattr(obj, "__qualname__", obj.__name__), obj._fields, obj.__module__), obj=obj) log.info("# T6") return - elif obj.__module__ == '__main__': - if issubclass(type(obj), type): - # try: # used when pickling the class as code (or the interpreter) - if is_dill(pickler, child=True) and not pickler._byref: - # thanks to Tom Stepleton pointing out pickler._session unneeded - _t = 'T2' - log.info("%s: %s" % (_t, obj)) - _dict = _dict_from_dictproxy(obj.__dict__) - # except: # punt to StockPickler (pickle by class reference) - else: - log.info("T5: %s" % obj) - name = getattr(obj, '__qualname__', getattr(obj, '__name__', None)) - StockPickler.save_global(pickler, obj, name=name) - log.info("# T5") - return - else: - _t = 'T3' - log.info("%s: %s" % (_t, obj)) - _dict = obj.__dict__ - #print (_dict) - #print ("%s\n%s" % (type(obj), obj.__name__)) - #print ("%s\n%s" % (obj.__bases__, obj.__dict__)) - for name in _dict.get("__slots__", []): - del _dict[name] - pickler.save_reduce(_create_type, (type(obj), obj.__name__, - obj.__bases__, _dict), obj=obj) - log.info("# %s" % _t) + # special cases: NoneType, NotImplementedType, EllipsisType elif obj is type(None): log.info("T7: %s" % obj) @@ -1513,16 +1627,50 @@ def save_type(pickler, obj): log.info("T7: %s" % obj) pickler.save_reduce(type, (Ellipsis,), obj=obj) log.info("# T7") + else: - log.info("T4: %s" % obj) - #print (obj.__dict__) - #print ("%s\n%s" % (type(obj), obj.__name__)) - #print ("%s\n%s" % (obj.__bases__, obj.__dict__)) - name = getattr(obj, '__qualname__', getattr(obj, '__name__', None)) - StockPickler.save_global(pickler, obj, name=name) - log.info("# T4") + obj_name = getattr(obj, '__qualname__', getattr(obj, '__name__', None)) + _byref = getattr(pickler, '_byref', None) + obj_recursive = id(obj) in getattr(pickler, '_postproc', ()) + incorrectly_named = not _locate_function(obj) + if not _byref and not obj_recursive and incorrectly_named: # not a function, but the name was held over + if issubclass(type(obj), type): + # thanks to Tom Stepleton pointing out pickler._session unneeded + _t = 'T2' + log.info("%s: %s" % (_t, obj)) + _dict = _dict_from_dictproxy(obj.__dict__) + else: + _t = 'T3' + log.info("%s: %s" % (_t, obj)) + _dict = obj.__dict__ + #print (_dict) + #print ("%s\n%s" % (type(obj), obj.__name__)) + #print ("%s\n%s" % (obj.__bases__, obj.__dict__)) + for name in _dict.get("__slots__", []): + del _dict[name] + _save_with_postproc(pickler, (_create_type, ( + type(obj), obj_name, obj.__bases__, _dict + )), obj=obj, postproc_list=postproc_list) + log.info("# %s" % _t) + else: + log.info("T4: %s" % obj) + if incorrectly_named: + warnings.warn('Cannot locate reference to %r.' % (obj,), PicklingWarning) + if obj_recursive: + warnings.warn('Cannot pickle %r: %s.%s has recursive self-references that trigger a RecursionError.' % (obj, obj.__module__, obj_name), PicklingWarning) + #print (obj.__dict__) + #print ("%s\n%s" % (type(obj), obj.__name__)) + #print ("%s\n%s" % (obj.__bases__, obj.__dict__)) + StockPickler.save_global(pickler, obj, name=obj_name) + log.info("# T4") return +# Error in PyPy 2.7 when adding ABC support +if IS_PYPY2: + @register(FrameType) + def save_frame(pickler, obj): + raise PicklingError('Cannot pickle a Python stack frame') + @register(property) def save_property(pickler, obj): log.info("Pr: %s" % obj) @@ -1541,6 +1689,18 @@ def save_classmethod(pickler, obj): orig_func = obj.__get__(None, object) if isinstance(obj, classmethod): orig_func = getattr(orig_func, im_func) # Unbind + + # if PY3: + # if type(obj.__dict__) is dict: + # if obj.__dict__: + # state = obj.__dict__ + # else: + # state = None + # else: + # state = (None, {'__dict__', obj.__dict__}) + # else: + # state = None + pickler.save_reduce(type(obj), (orig_func,), obj=obj) log.info("# Cm") @@ -1548,53 +1708,56 @@ def save_classmethod(pickler, obj): def save_function(pickler, obj): if not _locate_function(obj): #, pickler._session): log.info("F1: %s" % obj) - if getattr(pickler, '_recurse', False): + _recurse = getattr(pickler, '_recurse', None) + _byref = getattr(pickler, '_byref', None) + _postproc = getattr(pickler, '_postproc', None) + postproc_list = [] + if _recurse: # recurse to get all globals referred to by obj from .detect import globalvars - globs = globalvars(obj, recurse=True, builtin=True) - # remove objects that have already been serialized - #stacktypes = (ClassType, TypeType, FunctionType) - #for key,value in list(globs.items()): - # if isinstance(value, stacktypes) and id(value) in stack: - # del globs[key] - # ABORT: if self-references, use _recurse=False - if id(obj) in stack: # or obj in globs.values(): - globs = obj.__globals__ if PY3 else obj.func_globals + globs_copy = globalvars(obj, recurse=True, builtin=True) + + # Add the name of the module to the globs dictionary to prevent + # the duplication of the dictionary. Pickle the unpopulated + # globals dictionary and set the remaining items after the function + # is created to correctly handle recursion. + globs = {'__name__': obj.__module__} else: - globs = obj.__globals__ if PY3 else obj.func_globals - _byref = getattr(pickler, '_byref', None) - _recurse = getattr(pickler, '_recurse', None) - _memo = (id(obj) in stack) and (_recurse is not None) - #print("stack: %s + '%s'" % (set(hex(i) for i in stack),hex(id(obj)))) - stack[id(obj)] = len(stack), obj + globs_copy = obj.__globals__ if PY3 else obj.func_globals + + # If the globals is a module __dict__, do not save it in the pickle. + if globs_copy is not None and obj.__module__ is not None and \ + getattr(_import_module(obj.__module__, True), '__dict__', None) is globs_copy: + globs = globs_copy + else: + globs = {'__name__': obj.__module__} + + if globs_copy is not None and globs is not globs_copy: + # In the case that the globals are copied, we need to ensure that + # the globals dictionary is updated when all objects in the + # dictionary are already created. + if PY3: + glob_ids = {id(g) for g in globs_copy.values()} + else: + glob_ids = {id(g) for g in globs_copy.itervalues()} + for stack_element in _postproc: + if stack_element in glob_ids: + _postproc[stack_element].append((dict.update, (globs, globs_copy))) + break + else: + postproc_list.append((dict.update, (globs, globs_copy))) + if PY3: - #NOTE: workaround for 'super' (see issue #75) - _super = ('super' in getattr(obj.__code__,'co_names',())) and (_byref is not None) - if _super: pickler._byref = True - if _memo: pickler._recurse = False fkwdefaults = getattr(obj, '__kwdefaults__', None) - pickler.save_reduce(_create_function, (obj.__code__, - globs, obj.__name__, - obj.__defaults__, obj.__closure__, - obj.__dict__, fkwdefaults), obj=obj) + _save_with_postproc(pickler, (_create_function, ( + obj.__code__, globs, obj.__name__, obj.__defaults__, + obj.__closure__, obj.__dict__, fkwdefaults + )), obj=obj, postproc_list=postproc_list) else: - _super = ('super' in getattr(obj.func_code,'co_names',())) and (_byref is not None) and getattr(pickler, '_recurse', False) - if _super: pickler._byref = True - if _memo: pickler._recurse = False - pickler.save_reduce(_create_function, (obj.func_code, - globs, obj.func_name, - obj.func_defaults, obj.func_closure, - obj.__dict__), obj=obj) - if _super: pickler._byref = _byref - if _memo: pickler._recurse = _recurse - #clear = (_byref, _super, _recurse, _memo) - #print(clear + (OLDER,)) - #NOTE: workaround for #234; "partial" still is problematic for recurse - if OLDER and not _byref and (_super or (not _super and _memo) or (not _super and not _memo and _recurse)): pickler.clear_memo() - #if _memo: - # stack.remove(id(obj)) - # #pickler.clear_memo() - # #StockPickler.clear_memo(pickler) + _save_with_postproc(pickler, (_create_function, ( + obj.func_code, globs, obj.func_name, obj.func_defaults, + obj.func_closure, obj.__dict__ + )), obj=obj, postproc_list=postproc_list) log.info("# F1") else: log.info("F2: %s" % obj) @@ -1626,7 +1789,6 @@ def pickles(obj,exact=False,safe=False,**kwds): #FIXME: should be "(pik == obj).all()" for numpy comparison, though that'll fail if shapes differ result = bool(pik.all() == obj.all()) except AttributeError: - import warnings warnings.filterwarnings('ignore') result = pik == obj warnings.resetwarnings() diff --git a/dill/_shims.py b/dill/_shims.py new file mode 100644 index 00000000..ac20eca3 --- /dev/null +++ b/dill/_shims.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python +# +# Author: Mike McKerns (mmckerns @caltech and @uqfoundation) +# Author: Anirudh Vegesana (avegesan@stanford.edu) +# Copyright (c) 2021 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE +""" +Provides shims for compatibility between versions of dill and Python. + +Compatibility shims should be provided in this file. Here are two simple example +use cases. + +Deprecation of constructor function: +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Assume that we were transitioning _import_module in _dill.py to +the builtin function importlib.import_module when present. + +@move_to(_dill) +def _import_module(import_name): + ... # code already in _dill.py + +_import_module = Getattr(importlib, 'import_module', Getattr(_dill, '_import_module', None)) + +The code will attempt to find import_module in the importlib module. If not +present, it will use the _import_module function in _dill. + +Emulate new Python behavior in older Python versions: +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +CellType.cell_contents behaves differently in Python 3.6 and 3.7. It is +read-only in Python 3.6 and writable and deletable in 3.7. + +if _dill.OLD37 and _dill.HAS_CTYPES and ...: + @move_to(_dill) + def _setattr(object, name, value): + if type(object) is _dill.CellType and name == 'cell_contents': + _PyCell_Set.argtypes = (ctypes.py_object, ctypes.py_object) + _PyCell_Set(object, value) + else: + setattr(object, name, value) +... # more cases below + +_setattr = Getattr(_dill, '_setattr', setattr) + +_dill._setattr will be used when present to emulate Python 3.7 functionality in +older versions of Python while defaulting to the standard setattr in 3.7+. + +See this PR for the discussion that lead to this system: +https://github.com/uqfoundation/dill/pull/443 +""" + +import inspect, sys + +_dill = sys.modules['dill._dill'] + + +class Reduce(object): + """ + Reduce objects are wrappers used for compatibility enforcement during + unpickle-time. They should only be used in calls to pickler.save and + other Reduce objects. They are only evaluated within unpickler.load. + + Pickling a Reduce object makes the two implementations equivalent: + + pickler.save(Reduce(*reduction)) + + pickler.save_reduce(*reduction, obj=reduction) + """ + __slots__ = ['reduction'] + def __new__(cls, *reduction, **kwargs): + """ + Args: + *reduction: a tuple that matches the format given here: + https://docs.python.org/3/library/pickle.html#object.__reduce__ + is_callable: a bool to indicate that the object created by + unpickling `reduction` is callable. If true, the current Reduce + is allowed to be used as the function in further save_reduce calls + or Reduce objects. + """ + is_callable = kwargs.get('is_callable', False) # Pleases Py2. Can be removed later + if is_callable: + self = object.__new__(_CallableReduce) + else: + self = object.__new__(Reduce) + self.reduction = reduction + return self + def __repr__(self): + return 'Reduce%s' % (self.reduction,) + def __copy__(self): + return self # pragma: no cover + def __deepcopy__(self, memo): + return self # pragma: no cover + def __reduce__(self): + return self.reduction + def __reduce_ex__(self, protocol): + return self.__reduce__() + +class _CallableReduce(Reduce): + # A version of Reduce for functions. Used to trick pickler.save_reduce into + # thinking that Reduce objects of functions are themselves meaningful functions. + def __call__(self, *args, **kwargs): + reduction = self.__reduce__() + func = reduction[0] + f_args = reduction[1] + obj = func(*f_args) + return obj(*args, **kwargs) + +__NO_DEFAULT = _dill.Sentinel('Getattr.NO_DEFAULT') + +def Getattr(object, name, default=__NO_DEFAULT): + """ + A Reduce object that represents the getattr operation. When unpickled, the + Getattr will access an attribute 'name' of 'object' and return the value + stored there. If the attribute doesn't exist, the default value will be + returned if present. + + The following statements are equivalent: + + Getattr(collections, 'OrderedDict') + Getattr(collections, 'spam', None) + Getattr(*args) + + Reduce(getattr, (collections, 'OrderedDict')) + Reduce(getattr, (collections, 'spam', None)) + Reduce(getattr, args) + + During unpickling, the first two will result in collections.OrderedDict and + None respectively because the first attribute exists and the second one does + not, forcing it to use the default value given in the third argument. + """ + + if default is Getattr.NO_DEFAULT: + reduction = (getattr, (object, name)) + else: + reduction = (getattr, (object, name, default)) + + return Reduce(*reduction, is_callable=callable(default)) + +Getattr.NO_DEFAULT = __NO_DEFAULT +del __NO_DEFAULT + +def move_to(module, name=None): + def decorator(func): + if name is None: + fname = func.__name__ + else: + fname = name + module.__dict__[fname] = func + func.__module__ = module.__name__ + return func + return decorator + +###################### +## Compatibility Shims are defined below +###################### + +_CELL_EMPTY = Getattr(_dill, '_CELL_EMPTY', None) + +if _dill.OLD37: + if _dill.HAS_CTYPES and hasattr(_dill.ctypes, 'pythonapi') and hasattr(_dill.ctypes.pythonapi, 'PyCell_Set'): + # CPython + ctypes = _dill.ctypes + + _PyCell_Set = ctypes.pythonapi.PyCell_Set + + @move_to(_dill) + def _setattr(object, name, value): + if type(object) is _dill.CellType and name == 'cell_contents': + _PyCell_Set.argtypes = (ctypes.py_object, ctypes.py_object) + _PyCell_Set(object, value) + else: + setattr(object, name, value) + + @move_to(_dill) + def _delattr(object, name): + if type(object) is _dill.CellType and name == 'cell_contents': + _PyCell_Set.argtypes = (ctypes.py_object, ctypes.c_void_p) + _PyCell_Set(object, None) + else: + delattr(object, name) + + # General Python (not CPython) up to 3.6 is in a weird case, where it is + # possible to pickle recursive cells, but we can't assign directly to the + # cell. + elif _dill.PY3: + # Use nonlocal variables to reassign the cell value. + # https://stackoverflow.com/a/59276835 + __nonlocal = ('nonlocal cell',) + exec('''def _setattr(cell, name, value): + if type(cell) is _dill.CellType and name == 'cell_contents': + def cell_setter(value): + %s + cell = value # pylint: disable=unused-variable + func = _dill.FunctionType(cell_setter.__code__, globals(), "", None, (cell,)) # same as cell_setter, but with cell being the cell's contents + func(value) + else: + setattr(cell, name, value)''' % __nonlocal) + move_to(_dill)(_setattr) + + exec('''def _delattr(cell, name): + if type(cell) is _dill.CellType and name == 'cell_contents': + try: + cell.cell_contents + except: + return + def cell_deleter(): + %s + del cell # pylint: disable=unused-variable + func = _dill.FunctionType(cell_deleter.__code__, globals(), "", None, (cell,)) # same as cell_deleter, but with cell being the cell's contents + func() + else: + delattr(cell, name)''' % __nonlocal) + move_to(_dill)(_delattr) + + else: + # Likely PyPy 2.7. Simulate the nonlocal keyword with bytecode + # manipulation. + + # The following function is based on 'cell_set' from 'cloudpickle' + # https://github.com/cloudpipe/cloudpickle/blob/5d89947288a18029672596a4d719093cc6d5a412/cloudpickle/cloudpickle.py#L393-L482 + # Copyright (c) 2012, Regents of the University of California. + # Copyright (c) 2009 `PiCloud, Inc. `_. + # License: https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE + @move_to(_dill) + def _setattr(cell, name, value): + if type(cell) is _dill.CellType and name == 'cell_contents': + _cell_set = _dill.FunctionType( + _cell_set_template_code, {}, '_cell_set', (), (cell,),) + _cell_set(value) + else: + setattr(cell, name, value) + + def _cell_set_factory(value): + lambda: cell + cell = value + + co = _cell_set_factory.__code__ + + _cell_set_template_code = _dill.CodeType( + co.co_argcount, + co.co_nlocals, + co.co_stacksize, + co.co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_cellvars, # co_freevars is initialized with co_cellvars + (), # co_cellvars is made empty + ) + + del co + + @move_to(_dill) + def _delattr(cell, name): + if type(cell) is _dill.CellType and name == 'cell_contents': + pass + else: + delattr(cell, name) + +_setattr = Getattr(_dill, '_setattr', setattr) +_delattr = Getattr(_dill, '_delattr', delattr) diff --git a/dill/detect.py b/dill/detect.py index 59abee3f..3e5768ae 100644 --- a/dill/detect.py +++ b/dill/detect.py @@ -157,7 +157,16 @@ def freevars(func): func = getattr(func, func_code).co_freevars # get freevars else: return {} - return dict((name,c.cell_contents) for (name,c) in zip(func,closures)) + + def get_cell_contents(): + for (name,c) in zip(func,closures): + try: + cell_contents = c.cell_contents + except: + continue + yield (name,c.cell_contents) + + return dict(get_cell_contents()) # thanks to Davies Liu for recursion of globals def nestedglobals(func, recurse=True): @@ -201,9 +210,14 @@ def globalvars(func, recurse=True, builtin=False): # get references from within closure orig_func, func = func, set() for obj in getattr(orig_func, func_closure) or {}: - _vars = globalvars(obj.cell_contents, recurse, builtin) or {} - func.update(_vars) #XXX: (above) be wary of infinte recursion? - globs.update(_vars) + try: + cell_contents = obj.cell_contents + except: + pass + else: + _vars = globalvars(cell_contents, recurse, builtin) or {} + func.update(_vars) #XXX: (above) be wary of infinte recursion? + globs.update(_vars) # get globals globs.update(getattr(orig_func, func_globals) or {}) # get names of references diff --git a/tests/test_classdef.py b/tests/test_classdef.py index 3b2442e9..5f07be5e 100644 --- a/tests/test_classdef.py +++ b/tests/test_classdef.py @@ -85,8 +85,10 @@ def test_class_objects(): assert type(_cls).__name__ == "_meta" # test NoneType -def test_none(): +def test_specialtypes(): assert dill.pickles(type(None)) + assert dill.pickles(type(NotImplemented)) + assert dill.pickles(type(Ellipsis)) if hex(sys.hexversion) >= '0x20600f0': from collections import namedtuple @@ -204,7 +206,7 @@ def test_slots(): if __name__ == '__main__': test_class_instances() test_class_objects() - test_none() + test_specialtypes() test_namedtuple() test_dtype() test_array_nested() diff --git a/tests/test_functions.py b/tests/test_functions.py index 23de5f09..48d62f5e 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -36,6 +36,12 @@ def function_e(e, *e1, e2=1, e3=2): return e + sum(e1) + e2 + e3''') +def function_with_unassigned_variable(): + if False: + value = None + return (lambda: value) + + def test_functions(): dumped_func_a = dill.dumps(function_a) assert dill.loads(dumped_func_a)(0) == 0 @@ -52,6 +58,17 @@ def test_functions(): assert dill.loads(dumped_func_d)(1, 2, 3) == 6 assert dill.loads(dumped_func_d)(1, 2, d2=3) == 6 + empty_cell = function_with_unassigned_variable() + cell_copy = dill.loads(dill.dumps(empty_cell)) + assert 'empty' in str(cell_copy.__closure__[0]) + try: + cell_copy() + except: + # this is good + pass + else: + raise AssertionError('cell_copy() did not read an empty cell') + if is_py3(): exec(''' dumped_func_e = dill.dumps(function_e) @@ -62,6 +79,5 @@ def test_functions(): assert dill.loads(dumped_func_e)(1, 2, 3, e2=4) == 12 assert dill.loads(dumped_func_e)(1, 2, 3, e2=4, e3=5) == 15''') - if __name__ == '__main__': test_functions() diff --git a/tests/test_recursive.py b/tests/test_recursive.py index 78e5790b..a042385f 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -6,9 +6,27 @@ # - https://github.com/uqfoundation/dill/blob/master/LICENSE import dill +from dill._dill import PY3 from functools import partial -from dill._dill import PY3, OLDER -_super = super +import warnings + + +def copy(obj, byref=False, recurse=False): + if byref: + try: + return dill.copy(obj, byref=byref, recurse=recurse) + except: + pass + else: + raise AssertionError('Copy of %s with byref=True should have given a warning!' % (obj,)) + + warnings.simplefilter('ignore') + val = dill.copy(obj, byref=byref, recurse=recurse) + warnings.simplefilter('error') + return val + else: + return dill.copy(obj, byref=byref, recurse=recurse) + class obj1(object): def __init__(self): @@ -16,7 +34,7 @@ def __init__(self): class obj2(object): def __init__(self): - _super(obj2, self).__init__() + super(obj2, self).__init__() class obj3(object): super_ = super @@ -25,20 +43,20 @@ def __init__(self): def test_super(): - assert dill.copy(obj1(), byref=True) - assert dill.copy(obj1(), byref=True, recurse=True) - #assert dill.copy(obj1(), recurse=True) #FIXME: fails __main__.py - assert dill.copy(obj1()) + assert copy(obj1(), byref=True) + assert copy(obj1(), byref=True, recurse=True) + assert copy(obj1(), recurse=True) + assert copy(obj1()) - assert dill.copy(obj2(), byref=True) - assert dill.copy(obj2(), byref=True, recurse=True) - #assert dill.copy(obj2(), recurse=True) #FIXME: fails __main__.py - assert dill.copy(obj2()) + assert copy(obj2(), byref=True) + assert copy(obj2(), byref=True, recurse=True) + assert copy(obj2(), recurse=True) + assert copy(obj2()) - assert dill.copy(obj3(), byref=True) - assert dill.copy(obj3(), byref=True, recurse=True) - #assert dill.copy(obj3(), recurse=True) #FIXME: fails __main__.py - assert dill.copy(obj3()) + assert copy(obj3(), byref=True) + assert copy(obj3(), byref=True, recurse=True) + assert copy(obj3(), recurse=True) + assert copy(obj3()) def get_trigger(model): @@ -56,11 +74,10 @@ class Model(object): def test_partial(): - assert dill.copy(Machine(), byref=True) - assert dill.copy(Machine(), byref=True, recurse=True) - if not OLDER: - assert dill.copy(Machine(), recurse=True) - assert dill.copy(Machine()) + assert copy(Machine(), byref=True) + assert copy(Machine(), byref=True, recurse=True) + assert copy(Machine(), recurse=True) + assert copy(Machine()) class Machine2(object): @@ -72,21 +89,77 @@ def member(self, model): class SubMachine(Machine2): def __init__(self): - _super(SubMachine, self).__init__() - #super(SubMachine, self).__init__() #XXX: works, except for 3.1-3.3 + super(SubMachine, self).__init__() def test_partials(): - assert dill.copy(SubMachine(), byref=True) - assert dill.copy(SubMachine(), byref=True, recurse=True) - #if not OLDER: #FIXME: fails __main__.py - # assert dill.copy(SubMachine(), recurse=True) - assert dill.copy(SubMachine()) + assert copy(SubMachine(), byref=True) + assert copy(SubMachine(), byref=True, recurse=True) + assert copy(SubMachine(), recurse=True) + assert copy(SubMachine()) + +class obj4(object): + def __init__(self): + super(obj4, self).__init__() + a = self + class obj5(object): + def __init__(self): + super(obj5, self).__init__() + self.a = a + self.b = obj5() + + +def test_circular_reference(): + assert copy(obj4()) + obj4_copy = dill.loads(dill.dumps(obj4())) + if PY3: + assert type(obj4_copy) is type(obj4_copy).__init__.__closure__[0].cell_contents + assert type(obj4_copy.b) is type(obj4_copy.b).__init__.__closure__[0].cell_contents + + +def f(): + def g(): + return g + return g + + +def test_function_cells(): + assert copy(f()) + + +def fib(n): + assert n >= 0 + if n <= 1: + return n + else: + return fib(n-1) + fib(n-2) + + +def test_recursive_function(): + global fib + fib2 = copy(fib, recurse=True) + fib3 = copy(fib) + fib4 = fib + del fib + assert fib2(5) == 5 + for _fib in (fib3, fib4): + try: + _fib(5) + except: + # This is expected to fail because fib no longer exists + pass + else: + raise AssertionError("Function fib shouldn't have been found") + fib = fib4 if __name__ == '__main__': - #print(('byref','_super','_recurse','_memo','_stop','OLDER')) - test_super() - test_partial() - test_partials() + with warnings.catch_warnings(): + warnings.simplefilter('error') + test_super() + test_partial() + test_partials() + test_circular_reference() + test_function_cells() + test_recursive_function() diff --git a/tests/test_selected.py b/tests/test_selected.py index a3bf7487..ef2b9f7e 100644 --- a/tests/test_selected.py +++ b/tests/test_selected.py @@ -79,7 +79,7 @@ def test_frame_related(): _is = lambda ok: not ok if dill._dill.IS_PYPY2 else ok ok = dill.pickles(f) if verbose: print ("%s: %s, %s" % (ok, type(f), f)) - assert _is(not ok) #XXX: dill fails + assert not ok ok = dill.pickles(g) if verbose: print ("%s: %s, %s" % (ok, type(g), g)) assert _is(not ok) #XXX: dill fails