Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions pyadjoint/overloaded_type.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import weakref
from .block_variable import BlockVariable
from .tape import get_working_tape

Expand Down Expand Up @@ -64,6 +65,39 @@ def register_overloaded_type(overloaded_type, classes=None):
return overloaded_type


class Weakref:
"""Weakref which is picklable if the referenced object is picklable or
None.

Args:
obj (:obj:`object`): The object to hold a weak reference to. None
indicates a reference to no object.
"""

def __init__(self, obj=None):
self._init(obj)

def _init(self, obj):
if obj is None:
self._obj = lambda: None
else:
self._obj = weakref.ref(obj)

def __call__(self):
return self._obj()

def __getstate__(self):
state = self.__dict__.copy()
state["_obj"] = self()
return state

def __setstate__(self, state):
state = state.copy()
obj = state.pop("_obj")
self.__dict__.update(state)
self._init(obj)


class OverloadedType(object):
"""Base class for OverloadedType types.

Expand All @@ -74,8 +108,7 @@ class OverloadedType(object):
"""

def __init__(self, *args, **kwargs):
self.block_variable = None
self.create_block_variable()
self.clear_block_variable()

@classmethod
def _ad_init_object(cls, obj):
Expand All @@ -93,9 +126,21 @@ def _ad_init_object(cls, obj):
"""
return cls(obj)

@property
def block_variable(self):
block_variable = self._block_variable()
return self.create_block_variable() if block_variable is None else block_variable

@block_variable.setter
def block_variable(self, value):
self._block_variable = Weakref(value)

def clear_block_variable(self):
self._block_variable = Weakref()

def create_block_variable(self):
self.block_variable = BlockVariable(self)
return self.block_variable
self.block_variable = block_variable = BlockVariable(self)
return block_variable

def _ad_convert_type(self, value, options={}):
"""This method must be overridden.
Expand Down
4 changes: 2 additions & 2 deletions pyadjoint/tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,10 @@ def __exit__(self, *args):
_annotation_enabled = self._orig_annotation_enabled.pop()
if self.modifies is not None:
try:
self.modifies.create_block_variable()
self.modifies.clear_block_variable()
except AttributeError:
for var in self.modifies:
var.create_block_variable()
var.clear_block_variable()


no_annotations = stop_annotating()
Expand Down