diff --git a/pyadjoint/overloaded_type.py b/pyadjoint/overloaded_type.py index 38ebeb74..75690781 100644 --- a/pyadjoint/overloaded_type.py +++ b/pyadjoint/overloaded_type.py @@ -1,3 +1,4 @@ +import weakref from .block_variable import BlockVariable from .tape import get_working_tape @@ -64,6 +65,39 @@ def register_overloaded_type(overloaded_type, classes=None): return overloaded_type +class Weakref: + """Weakref which is picklable if the referenced object is picklable or + None. + + Args: + obj (:obj:`object`): The object to hold a weak reference to. None + indicates a reference to no object. + """ + + def __init__(self, obj=None): + self._init(obj) + + def _init(self, obj): + if obj is None: + self._obj = lambda: None + else: + self._obj = weakref.ref(obj) + + def __call__(self): + return self._obj() + + def __getstate__(self): + state = self.__dict__.copy() + state["_obj"] = self() + return state + + def __setstate__(self, state): + state = state.copy() + obj = state.pop("_obj") + self.__dict__.update(state) + self._init(obj) + + class OverloadedType(object): """Base class for OverloadedType types. @@ -74,8 +108,7 @@ class OverloadedType(object): """ def __init__(self, *args, **kwargs): - self.block_variable = None - self.create_block_variable() + self.clear_block_variable() @classmethod def _ad_init_object(cls, obj): @@ -93,9 +126,21 @@ def _ad_init_object(cls, obj): """ return cls(obj) + @property + def block_variable(self): + block_variable = self._block_variable() + return self.create_block_variable() if block_variable is None else block_variable + + @block_variable.setter + def block_variable(self, value): + self._block_variable = Weakref(value) + + def clear_block_variable(self): + self._block_variable = Weakref() + def create_block_variable(self): - self.block_variable = BlockVariable(self) - return self.block_variable + self.block_variable = block_variable = BlockVariable(self) + return block_variable def _ad_convert_type(self, value, options={}): """This method must be overridden. diff --git a/pyadjoint/tape.py b/pyadjoint/tape.py index 64f3b2ba..9384d689 100644 --- a/pyadjoint/tape.py +++ b/pyadjoint/tape.py @@ -121,10 +121,10 @@ def __exit__(self, *args): _annotation_enabled = self._orig_annotation_enabled.pop() if self.modifies is not None: try: - self.modifies.create_block_variable() + self.modifies.clear_block_variable() except AttributeError: for var in self.modifies: - var.create_block_variable() + var.clear_block_variable() no_annotations = stop_annotating()