diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 18a178a41..cfb3e13db 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -143,9 +143,13 @@ class Dynesty(NestedSampler): bound: {'live', 'live-multi', 'none', 'single', 'multi', 'balls', 'cubes'}, ('live') Method used to select new points sample: {'act-walk', 'acceptance-walk', 'unif', 'rwalk', 'slice', - 'rslice', 'hslice', 'rwalk_dynesty'}, ('act-walk') + 'rslice', 'hslice', 'rwalk_dynesty', BaseEnsembleSampler}, ('act-walk') Method used to sample uniformly within the likelihood constraints, - conditioned on the provided bounds + conditioned on the provided bounds. The default is the bilby-implemented `act-walk` method. + Passing custom subclasses of the bilby-implemented `BaseEnsembleSampler` is also supported. + Custom kwargs can be passed to the samplers via the `custom_sampler_kwargs` argument. + custom_sampler_kwargs: dict (None) + Dictionary of custom keyword arguments passed to initialize the sampling method. walks: int (100) Number of walks taken if using the dynesty implemented sample methods Note that the default `walks` in dynesty itself is 25, although using @@ -227,6 +231,7 @@ def __init__( naccept=60, rejection_sample_posterior=True, proposals=None, + custom_sampler_kwargs=None, **kwargs, ): self.nact = nact @@ -234,6 +239,7 @@ def __init__( self.maxmcmc = maxmcmc self.proposals = proposals self.print_method = print_method + self.custom_sampler_kwargs = custom_sampler_kwargs self._translate_kwargs(kwargs) super(Dynesty, self).__init__( likelihood=likelihood, @@ -282,21 +288,49 @@ def sampler_init_kwargs(self): # method. If we aren't we need to make sure the default "live" isn't set as # the bounding method if self.new_dynesty_api: - internal_kwargs = dict( - ndim=self.ndim, - nonbounded=self.kwargs.get("nonbounded", None), - periodic=self.kwargs.get("periodic", None), - reflective=self.kwargs.get("reflective", None), - maxmcmc=self.maxmcmc, - ) from . import dynesty3_utils as dynesty_utils - if kwargs["sample"] == "act-walk": - internal_kwargs["nact"] = self.nact - internal_sampler = dynesty_utils.ACTTrackingEnsembleWalk( - **internal_kwargs + def init_internal_sampler(internal_sampler): + use_kwargs = internal_sampler.internal_sampler_init_kwargs() + # collect the kwargs from attributes or from the kwargs dict + kwargs = { + key: val + for key in use_kwargs + if (val := getattr(self, key, None)) is not None + } + kwargs |= { + key: val + for key in use_kwargs + if (val := self.kwargs.get(key, None)) is not None + } + if self.custom_sampler_kwargs is not None: + kwargs |= { + key: self.custom_sampler_kwargs[key] + for key in use_kwargs + if key in self.custom_sampler_kwargs + } + internal_sampler = internal_sampler(**kwargs) + return internal_sampler + + if not isinstance(kwargs["sample"], str): + if not isinstance(kwargs["sample"], type) or not issubclass( + kwargs["sample"], dynesty_utils.BaseEnsembleSampler + ): + raise DynestySetupError( + "If sample is not a string, it must be a subclass of " + "bilby.core.sampler.dynesty_utils.BaseEnsembleSampler" + ) + internal_sampler = kwargs["sample"] + internal_sampler = init_internal_sampler(internal_sampler) + bound = "none" + logger.info( + f"Using the custom {internal_sampler.__class__.__name__} sampling method with " + f"parameters {internal_sampler.input_kwargs}." ) + elif kwargs["sample"] == "act-walk": + internal_sampler = dynesty_utils.ACTTrackingEnsembleWalk + internal_sampler = init_internal_sampler(internal_sampler) bound = "none" logger.info( f"Using the bilby-implemented ensemble rwalk sampling tracking the " @@ -304,9 +338,8 @@ def sampler_init_kwargs(self): f"maximum length {internal_sampler.thin * internal_sampler.maxmcmc}." ) elif kwargs["sample"] == "acceptance-walk": - internal_kwargs["naccept"] = self.naccept - internal_kwargs["walks"] = self.kwargs["walks"] - internal_sampler = dynesty_utils.EnsembleWalkSampler(**internal_kwargs) + internal_sampler = dynesty_utils.EnsembleWalkSampler + internal_sampler = init_internal_sampler(internal_sampler) bound = "none" logger.info( f"Using the bilby-implemented ensemble rwalk sampling method with an " @@ -314,17 +347,15 @@ def sampler_init_kwargs(self): f"length {internal_sampler.maxmcmc}." ) elif kwargs["sample"] == "rwalk": - internal_kwargs["nact"] = self.nact - internal_sampler = dynesty_utils.AcceptanceTrackingRWalk( - **internal_kwargs - ) + internal_sampler = dynesty_utils.AcceptanceTrackingRWalk + internal_sampler = init_internal_sampler(internal_sampler) bound = "none" logger.info( f"Using the bilby-implemented ensemble rwalk sampling method with ACT " f"estimated chain length. An average of {2 * internal_sampler.nact} " f"steps will be accepted up to chain length {internal_sampler.maxmcmc}." ) - elif kwargs["bound"] == "live": + elif "live" in kwargs["bound"]: logger.info( "Live-point based bound method requested with dynesty sample " f"'{kwargs['sample']}', overwriting to 'multi'" diff --git a/bilby/core/sampler/dynesty3_utils.py b/bilby/core/sampler/dynesty3_utils.py index 7c8f56b9f..62d3096de 100644 --- a/bilby/core/sampler/dynesty3_utils.py +++ b/bilby/core/sampler/dynesty3_utils.py @@ -36,12 +36,26 @@ class BaseEnsembleSampler(InternalSampler): + + _init_kwargs = { + "ndim", + "ncdim", + "nonbounded", + "periodic", + "reflective", + "proposals", + } + def __init__(self, **kwargs): super().__init__(**kwargs) self.ncdim = kwargs.get("ncdim") self.sampler_kwargs["ncdim"] = self.ncdim self.sampler_kwargs["proposals"] = kwargs.get("proposals", ["diff"]) + @classmethod + def internal_sampler_init_kwargs(cls): + return {kwarg for c in cls.mro() for kwarg in getattr(c, "_init_kwargs", set())} + def prepare_sampler( self, loglstar=None, @@ -103,6 +117,9 @@ def prepare_sampler( class EnsembleWalkSampler(BaseEnsembleSampler): + + _init_kwargs = {"walks", "naccept", "maxmcmc"} + def __init__(self, **kwargs): super().__init__(**kwargs) self.walks = max(2, kwargs.get("walks", 25)) @@ -256,6 +273,8 @@ class ACTTrackingEnsembleWalk(BaseEnsembleSampler): # iteration when using multiprocessing _cache = list() + _init_kwargs = {"nact", "maxmcmc"} + def __init__(self, **kwargs): super().__init__(**kwargs) self.act = 1 @@ -560,6 +579,8 @@ class AcceptanceTrackingRWalk(EnsembleWalkSampler): # level attribute old_act = None + _init_kwargs = {"nact"} + def __init__(self, **kwargs): super().__init__(**kwargs) self.nact = kwargs.get("nact", 40) diff --git a/test/core/sampler/dynesty_test.py b/test/core/sampler/dynesty_test.py index 0be239f19..dc04d3d51 100644 --- a/test/core/sampler/dynesty_test.py +++ b/test/core/sampler/dynesty_test.py @@ -174,6 +174,27 @@ def test_sampler_kwargs_acceptance_walk(self): self.assertEqual(self.sampler.naccept, 5) self.assertEqual(self.sampler.maxmcmc, 200) + @pytest.mark.skipif(not NEW_DYNESTY_API, reason="Custom samplers only implemented for new dynesty API") + def test_sampler_kwargs_custom(self): + base_sample = bilby.core.sampler.dynesty3_utils.EnsembleWalkSampler + + class custom_sample(base_sample): + + _init_kwargs = {"custom_kwarg"} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.custom_kwarg = kwargs.get("custom_kwarg") + + self.init_sampler(sample=custom_sample, naccept=5, + maxmcmc=200, custom_sampler_kwargs={"custom_kwarg": 5}) + self.assertIsInstance( + self.dysampler.internal_sampler_next, custom_sample + ) + self.assertEqual(self.dysampler.internal_sampler_next.naccept, 5) + self.assertEqual(self.dysampler.internal_sampler_next.maxmcmc, 200) + self.assertEqual(self.dysampler.internal_sampler_next.custom_kwarg, 5) + def test_run_test_runs(self): self.sampler._run_test()