diff --git a/src/nessai_bilby/plugin.py b/src/nessai_bilby/plugin.py index ee8bbfc..02b0d5b 100644 --- a/src/nessai_bilby/plugin.py +++ b/src/nessai_bilby/plugin.py @@ -208,6 +208,13 @@ def run_sampler(self): use_ratio=self.use_ratio, ) + n_pool = kwargs.pop("n_pool", None) + if n_pool == 1: + logger.debug( + "n_pool=1, overriding n_pool to None to disable multiprocessing" + ) + n_pool = None + # Configure the sampler self.fs = FlowSampler( model, diff --git a/tests/conftest.py b/tests/conftest.py index 3dd6669..7814eec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,19 +4,20 @@ from nessai.livepoint import reset_extra_live_points_parameters -@pytest.fixture -def bilby_gaussian_likelihood_and_priors(): - class GaussianLikelihood(bilby.Likelihood): - def __init__(self): - """A very simple Gaussian likelihood""" - super().__init__(parameters={"x": None, "y": None}) +class GaussianLikelihood(bilby.Likelihood): + def __init__(self): + """A very simple Gaussian likelihood""" + super().__init__(parameters={"x": None, "y": None}) - def log_likelihood(self): - """Log-likelihood.""" - return -0.5 * ( - self.parameters["x"] ** 2.0 + self.parameters["y"] ** 2.0 - ) - np.log(2.0 * np.pi) + def log_likelihood(self): + """Log-likelihood.""" + return -0.5 * ( + self.parameters["x"] ** 2.0 + self.parameters["y"] ** 2.0 + ) - np.log(2.0 * np.pi) + +@pytest.fixture +def bilby_gaussian_likelihood_and_priors(): likelihood = GaussianLikelihood() priors = dict( x=bilby.core.prior.Uniform(-10, 10, "x"), @@ -34,3 +35,8 @@ def reset_live_point_parameters(): @pytest.fixture() def rng(): return np.random.default_rng() + + +@pytest.fixture(params=[None, 2]) +def n_pool(request): + return request.param diff --git a/tests/test_bilby_integration.py b/tests/test_bilby_integration.py index 05d2b52..b4c95f6 100644 --- a/tests/test_bilby_integration.py +++ b/tests/test_bilby_integration.py @@ -13,6 +13,7 @@ def test_sampling_nessai( bilby_gaussian_likelihood_and_priors, tmp_path, likelihood_constraint, + n_pool, ): likelihood, priors = bilby_gaussian_likelihood_and_priors @@ -31,7 +32,7 @@ def test_sampling_nessai( analytic_priors=True, seed=1234, nessai_likelihood_constraint=likelihood_constraint, - n_pool=None, + n_pool=n_pool, ) # Assert plots are made assert list(outdir.glob("*_nessai/*.png")) @@ -41,6 +42,7 @@ def test_sampling_inessai( bilby_gaussian_likelihood_and_priors, tmp_path, likelihood_constraint, + n_pool, ): likelihood, priors = bilby_gaussian_likelihood_and_priors @@ -58,7 +60,7 @@ def test_sampling_inessai( injection_parameters={"x": 0.0, "y": 0.0}, seed=1234, nessai_likelihood_constraint=likelihood_constraint, - n_pool=None, + n_pool=n_pool, )