Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 52 additions & 21 deletions bilby/core/sampler/dynesty.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,13 @@ class Dynesty(NestedSampler):
bound: {'live', 'live-multi', 'none', 'single', 'multi', 'balls', 'cubes'}, ('live')
Method used to select new points
sample: {'act-walk', 'acceptance-walk', 'unif', 'rwalk', 'slice',
'rslice', 'hslice', 'rwalk_dynesty'}, ('act-walk')
'rslice', 'hslice', 'rwalk_dynesty', BaseEnsembleSampler}, ('act-walk')
Method used to sample uniformly within the likelihood constraints,
conditioned on the provided bounds
conditioned on the provided bounds. The default is the bilby-implemented `act-walk` method.
Passing custom subclasses of the bilby-implemented `BaseEnsembleSampler` is also supported.
Custom kwargs can be passed to the samplers via the `custom_sampler_kwargs` argument.
custom_sampler_kwargs: dict (None)
Dictionary of custom keyword arguments passed to initialize the sampling method.
walks: int (100)
Number of walks taken if using the dynesty implemented sample methods
Note that the default `walks` in dynesty itself is 25, although using
Expand Down Expand Up @@ -227,13 +231,15 @@ def __init__(
naccept=60,
rejection_sample_posterior=True,
proposals=None,
custom_sampler_kwargs=None,
**kwargs,
):
self.nact = nact
self.naccept = naccept
self.maxmcmc = maxmcmc
self.proposals = proposals
self.print_method = print_method
self.custom_sampler_kwargs = custom_sampler_kwargs
self._translate_kwargs(kwargs)
super(Dynesty, self).__init__(
likelihood=likelihood,
Expand Down Expand Up @@ -282,49 +288,74 @@ def sampler_init_kwargs(self):
# method. If we aren't we need to make sure the default "live" isn't set as
# the bounding method
if self.new_dynesty_api:
internal_kwargs = dict(
ndim=self.ndim,
nonbounded=self.kwargs.get("nonbounded", None),
periodic=self.kwargs.get("periodic", None),
reflective=self.kwargs.get("reflective", None),
maxmcmc=self.maxmcmc,
)

from . import dynesty3_utils as dynesty_utils

if kwargs["sample"] == "act-walk":
internal_kwargs["nact"] = self.nact
internal_sampler = dynesty_utils.ACTTrackingEnsembleWalk(
**internal_kwargs
def init_internal_sampler(internal_sampler):
use_kwargs = internal_sampler.internal_sampler_init_kwargs()
# collect the kwargs from attributes or from the kwargs dict
kwargs = {
key: val
for key in use_kwargs
if (val := getattr(self, key, None)) is not None
}
kwargs |= {
key: val
for key in use_kwargs
if (val := self.kwargs.get(key, None)) is not None
}
if self.custom_sampler_kwargs is not None:
kwargs |= {
key: self.custom_sampler_kwargs[key]
for key in use_kwargs
if key in self.custom_sampler_kwargs
}
internal_sampler = internal_sampler(**kwargs)
return internal_sampler

if not isinstance(kwargs["sample"], str):
if not isinstance(kwargs["sample"], type) or not issubclass(
kwargs["sample"], dynesty_utils.BaseEnsembleSampler
):
raise DynestySetupError(
"If sample is not a string, it must be a subclass of "
"bilby.core.sampler.dynesty_utils.BaseEnsembleSampler"
)
internal_sampler = kwargs["sample"]
internal_sampler = init_internal_sampler(internal_sampler)
bound = "none"
logger.info(
f"Using the custom {internal_sampler.__class__.__name__} sampling method with "
f"parameters {internal_sampler.input_kwargs}."
)
elif kwargs["sample"] == "act-walk":
internal_sampler = dynesty_utils.ACTTrackingEnsembleWalk
internal_sampler = init_internal_sampler(internal_sampler)
bound = "none"
logger.info(
f"Using the bilby-implemented ensemble rwalk sampling tracking the "
f"autocorrelation function and thinning by {internal_sampler.thin} with "
f"maximum length {internal_sampler.thin * internal_sampler.maxmcmc}."
)
elif kwargs["sample"] == "acceptance-walk":
internal_kwargs["naccept"] = self.naccept
internal_kwargs["walks"] = self.kwargs["walks"]
internal_sampler = dynesty_utils.EnsembleWalkSampler(**internal_kwargs)
internal_sampler = dynesty_utils.EnsembleWalkSampler
internal_sampler = init_internal_sampler(internal_sampler)
bound = "none"
logger.info(
f"Using the bilby-implemented ensemble rwalk sampling method with an "
f"average of {internal_sampler.naccept} accepted steps up to chain "
f"length {internal_sampler.maxmcmc}."
)
elif kwargs["sample"] == "rwalk":
internal_kwargs["nact"] = self.nact
internal_sampler = dynesty_utils.AcceptanceTrackingRWalk(
**internal_kwargs
)
internal_sampler = dynesty_utils.AcceptanceTrackingRWalk
internal_sampler = init_internal_sampler(internal_sampler)
bound = "none"
logger.info(
f"Using the bilby-implemented ensemble rwalk sampling method with ACT "
f"estimated chain length. An average of {2 * internal_sampler.nact} "
f"steps will be accepted up to chain length {internal_sampler.maxmcmc}."
)
elif kwargs["bound"] == "live":
elif "live" in kwargs["bound"]:
logger.info(
"Live-point based bound method requested with dynesty sample "
f"'{kwargs['sample']}', overwriting to 'multi'"
Expand Down
21 changes: 21 additions & 0 deletions bilby/core/sampler/dynesty3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,26 @@


class BaseEnsembleSampler(InternalSampler):

_init_kwargs = {
"ndim",
"ncdim",
"nonbounded",
"periodic",
"reflective",
"proposals",
}

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.ncdim = kwargs.get("ncdim")
self.sampler_kwargs["ncdim"] = self.ncdim
self.sampler_kwargs["proposals"] = kwargs.get("proposals", ["diff"])

@classmethod
def internal_sampler_init_kwargs(cls):
return {kwarg for c in cls.mro() for kwarg in getattr(c, "_init_kwargs", set())}

def prepare_sampler(
self,
loglstar=None,
Expand Down Expand Up @@ -103,6 +117,9 @@ def prepare_sampler(


class EnsembleWalkSampler(BaseEnsembleSampler):

_init_kwargs = {"walks", "naccept", "maxmcmc"}

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.walks = max(2, kwargs.get("walks", 25))
Expand Down Expand Up @@ -256,6 +273,8 @@ class ACTTrackingEnsembleWalk(BaseEnsembleSampler):
# iteration when using multiprocessing
_cache = list()

_init_kwargs = {"nact", "maxmcmc"}

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.act = 1
Expand Down Expand Up @@ -560,6 +579,8 @@ class AcceptanceTrackingRWalk(EnsembleWalkSampler):
# level attribute
old_act = None

_init_kwargs = {"nact"}

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.nact = kwargs.get("nact", 40)
Expand Down
21 changes: 21 additions & 0 deletions test/core/sampler/dynesty_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,27 @@ def test_sampler_kwargs_acceptance_walk(self):
self.assertEqual(self.sampler.naccept, 5)
self.assertEqual(self.sampler.maxmcmc, 200)

@pytest.mark.skipif(not NEW_DYNESTY_API, reason="Custom samplers only implemented for new dynesty API")
def test_sampler_kwargs_custom(self):
base_sample = bilby.core.sampler.dynesty3_utils.EnsembleWalkSampler

class custom_sample(base_sample):

_init_kwargs = {"custom_kwarg"}

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.custom_kwarg = kwargs.get("custom_kwarg")

self.init_sampler(sample=custom_sample, naccept=5,
maxmcmc=200, custom_sampler_kwargs={"custom_kwarg": 5})
self.assertIsInstance(
self.dysampler.internal_sampler_next, custom_sample
)
self.assertEqual(self.dysampler.internal_sampler_next.naccept, 5)
self.assertEqual(self.dysampler.internal_sampler_next.maxmcmc, 200)
self.assertEqual(self.dysampler.internal_sampler_next.custom_kwarg, 5)

def test_run_test_runs(self):
self.sampler._run_test()

Expand Down
Loading