Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 51 additions & 190 deletions bilby/core/sampler/dynesty.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
logger,
safe_file_dump,
)
from . import dynesty_utils
from .base_sampler import NestedSampler, Sampler, _SamplingContainer, signal_wrapper


Expand Down Expand Up @@ -196,15 +197,6 @@ def default_kwargs(self):
kwargs["seed"] = None
return kwargs

@property
def new_dynesty_api(self):
try:
import dynesty.internal_samplers # noqa

return True
except ImportError:
return False

def __init__(
self,
likelihood,
Expand Down Expand Up @@ -281,61 +273,54 @@ def sampler_init_kwargs(self):
# if we're using a Bilby implemented sampling method we need to register the
# method. If we aren't we need to make sure the default "live" isn't set as
# the bounding method
if self.new_dynesty_api:
internal_kwargs = dict(
ndim=self.ndim,
nonbounded=self.kwargs.get("nonbounded", None),
periodic=self.kwargs.get("periodic", None),
reflective=self.kwargs.get("reflective", None),
maxmcmc=self.maxmcmc,
)

from . import dynesty3_utils as dynesty_utils
internal_kwargs = dict(
ndim=self.ndim,
nonbounded=self.kwargs.get("nonbounded", None),
periodic=self.kwargs.get("periodic", None),
reflective=self.kwargs.get("reflective", None),
maxmcmc=self.maxmcmc,
)

if kwargs["sample"] == "act-walk":
internal_kwargs["nact"] = self.nact
internal_sampler = dynesty_utils.ACTTrackingEnsembleWalk(
**internal_kwargs
)
bound = "none"
logger.info(
f"Using the bilby-implemented ensemble rwalk sampling tracking the "
f"autocorrelation function and thinning by {internal_sampler.thin} with "
f"maximum length {internal_sampler.thin * internal_sampler.maxmcmc}."
)
elif kwargs["sample"] == "acceptance-walk":
internal_kwargs["naccept"] = self.naccept
internal_kwargs["walks"] = self.kwargs["walks"]
internal_sampler = dynesty_utils.EnsembleWalkSampler(**internal_kwargs)
bound = "none"
logger.info(
f"Using the bilby-implemented ensemble rwalk sampling method with an "
f"average of {internal_sampler.naccept} accepted steps up to chain "
f"length {internal_sampler.maxmcmc}."
)
elif kwargs["sample"] == "rwalk":
internal_kwargs["nact"] = self.nact
internal_sampler = dynesty_utils.AcceptanceTrackingRWalk(
**internal_kwargs
)
bound = "none"
logger.info(
f"Using the bilby-implemented ensemble rwalk sampling method with ACT "
f"estimated chain length. An average of {2 * internal_sampler.nact} "
f"steps will be accepted up to chain length {internal_sampler.maxmcmc}."
)
elif kwargs["bound"] == "live":
logger.info(
"Live-point based bound method requested with dynesty sample "
f"'{kwargs['sample']}', overwriting to 'multi'"
)
internal_sampler = kwargs["sample"]
bound = "multi"
else:
internal_sampler = kwargs["sample"]
bound = kwargs["bound"]
kwargs["sample"] = internal_sampler
kwargs["bound"] = bound
if kwargs["sample"] == "act-walk":
internal_kwargs["nact"] = self.nact
internal_sampler = dynesty_utils.ACTTrackingEnsembleWalk(**internal_kwargs)
bound = "none"
logger.info(
f"Using the bilby-implemented ensemble rwalk sampling tracking the "
f"autocorrelation function and thinning by {internal_sampler.thin} with "
f"maximum length {internal_sampler.thin * internal_sampler.maxmcmc}."
)
elif kwargs["sample"] == "acceptance-walk":
internal_kwargs["naccept"] = self.naccept
internal_kwargs["walks"] = self.kwargs["walks"]
internal_sampler = dynesty_utils.EnsembleWalkSampler(**internal_kwargs)
bound = "none"
logger.info(
f"Using the bilby-implemented ensemble rwalk sampling method with an "
f"average of {internal_sampler.naccept} accepted steps up to chain "
f"length {internal_sampler.maxmcmc}."
)
elif kwargs["sample"] == "rwalk":
internal_kwargs["nact"] = self.nact
internal_sampler = dynesty_utils.AcceptanceTrackingRWalk(**internal_kwargs)
bound = "none"
logger.info(
f"Using the bilby-implemented ensemble rwalk sampling method with ACT "
f"estimated chain length. An average of {2 * internal_sampler.nact} "
f"steps will be accepted up to chain length {internal_sampler.maxmcmc}."
)
elif kwargs["bound"] == "live":
logger.info(
"Live-point based bound method requested with dynesty sample "
f"'{kwargs['sample']}', overwriting to 'multi'"
)
internal_sampler = kwargs["sample"]
bound = "multi"
else:
internal_sampler = kwargs["sample"]
bound = kwargs["bound"]
kwargs["sample"] = internal_sampler
kwargs["bound"] = bound
return kwargs

def _translate_kwargs(self, kwargs):
Expand Down Expand Up @@ -514,107 +499,12 @@ def sampler_class(self):

return Sampler

def _set_sampling_method(self):
"""
Resolve the sampling method and sampler to use from the provided
:code:`bound` and :code:`sample` arguments.

This requires registering the :code:`bilby` specific methods in the
appropriate locations within :code:`dynesty`.

Additionally, some combinations of bound/sample/proposals are not
compatible and so we either warn the user or raise an error.
"""
if self.new_dynesty_api:
return

import dynesty

_set_sampling_kwargs((self.nact, self.maxmcmc, self.proposals, self.naccept))

sample = self.kwargs["sample"]
bound = self.kwargs["bound"]

if sample not in ["rwalk", "act-walk", "acceptance-walk"] and bound in [
"live",
"live-multi",
]:
logger.info(
"Live-point based bound method requested with dynesty sample "
f"'{sample}', overwriting to 'multi'"
)
self.kwargs["bound"] = "multi"
elif bound == "live":
from .dynesty_utils import LivePointSampler

dynesty.dynamicsampler._SAMPLERS["live"] = LivePointSampler
elif bound == "live-multi":
from .dynesty_utils import MultiEllipsoidLivePointSampler

dynesty.dynamicsampler._SAMPLERS[
"live-multi"
] = MultiEllipsoidLivePointSampler
elif sample == "acceptance-walk":
raise DynestySetupError(
"bound must be set to live or live-multi for sample=acceptance-walk"
)
elif self.proposals is None:
logger.warning(
"No proposals specified using dynesty sampling, defaulting "
"to 'volumetric'."
)
self.proposals = ["volumetric"]
_SamplingContainer.proposals = self.proposals
elif "diff" in self.proposals:
raise DynestySetupError(
"bound must be set to live or live-multi to use differential "
"evolution proposals"
)

if sample == "rwalk":
logger.info(
f"Using the bilby-implemented {sample} sample method with ACT estimated walks. "
f"An average of {2 * self.nact} steps will be accepted up to chain length "
f"{self.maxmcmc}."
)
from .dynesty_utils import AcceptanceTrackingRWalk

if self.kwargs["walks"] > self.maxmcmc:
raise DynestySetupError("You have maxmcmc < walks (minimum mcmc)")
if self.nact < 1:
raise DynestySetupError("Unable to run with nact < 1")
AcceptanceTrackingRWalk.old_act = None
dynesty.nestedsamplers._SAMPLING["rwalk"] = AcceptanceTrackingRWalk()
elif sample == "acceptance-walk":
logger.info(
f"Using the bilby-implemented {sample} sampling with an average of "
f"{self.naccept} accepted steps per MCMC and maximum length {self.maxmcmc}"
)
from .dynesty_utils import FixedRWalk

dynesty.nestedsamplers._SAMPLING["acceptance-walk"] = FixedRWalk()
elif sample == "act-walk":
logger.info(
f"Using the bilby-implemented {sample} sampling tracking the "
f"autocorrelation function and thinning by "
f"{self.nact} with maximum length {self.nact * self.maxmcmc}"
)
from .dynesty_utils import ACTTrackingRWalk

ACTTrackingRWalk._cache = list()
dynesty.nestedsamplers._SAMPLING["act-walk"] = ACTTrackingRWalk()
elif sample == "rwalk_dynesty":
sample = sample.strip("_dynesty")
self.kwargs["sample"] = sample
logger.info(f"Using the dynesty-implemented {sample} sample method")

@signal_wrapper
def run_sampler(self):
import dynesty

logger.info(f"Using dynesty version {dynesty.__version__}")

self._set_sampling_method()
self._setup_pool()

if self.resume:
Expand Down Expand Up @@ -666,25 +556,6 @@ def run_sampler(self):

return self.result

def _setup_pool(self):
"""
In addition to the usual steps, we need to set the sampling kwargs on
every process. To make sure we get every process, run the kwarg setting
more times than we have processes.
"""
super(Dynesty, self)._setup_pool()

if self.new_dynesty_api:
return

if self.pool is not None:
args = (
[(self.nact, self.maxmcmc, self.proposals, self.naccept)]
* self.npool
* 10
)
self.pool.map(_set_sampling_kwargs, args)

def _generate_result(self, out):
"""
Extract the information we need from the dynesty output. This includes
Expand Down Expand Up @@ -895,10 +766,7 @@ def read_saved_state(self, continuing=False):
mapper = self.pool.map
else:
mapper = map
if self.new_dynesty_api:
self.sampler.mapper = mapper
else:
self.sampler.M = mapper
self.sampler.mapper = mapper
return True
else:
logger.info(f"Resume file {self.resume_file} does not exist.")
Expand Down Expand Up @@ -941,10 +809,7 @@ def write_current_state(self):
metadata = dict()
versions = dict(bilby=bilby_version, dynesty=dynesty_version)
self.sampler.pool = None
if self.new_dynesty_api:
self.sampler.mapper = map
else:
self.sampler.M = map
self.sampler.mapper = map
if dill.pickles(self.sampler):
safe_file_dump((self.sampler, versions, metadata), self.resume_file, dill)
logger.info(f"Written checkpoint file {self.resume_file}")
Expand All @@ -955,10 +820,7 @@ def write_current_state(self):
)
self.sampler.pool = self.pool
if self.sampler.pool is not None:
if self.new_dynesty_api:
self.sampler.mapper = self.sampler.pool.map
else:
self.sampler.M = self.sampler.pool.map
self.sampler.mapper = self.sampler.pool.map

def dump_samples_to_dat(self):
"""
Expand Down Expand Up @@ -1090,7 +952,6 @@ def _run_test(self):
"""Run the sampler very briefly as a sanity test that it works."""
import pandas as pd

self._set_sampling_method()
self._setup_pool()
self.sampler = self.sampler_init(
loglikelihood=_log_likelihood_wrapper,
Expand Down
Loading
Loading