Skip to content

Commit

Permalink
workaround PicklingError from uqfoundation/dill#443
Browse files Browse the repository at this point in the history
  • Loading branch information
mmckerns committed Mar 12, 2022
1 parent 939d7ae commit f6f71f5
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 16 deletions.
43 changes: 31 additions & 12 deletions cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,34 @@
from klepto import lru_cache, lfu_cache, mru_cache
from klepto import rr_cache, inf_cache, no_cache

class _cached(object):
def __init__(self, func, multivalued=False):
self.func = func
self.multivalued = multivalued
def mvmodel(self, x, *argz, **kwdz):
axis = kwdz.pop('axis', None)
if axis is None: axis = slice(None)
return self.func(x, *argz, **kwdz)[axis]
def model(self, x, *argz, **kwdz):
axis = kwdz.pop('axis', None)
return self.func(x, *argz, **kwdz)
def __call__(self, x, *argz, **kwdz):
doit = self.mvmodel if self.multivalued else self.model
return doit(x, *argz, **kwdz)

class _cache(object):
def __init__(self, func):
self.func = func
def __call__(self):
return self.func.__cache__()

class _imodel(object):
def __init__(self, model):
self.model = model
def __call__(self, *args, **kwds):
return -self.model(*args, **kwds)


def cached(**kwds):
"""build a caching archive for an objective function
Expand Down Expand Up @@ -73,22 +101,13 @@ def _model(x, *args, **kwds):
_model.__inner__ = inner

# when caching, always cache the multi-valued tuple
if multivalued:
def model(x, *argz, **kwdz):
axis = kwdz.pop('axis', None)
if axis is None: axis = slice(None)
return _model(x, *argz, **kwdz)[axis]
else:
def model(x, *argz, **kwdz):
axis = kwdz.pop('axis', None)
return _model(x, *argz, **kwdz)
model = _cached(_model, multivalued)
# produce objective function that caches multi-valued output
model.__cache__ = lambda : inner.__cache__()
model.__cache__ = _cache(inner)
model.__doc__ = objective.__doc__

# produce model inverse with shared cache
imodel = lambda *args, **kwds: -model(*args, **kwds)
model.__inverse__ = imodel
model.__inverse__ = imodel = _imodel(model)
imodel.__inverse__ = model
imodel.__cache__ = model.__cache__

Expand Down
11 changes: 10 additions & 1 deletion examples3/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,18 +213,27 @@ def learn_ax(i):
from mystic.math.interpolate import _getaxis
from ml import Estimator as Learner
func = Learner(estimator, transform)
_z = _getaxis(z, i)
#from ml import improve_score, MLData
#_z = MLData(x, x, _z, _z)
#kwds = dict(tries=10, verbose=True)
with np.warnings.catch_warnings(): #FIXME: enable warn=True
np.warnings.filterwarnings('ignore')
func = func.train(x, _getaxis(z, i))
func = func.train(x, _z)
#func = improve_score(func, _z, **kwds)
return func
function.__axis__ = list(_map(learn_ax, range(len(z[0]))))
return function
else:
from mystic.math.interpolate import _getaxis
z = _getaxis(z, axis)
#from ml import improve_score, MLData
#_z = MLData(x, x, z, z)
#kwds = dict(tries=10, verbose=True)
with np.warnings.catch_warnings(): #FIXME: enable warn=True
np.warnings.filterwarnings('ignore')
function = learner.train(x, z)
#function = improve_score(learner, _z, **kwds)
function.__axis__ = axis
return function

Expand Down
4 changes: 4 additions & 0 deletions examples3/ouq.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def solve(self, objective, **kwds): #NOTE: single axis only
maxfun: max number of objective evaluations [default: defined in solver]
evalmon: mystic.monitor instance [default: Monitor], for evaluations
stepmon: mystic.monitor instance [default: Monitor], for iterations
save: iteration frequency to save solver [default: None]
opts: dict of configuration options for solver.Solve [default: {}]
Returns:
Expand All @@ -251,6 +252,9 @@ def solve(self, objective, **kwds): #NOTE: single axis only
else: # DiffEv/Nelder/Powell
if x0 is None: solver.SetRandomInitialPoints(min=lb,max=ub)
else: solver.SetInitialPoints(x0)
save = kwds.get('save', None)
if save is not None:
solver.SetSaveFrequency(save, 'Solver.pkl') #XXX: set name?
mapper = kwds.get('pool', None)
if mapper is not None:
pool = mapper() #XXX: ThreadPool, ProcessPool, etc
Expand Down
11 changes: 8 additions & 3 deletions examples3/ouq_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
'''
model objects (and helper functions) to be used with OUQ classes
'''
def _modelaxis(model, axis=None):
def axmodel(x):
return model(x, axis=axis)
return axmodel

#FIXME: hardwired to multivalue function
#FIXME: dict_archive('truth', cached=False) does not cache (is empty)
#FIXME: option to cache w/o lookup (e.g. for model with randomness)
Expand Down Expand Up @@ -85,7 +90,7 @@ def sample(model, bounds, pts=None, **kwds):
elif pts > 0: # sample pts without optimizing
pts = pts if _pts is None else _pts
def doit(axis=None):
_model = lambda x: model(x, axis=axis)
_model = _modelaxis(model, axis)
s = searcher(bounds, _model, npts=pts, dist=dist, **kwds)
s.sample()
return s
Expand All @@ -97,12 +102,12 @@ def doit(axis=None):
else: # search for minima until terminated
pts = -pts if _pts is None else _pts
def lower(axis=None):
_model = lambda x: model(x, axis=axis)
_model = _modelaxis(model, axis)
s = searcher(bounds, _model, npts=pts, dist=dist, **kwds)
s.sample_until(terminated=all)
return s
def upper(axis=None):
model_ = lambda x: imodel(x, axis=axis)
model_ = _modelaxis(imodel, axis)
si = searcher(bounds, model_, npts=pts, dist=dist, **kwds)
si.sample_until(terminated=all)
return si
Expand Down

0 comments on commit f6f71f5

Please sign in to comment.